rocket_codegen/attribute/entry/
launch.rs

1use devise::{Spanned, Result};
2use devise::ext::SpanDiagnosticExt;
3use proc_macro2::{TokenStream, Span};
4
5use super::EntryAttr;
6use crate::exports::mixed;
7
8/// `#[rocket::launch]`: generates a `main` function that calls the attributed
9/// function to generate a `Rocket` instance. Then calls `.launch()` on the
10/// returned instance inside of an `rocket::async_main`.
11pub struct Launch;
12
13/// Determines if `f` likely spawns an async task, returning the spawn call.
14fn likely_spawns(f: &syn::ItemFn) -> Option<&syn::ExprCall> {
15    use syn::visit::{self, Visit};
16
17    struct SpawnFinder<'a>(Option<&'a syn::ExprCall>);
18
19    impl<'ast> Visit<'ast> for SpawnFinder<'ast> {
20        fn visit_expr_call(&mut self, i: &'ast syn::ExprCall) {
21            if self.0.is_some() {
22                return;
23            }
24
25            if let syn::Expr::Path(ref e) = *i.func {
26                let mut segments = e.path.segments.clone();
27                if let Some(last) = segments.pop() {
28                    if last.value().ident != "spawn" {
29                        return visit::visit_expr_call(self, i);
30                    }
31
32                    if let Some(prefix) = segments.pop() {
33                        if prefix.value().ident == "tokio" {
34                            self.0 = Some(i);
35                            return;
36                        }
37                    }
38
39                    if let Some(syn::Expr::Async(_)) = i.args.first() {
40                        self.0 = Some(i);
41                        return;
42                    }
43                }
44            };
45
46            visit::visit_expr_call(self, i);
47        }
48    }
49
50    let mut v = SpawnFinder(None);
51    v.visit_item_fn(f);
52    v.0
53}
54
55impl EntryAttr for Launch {
56    const REQUIRES_ASYNC: bool = false;
57
58    fn function(f: &mut syn::ItemFn) -> Result<TokenStream> {
59        if f.sig.ident == "main" {
60            return Err(Span::call_site()
61                .error("attribute cannot be applied to `main` function")
62                .note("this attribute generates a `main` function")
63                .span_note(f.sig.ident.span(), "this function cannot be `main`"));
64        }
65
66        // Always infer the type as `Rocket<Build>`.
67        if let syn::ReturnType::Type(_, ref mut ty) = &mut f.sig.output {
68            if let syn::Type::Infer(_) = &mut **ty {
69                let new = quote_spanned!(ty.span() => ::rocket::Rocket<::rocket::Build>);
70                *ty = syn::parse2(new).expect("path is type");
71            }
72        }
73
74        let ty = match &f.sig.output {
75            syn::ReturnType::Type(_, ty) => ty,
76            _ => return Err(Span::call_site()
77                .error("attribute can only be applied to functions that return a value")
78                .span_note(f.sig.span(), "this function must return a value"))
79        };
80
81        let block = &f.block;
82        let rocket = quote_spanned!(mixed(ty.span()) => {
83            let ___rocket: #ty = #block;
84            let ___rocket: ::rocket::Rocket<::rocket::Build> = ___rocket;
85            ___rocket
86        });
87
88        let launch = match f.sig.asyncness {
89            Some(_) => quote_spanned!(ty.span() => async move { #rocket.launch().await }),
90            None => quote_spanned!(ty.span() => #rocket.launch()),
91        };
92
93        if f.sig.asyncness.is_none() {
94            if let Some(call) = likely_spawns(f) {
95                call.span()
96                    .warning("task is being spawned outside an async context")
97                    .span_help(f.sig.span(), "declare this function as `async fn` \
98                                              to require async execution")
99                    .span_note(Span::call_site(), "`#[launch]` call is here")
100                    .emit_as_expr_tokens();
101            }
102        }
103
104        let (vis, mut sig) = (&f.vis, f.sig.clone());
105        sig.ident = syn::Ident::new("main", sig.ident.span());
106        sig.output = syn::ReturnType::Default;
107        sig.asyncness = None;
108
109        Ok(quote_spanned!(block.span() =>
110            #[allow(dead_code)] #f
111
112            #vis #sig {
113                let _ = ::rocket::async_main(#launch);
114            }
115        ))
116    }
117}