rocket_codegen/derive/
responder.rs

1use quote::ToTokens;
2use devise::{*, ext::{TypeExt, SpanDiagnosticExt}};
3use proc_macro2::TokenStream;
4
5use crate::exports::*;
6use crate::syn_ext::{TypeExt as _, GenericsExt as _};
7use crate::http_codegen::{ContentType, Status};
8
9#[derive(Debug, Default, FromMeta)]
10struct ItemAttr {
11    content_type: Option<SpanWrapped<ContentType>>,
12    status: Option<SpanWrapped<Status>>,
13}
14
15#[derive(Default, FromMeta)]
16struct FieldAttr {
17    ignore: bool,
18}
19
20pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream {
21    let impl_tokens = quote!(impl<'r, 'o: 'r> #_response::Responder<'r, 'o>);
22    DeriveGenerator::build_for(input, impl_tokens)
23        .support(Support::Struct | Support::Enum | Support::Lifetime | Support::Type)
24        .replace_generic(1, 0)
25        .type_bound_mapper(MapperBuild::new()
26            .try_enum_map(|m, e| mapper::enum_null(m, e))
27            .try_fields_map(|_, fields| {
28                let generic_idents = fields.parent.input().generics().type_idents();
29                let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span());
30                let mut types = fields.iter()
31                    .map(|f| (f, &f.field.inner.ty))
32                    .map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty))));
33
34                let mut bounds = vec![];
35                if let Some((_, ty)) = types.next() {
36                    if !ty.is_concrete(&generic_idents) {
37                        let span = ty.span();
38                        bounds.push(quote_spanned!(span => #ty: #_response::Responder<'r, 'o>));
39                    }
40                }
41
42                for (f, ty) in types {
43                    let attr = FieldAttr::one_from_attrs("response", &f.attrs)?.unwrap_or_default();
44                    if ty.is_concrete(&generic_idents) || attr.ignore {
45                        continue;
46                    }
47
48                    bounds.push(quote_spanned! { ty.span() =>
49                        #ty: ::std::convert::Into<#_http::Header<'o>>
50                    });
51                }
52
53                Ok(quote!(#(#bounds,)*))
54            })
55        )
56        .validator(ValidatorBuild::new()
57            .input_validate(|_, i| match i.generics().lifetimes().count() > 1 {
58                true => Err(i.generics().span().error("only one lifetime is supported")),
59                false => Ok(())
60            })
61            .fields_validate(|_, fields| match fields.is_empty() {
62                true => Err(fields.span().error("need at least one field")),
63                false => Ok(())
64            })
65        )
66        .inner_mapper(MapperBuild::new()
67            .with_output(|_, output| quote! {
68                fn respond_to(self, __req: &'r #Request<'_>) -> #_response::Result<'o> {
69                    #output
70                }
71            })
72            .try_fields_map(|_, fields| {
73                fn set_header_tokens<T: ToTokens + Spanned>(item: T) -> TokenStream {
74                    quote_spanned!(item.span() => __res.set_header(#item);)
75                }
76
77                let attr = ItemAttr::one_from_attrs("response", fields.parent.attrs())?
78                    .unwrap_or_default();
79
80                let responder = fields.iter().next().map(|f| {
81                    let (accessor, ty) = (f.accessor(), f.ty.with_stripped_lifetimes());
82                    quote_spanned! { f.span().into() =>
83                        let mut __res = <#ty as #_response::Responder>::respond_to(
84                            #accessor, __req
85                        )?;
86                    }
87                }).expect("have at least one field");
88
89                let mut headers = vec![];
90                for field in fields.iter().skip(1) {
91                    let attr = FieldAttr::one_from_attrs("response", &field.attrs)?
92                        .unwrap_or_default();
93
94                    if !attr.ignore {
95                        headers.push(set_header_tokens(field.accessor()));
96                    }
97                }
98
99                let content_type = attr.content_type.map(set_header_tokens);
100                let status = attr.status.map(|status| {
101                    quote_spanned!(status.span() => __res.set_status(#status);)
102                });
103
104                Ok(quote! {
105                    #responder
106                    #(#headers)*
107                    #content_type
108                    #status
109                    #_Ok(__res)
110                })
111            })
112        )
113        .to_tokens()
114}