rocket_codegen/attribute/route/
mod.rs

1mod parse;
2
3use std::hash::Hash;
4
5use devise::{Spanned, SpanWrapped, Result, FromMeta, Diagnostic};
6use devise::ext::TypeExt as _;
7use proc_macro2::{TokenStream, Span};
8
9use crate::proc_macro_ext::StringLit;
10use crate::syn_ext::{IdentExt, TypeExt as _};
11use crate::http_codegen::{Method, Optional};
12use crate::attribute::param::Guard;
13use crate::exports::mixed;
14
15use self::parse::{Route, Attribute, MethodAttribute};
16
17impl Route {
18    pub fn guards(&self) -> impl Iterator<Item = &Guard> {
19        self.param_guards()
20            .chain(self.query_guards())
21            .chain(self.data_guard.iter())
22            .chain(self.request_guards.iter())
23    }
24
25    pub fn param_guards(&self) -> impl Iterator<Item = &Guard> {
26        self.path_params.iter().filter_map(|p| p.guard())
27    }
28
29    pub fn query_guards(&self) -> impl Iterator<Item = &Guard> {
30        self.query_params.iter().filter_map(|p| p.guard())
31    }
32}
33
34fn query_decls(route: &Route) -> Option<TokenStream> {
35    use devise::ext::{Split2, Split6};
36
37    if route.query_params.is_empty() && route.query_guards().next().is_none() {
38        return None;
39    }
40
41    define_spanned_export!(Span::call_site() =>
42        __req, __data, _log, _form, Outcome, _Ok, _Err, _Some, _None, Status
43    );
44
45    // Record all of the static parameters for later filtering.
46    let (raw_name, raw_value) = route.query_params.iter()
47        .filter_map(|s| s.r#static())
48        .map(|name| match name.find('=') {
49            Some(i) => (&name[..i], &name[i + 1..]),
50            None => (name.as_str(), "")
51        })
52        .split2();
53
54    // Now record all of the dynamic parameters.
55    let (name, matcher, ident, init_expr, push_expr, finalize_expr) = route.query_guards()
56        .map(|guard| {
57            let (name, ty) = (&guard.name, &guard.ty);
58            let ident = guard.fn_ident.rocketized().with_span(ty.span());
59            let matcher = match guard.trailing {
60                true => quote_spanned!(name.span() => _),
61                _ => quote!(#name)
62            };
63
64            define_spanned_export!(ty.span() => FromForm, _form);
65
66            let ty = quote_spanned!(ty.span() => <#ty as #FromForm>);
67            let init = quote_spanned!(ty.span() => #ty::init(#_form::Options::Lenient));
68            let finalize = quote_spanned!(ty.span() => #ty::finalize(#ident));
69            let push = match guard.trailing {
70                true => quote_spanned!(ty.span() => #ty::push_value(&mut #ident, _f)),
71                _ => quote_spanned!(ty.span() => #ty::push_value(&mut #ident, _f.shift())),
72            };
73
74            (name, matcher, ident, init, push, finalize)
75        })
76        .split6();
77
78    #[allow(non_snake_case)]
79    Some(quote! {
80        let (#(#ident),*) = {
81            let mut __e = #_form::Errors::new();
82            #(let mut #ident = #init_expr;)*
83
84            for _f in #__req.query_fields() {
85                let _raw = (_f.name.source().as_str(), _f.value);
86                let _key = _f.name.key_lossy().as_str();
87                match (_raw, _key) {
88                    // Skip static parameters so <param..> doesn't see them.
89                    #(((#raw_name, #raw_value), _) => { /* skip */ },)*
90                    #((_, #matcher) => #push_expr,)*
91                    _ => { /* in case we have no trailing, ignore all else */ },
92                }
93            }
94
95            #(
96                let #ident = match #finalize_expr {
97                    #_Ok(_v) => #_Some(_v),
98                    #_Err(_err) => {
99                        __e.extend(_err.with_name(#_form::NameView::new(#name)));
100                        #_None
101                    },
102                };
103            )*
104
105            if !__e.is_empty() {
106                #_log::warn_!("Query string failed to match route declaration.");
107                for _err in __e { #_log::warn_!("{}", _err); }
108                return #Outcome::Forward((#__data, #Status::UnprocessableEntity));
109            }
110
111            (#(#ident.unwrap()),*)
112        };
113    })
114}
115
116fn request_guard_decl(guard: &Guard) -> TokenStream {
117    let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
118    define_spanned_export!(ty.span() =>
119        __req, __data, _request, _log, FromRequest, Outcome
120    );
121
122    quote_spanned! { ty.span() =>
123        let #ident: #ty = match <#ty as #FromRequest>::from_request(#__req).await {
124            #Outcome::Success(__v) => __v,
125            #Outcome::Forward(__e) => {
126                #_log::warn_!("Request guard `{}` is forwarding.", stringify!(#ty));
127                return #Outcome::Forward((#__data, __e));
128            },
129            #Outcome::Error((__c, __e)) => {
130                #_log::warn_!("Request guard `{}` failed: {:?}.", stringify!(#ty), __e);
131                return #Outcome::Error(__c);
132            }
133        };
134    }
135}
136
137fn param_guard_decl(guard: &Guard) -> TokenStream {
138    let (i, name, ty) = (guard.index, &guard.name, &guard.ty);
139    define_spanned_export!(ty.span() =>
140        __req, __data, _log, _None, _Some, _Ok, _Err,
141        Outcome, FromSegments, FromParam, Status
142    );
143
144    // Returned when a dynamic parameter fails to parse.
145    let parse_error = quote!({
146        #_log::warn_!("Parameter guard `{}: {}` is forwarding: {:?}.",
147            #name, stringify!(#ty), __error);
148
149        #Outcome::Forward((#__data, #Status::UnprocessableEntity))
150    });
151
152    // All dynamic parameters should be found if this function is being called;
153    // that's the point of statically checking the URI parameters.
154    let expr = match guard.trailing {
155        false => quote_spanned! { ty.span() =>
156            match #__req.routed_segment(#i) {
157                #_Some(__s) => match <#ty as #FromParam>::from_param(__s) {
158                    #_Ok(__v) => __v,
159                    #_Err(__error) => return #parse_error,
160                },
161                #_None => {
162                    #_log::error_!("Internal invariant broken: dyn param not found.");
163                    #_log::error_!("Please report this to the Rocket issue tracker.");
164                    #_log::error_!("https://github.com/rwf2/Rocket/issues");
165                    return #Outcome::Forward((#__data, #Status::InternalServerError));
166                }
167            }
168        },
169        true => quote_spanned! { ty.span() =>
170            match <#ty as #FromSegments>::from_segments(#__req.routed_segments(#i..)) {
171                #_Ok(__v) => __v,
172                #_Err(__error) => return #parse_error,
173            }
174        },
175    };
176
177    let ident = guard.fn_ident.rocketized();
178    quote!(let #ident: #ty = #expr;)
179}
180
181fn data_guard_decl(guard: &Guard) -> TokenStream {
182    let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
183    define_spanned_export!(ty.span() => _log, __req, __data, FromData, Outcome);
184
185    quote_spanned! { ty.span() =>
186        let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await {
187            #Outcome::Success(__d) => __d,
188            #Outcome::Forward((__d, __e)) => {
189                #_log::warn_!("Data guard `{}` is forwarding.", stringify!(#ty));
190                return #Outcome::Forward((__d, __e));
191            }
192            #Outcome::Error((__c, __e)) => {
193                #_log::warn_!("Data guard `{}` failed: {:?}.", stringify!(#ty), __e);
194                return #Outcome::Error(__c);
195            }
196        };
197    }
198}
199
200fn internal_uri_macro_decl(route: &Route) -> TokenStream {
201    // FIXME: Is this the right order? Does order matter?
202    let uri_args = route.param_guards()
203        .chain(route.query_guards())
204        .map(|guard| (&guard.fn_ident, &guard.ty))
205        .map(|(ident, ty)| quote!(#ident: #ty));
206
207    // Generate a unique macro name based on the route's metadata.
208    let macro_name = route.handler.sig.ident.prepend(crate::URI_MACRO_PREFIX);
209    let inner_macro_name = macro_name.uniqueify_with(|mut hasher| {
210        route.handler.sig.ident.hash(&mut hasher);
211        route.attr.uri.path().hash(&mut hasher);
212        route.attr.uri.query().hash(&mut hasher)
213    });
214
215    let route_uri = route.attr.uri.to_string();
216
217    quote_spanned! { Span::call_site() =>
218        #[doc(hidden)]
219        #[macro_export]
220        /// Rocket generated URI macro.
221        macro_rules! #inner_macro_name {
222            ($($token:tt)*) => {{
223                rocket::rocket_internal_uri!(#route_uri, (#(#uri_args),*), $($token)*)
224            }};
225        }
226
227        #[doc(hidden)]
228        #[allow(unused)]
229        pub use #inner_macro_name as #macro_name;
230    }
231}
232
233fn responder_outcome_expr(route: &Route) -> TokenStream {
234    let ret_span = match route.handler.sig.output {
235        syn::ReturnType::Default => route.handler.sig.ident.span(),
236        syn::ReturnType::Type(_, ref ty) => ty.span()
237    };
238
239    let user_handler_fn_name = &route.handler.sig.ident;
240    let parameter_names = route.arguments.map.values()
241        .map(|(ident, _)| ident.rocketized());
242
243    let _await = route.handler.sig.asyncness
244        .map(|a| quote_spanned!(a.span() => .await));
245
246    define_spanned_export!(ret_span => __req, _route);
247    quote_spanned! { mixed(ret_span) =>
248        let ___responder = #user_handler_fn_name(#(#parameter_names),*) #_await;
249        #_route::Outcome::from(#__req, ___responder)
250    }
251}
252
253fn sentinels_expr(route: &Route) -> TokenStream {
254    let ret_ty = match route.handler.sig.output {
255        syn::ReturnType::Default => None,
256        syn::ReturnType::Type(_, ref ty) => Some(ty.with_stripped_lifetimes())
257    };
258
259    let generic_idents: Vec<_> = route.handler.sig.generics
260        .type_params()
261        .map(|p| &p.ident)
262        .collect();
263
264    // Note: for a given route, we need to emit a valid graph of eligible
265    // sentinels. This means that we don't have broken links, where a child
266    // points to a parent that doesn't exist. The concern is that the
267    // `is_concrete()` filter will cause a break in the graph.
268    //
269    // Here's a proof by cases for why this can't happen:
270    //    1. if `is_concrete()` returns `false` for a (valid) type, it returns
271    //       false for all of its parents. we consider this an axiom; this is
272    //       the point of `is_concrete()`. the type is filtered out, so the
273    //       theorem vacuously holds
274    //    2. if `is_concrete()` returns `true`, for a type `T`, it either:
275    //      * returns `false` for the parent. by 1) it will return false for
276    //        _all_ parents of the type, so no node in the graph can consider,
277    //        directly or indirectly, `T` to be a child, and thus there are no
278    //        broken links; the theorem holds
279    //      * returns `true` for the parent, and so the type has a parent, and
280    //      the theorem holds.
281    //    3. these are all the cases. QED.
282
283    const TY_MACS: &[&str] = &["ReaderStream", "TextStream", "ByteStream", "EventStream"];
284
285    fn ty_mac_mapper(tokens: &TokenStream) -> Option<syn::Type> {
286        use crate::bang::typed_stream::Input;
287
288        match syn::parse2(tokens.clone()).ok()? {
289            Input::Type(ty, ..) => Some(ty),
290            Input::Tokens(..) => None
291        }
292    }
293
294    let eligible_types = route.guards()
295        .map(|guard| &guard.ty)
296        .chain(ret_ty.as_ref().into_iter())
297        .flat_map(|ty| ty.unfold_with_ty_macros(TY_MACS, ty_mac_mapper))
298        .filter(|ty| ty.is_concrete(&generic_idents))
299        .map(|child| (child.parent, child.ty));
300
301    let sentinel = eligible_types.map(|(parent, ty)| {
302        define_spanned_export!(ty.span() => _sentinel);
303
304        match parent {
305            Some(p) if p.is_concrete(&generic_idents) => {
306                quote_spanned!(ty.span() => #_sentinel::resolve!(#ty, #p))
307            }
308            Some(_) | None => quote_spanned!(ty.span() => #_sentinel::resolve!(#ty)),
309        }
310    });
311
312    quote!(::std::vec![#(#sentinel),*])
313}
314
315fn codegen_route(route: Route) -> Result<TokenStream> {
316    use crate::exports::*;
317
318    // Generate the declarations for all of the guards.
319    let request_guards = route.request_guards.iter().map(request_guard_decl);
320    let param_guards = route.param_guards().map(param_guard_decl);
321    let query_guards = query_decls(&route);
322    let data_guard = route.data_guard.as_ref().map(data_guard_decl);
323
324    // Extract the sentinels from the route.
325    let sentinels = sentinels_expr(&route);
326
327    // Gather info about the function.
328    let (vis, handler_fn) = (&route.handler.vis, &route.handler);
329    let deprecated = handler_fn.attrs.iter().find(|a| a.path().is_ident("deprecated"));
330    let handler_fn_name = &handler_fn.sig.ident;
331    let internal_uri_macro = internal_uri_macro_decl(&route);
332    let responder_outcome = responder_outcome_expr(&route);
333
334    let method = route.attr.method;
335    let uri = route.attr.uri.to_string();
336    let rank = Optional(route.attr.rank);
337    let format = Optional(route.attr.format.as_ref());
338
339    Ok(quote! {
340        #handler_fn
341
342        #[doc(hidden)]
343        #[allow(nonstandard_style)]
344        /// Rocket code generated proxy structure.
345        #deprecated #vis struct #handler_fn_name {  }
346
347        /// Rocket code generated proxy static conversion implementations.
348        #[allow(nonstandard_style, deprecated, clippy::style)]
349        impl #handler_fn_name {
350            fn into_info(self) -> #_route::StaticInfo {
351                fn monomorphized_function<'__r>(
352                    #__req: &'__r #Request<'_>,
353                    #__data: #Data<'__r>
354                ) -> #_route::BoxFuture<'__r> {
355                    #_Box::pin(async move {
356                        #(#request_guards)*
357                        #(#param_guards)*
358                        #query_guards
359                        #data_guard
360
361                        #responder_outcome
362                    })
363                }
364
365                #_route::StaticInfo {
366                    name: stringify!(#handler_fn_name),
367                    method: #method,
368                    uri: #uri,
369                    handler: monomorphized_function,
370                    format: #format,
371                    rank: #rank,
372                    sentinels: #sentinels,
373                }
374            }
375
376            #[doc(hidden)]
377            pub fn into_route(self) -> #Route {
378                self.into_info().into()
379            }
380        }
381
382        /// Rocket code generated wrapping URI macro.
383        #internal_uri_macro
384    })
385}
386
387fn complete_route(args: TokenStream, input: TokenStream) -> Result<TokenStream> {
388    let function: syn::ItemFn = syn::parse2(input)
389        .map_err(Diagnostic::from)
390        .map_err(|diag| diag.help("`#[route]` can only be used on functions"))?;
391
392    let attr_tokens = quote!(route(#args));
393    let attribute = Attribute::from_meta(&syn::parse2(attr_tokens)?)?;
394    codegen_route(Route::from(attribute, function)?)
395}
396
397fn incomplete_route(
398    method: crate::http::Method,
399    args: TokenStream,
400    input: TokenStream
401) -> Result<TokenStream> {
402    let method_str = method.to_string().to_lowercase();
403    // FIXME(proc_macro): there should be a way to get this `Span`.
404    let method_span = StringLit::new(format!("#[{}]", method), Span::call_site())
405        .subspan(2..2 + method_str.len());
406
407    let method_ident = syn::Ident::new(&method_str, method_span);
408
409    let function: syn::ItemFn = syn::parse2(input)
410        .map_err(Diagnostic::from)
411        .map_err(|d| d.help(format!("#[{}] can only be used on functions", method_str)))?;
412
413    let full_attr = quote!(#method_ident(#args));
414    let method_attribute = MethodAttribute::from_meta(&syn::parse2(full_attr)?)?;
415
416    let attribute = Attribute {
417        method: SpanWrapped {
418            full_span: method_span, key_span: None, span: method_span, value: Method(method)
419        },
420        uri: method_attribute.uri,
421        data: method_attribute.data,
422        format: method_attribute.format,
423        rank: method_attribute.rank,
424    };
425
426    codegen_route(Route::from(attribute, function)?)
427}
428
429pub fn route_attribute<M: Into<Option<crate::http::Method>>>(
430    method: M,
431    args: proc_macro::TokenStream,
432    input: proc_macro::TokenStream
433) -> TokenStream {
434    let result = match method.into() {
435        Some(method) => incomplete_route(method, args.into(), input.into()),
436        None => complete_route(args.into(), input.into())
437    };
438
439    result.unwrap_or_else(|diag| diag.emit_as_item_tokens())
440}