rocket_codegen/derive/
form_field.rs

1use std::collections::HashSet;
2
3use devise::{*, ext::{TypeExt, SpanDiagnosticExt}};
4
5use syn::{visit_mut::VisitMut, visit::Visit};
6use proc_macro2::{TokenStream, TokenTree, Span};
7use quote::{ToTokens, TokenStreamExt};
8
9use crate::syn_ext::IdentExt;
10use crate::name::Name;
11
12macro_rules! quote_spanned {
13    ($span:expr => $($token:tt)*) => (
14        quote::quote_spanned!(
15            proc_macro2::Span::call_site().located_at($span) => $($token)*
16        )
17    )
18}
19
20#[derive(Debug)]
21pub enum FieldName {
22    Cased(Name),
23    Uncased(Name),
24}
25
26#[derive(FromMeta)]
27pub struct FieldAttr {
28    pub name: Option<FieldName>,
29    pub validate: Option<SpanWrapped<syn::Expr>>,
30    pub default: Option<syn::Expr>,
31    pub default_with: Option<syn::Expr>,
32}
33
34impl FieldAttr {
35    const NAME: &'static str = "field";
36}
37
38pub(crate) trait FieldExt {
39    fn ident(&self) -> Option<&syn::Ident>;
40    fn member(&self) -> syn::Member;
41    fn context_ident(&self) -> syn::Ident;
42    fn field_names(&self) -> Result<Vec<FieldName>>;
43    fn first_field_name(&self) -> Result<Option<FieldName>>;
44    fn stripped_ty(&self) -> syn::Type;
45    fn name_buf_opt(&self) -> Result<TokenStream>;
46}
47
48#[derive(FromMeta)]
49pub struct VariantAttr {
50    pub value: Name,
51}
52
53impl VariantAttr {
54    const NAME: &'static str = "field";
55}
56
57pub(crate) trait VariantExt {
58    fn first_form_field_value(&self) -> Result<FieldName>;
59    fn form_field_values(&self) -> Result<Vec<FieldName>>;
60}
61
62impl VariantExt for Variant<'_> {
63    fn first_form_field_value(&self) -> Result<FieldName> {
64        let value = VariantAttr::from_attrs(VariantAttr::NAME, &self.attrs)?
65            .into_iter()
66            .next()
67            .map(|attr| FieldName::Uncased(attr.value))
68            .unwrap_or_else(|| FieldName::Uncased(Name::from(&self.ident)));
69
70        Ok(value)
71    }
72
73    fn form_field_values(&self) -> Result<Vec<FieldName>> {
74        let attr_values = VariantAttr::from_attrs(VariantAttr::NAME, &self.attrs)?
75            .into_iter()
76            .map(|attr| FieldName::Uncased(attr.value))
77            .collect::<Vec<_>>();
78
79        if attr_values.is_empty() {
80            return Ok(vec![FieldName::Uncased(Name::from(&self.ident))]);
81        }
82
83        Ok(attr_values)
84    }
85}
86
87impl FromMeta for FieldName {
88    fn from_meta(meta: &MetaItem) -> Result<Self> {
89        // These are used during parsing.
90        const CONTROL_CHARS: &[char] = &['&', '=', '?', '.', '[', ']'];
91
92        fn is_valid_field_name(s: &str) -> bool {
93            // The HTML5 spec (4.10.18.1) says 'isindex' is not allowed.
94            if s == "isindex" || s.is_empty() {
95                return false
96            }
97
98            // We allow all visible ASCII characters except `CONTROL_CHARS`.
99            s.chars().all(|c| c.is_ascii_graphic() && !CONTROL_CHARS.contains(&c))
100        }
101
102        let field_name = match Name::from_meta(meta) {
103            Ok(name) => FieldName::Cased(name),
104            Err(_) => {
105                #[derive(FromMeta)]
106                struct Inner {
107                    #[meta(naked)]
108                    uncased: Name
109                }
110
111                let expr = meta.expr()?;
112                let item: MetaItem = syn::parse2(quote!(#expr))?;
113                let inner = Inner::from_meta(&item)?;
114                FieldName::Uncased(inner.uncased)
115            }
116        };
117
118        if !is_valid_field_name(field_name.as_str()) {
119            let chars = CONTROL_CHARS.iter()
120                .map(|c| format!("{:?}", c))
121                .collect::<Vec<_>>()
122                .join(", ");
123
124            return Err(meta.value_span()
125                .error("invalid form field name")
126                .help(format!("field name cannot be `isindex` or contain {}", chars)));
127        }
128
129        Ok(field_name)
130    }
131}
132
133impl std::ops::Deref for FieldName {
134    type Target = Name;
135
136    fn deref(&self) -> &Self::Target {
137        match self {
138            FieldName::Cased(n) | FieldName::Uncased(n) => n,
139        }
140    }
141}
142
143impl ToTokens for FieldName {
144    fn to_tokens(&self, tokens: &mut TokenStream) {
145        (self as &Name).to_tokens(tokens)
146    }
147}
148
149impl PartialEq for FieldName {
150    fn eq(&self, other: &Self) -> bool {
151        use FieldName::*;
152
153        match (self, other) {
154            (Cased(a), Cased(b)) => a == b,
155            (Cased(a), Uncased(u)) | (Uncased(u), Cased(a)) => a == u.as_uncased_str(),
156            (Uncased(u1), Uncased(u2)) => u1.as_uncased_str() == u2.as_uncased_str(),
157        }
158    }
159}
160
161fn member_to_ident(member: syn::Member) -> syn::Ident {
162    match member {
163        syn::Member::Named(ident) => ident,
164        syn::Member::Unnamed(i) => format_ident!("__{}", i.index, span = i.span),
165    }
166}
167
168impl FieldExt for Field<'_> {
169    fn ident(&self) -> Option<&syn::Ident> {
170        self.ident.as_ref()
171    }
172
173    fn member(&self) -> syn::Member {
174        match self.ident().cloned() {
175            Some(ident) => syn::Member::Named(ident),
176            None => syn::Member::Unnamed(syn::Index {
177                index: self.index as u32,
178                span: self.ty.span()
179            })
180        }
181    }
182
183    /// Returns the ident used by the context generated for the `FromForm` impl.
184    /// This is _not_ the field's ident and should not be used as such.
185    fn context_ident(&self) -> syn::Ident {
186        member_to_ident(self.member())
187    }
188
189    // With named existentials, this could return an `impl Iterator`...
190    fn field_names(&self) -> Result<Vec<FieldName>> {
191        let attr_names = FieldAttr::from_attrs(FieldAttr::NAME, &self.attrs)?
192            .into_iter()
193            .filter_map(|attr| attr.name)
194            .collect::<Vec<_>>();
195
196        if attr_names.is_empty() {
197            if let Some(ident) = self.ident() {
198                return Ok(vec![FieldName::Cased(Name::from(ident))]);
199            }
200        }
201
202        Ok(attr_names)
203    }
204
205    fn first_field_name(&self) -> Result<Option<FieldName>> {
206        Ok(self.field_names()?.into_iter().next())
207    }
208
209    fn stripped_ty(&self) -> syn::Type {
210        self.ty.with_stripped_lifetimes()
211    }
212
213    fn name_buf_opt(&self) -> Result<TokenStream> {
214        let (span, field_names) = (self.span(), self.field_names()?);
215        define_spanned_export!(span => _form);
216
217        Ok(field_names.first()
218            .map(|name| quote_spanned!(span => Some(#_form::NameBuf::from((__c.__parent, #name)))))
219            .unwrap_or_else(|| quote_spanned!(span => None::<#_form::NameBuf>)))
220    }
221}
222
223#[derive(Default)]
224struct RecordMemberAccesses {
225    reference: bool,
226    accesses: HashSet<(syn::Ident, bool)>,
227}
228
229impl<'a> Visit<'a> for RecordMemberAccesses {
230    fn visit_expr_reference(&mut self, i: &'a syn::ExprReference) {
231        self.reference = true;
232        syn::visit::visit_expr_reference(self, i);
233        self.reference = false;
234    }
235
236    fn visit_expr_field(&mut self, i: &syn::ExprField) {
237        if let syn::Expr::Path(e) = &*i.base {
238            if e.path.is_ident("self") {
239                let ident = member_to_ident(i.member.clone());
240                self.accesses.insert((ident, self.reference));
241            }
242        }
243
244        syn::visit::visit_expr_field(self, i);
245    }
246}
247
248struct ValidationMutator<'a> {
249    field: Field<'a>,
250    visited: bool,
251}
252
253impl ValidationMutator<'_> {
254    fn visit_token_stream(&mut self, tt: TokenStream) -> TokenStream {
255        use TokenTree::*;
256
257        let mut iter = tt.into_iter();
258        let mut stream = TokenStream::new();
259        while let Some(tt) = iter.next() {
260            match tt {
261                Ident(s3lf) if s3lf == "self" => {
262                    match (iter.next(), iter.next()) {
263                        (Some(Punct(p)), Some(Ident(i))) if p.as_char() == '.' => {
264                            let field = syn::parse_quote!(#s3lf #p #i);
265                            let mut expr = syn::Expr::Field(field);
266                            self.visit_expr_mut(&mut expr);
267                            expr.to_tokens(&mut stream);
268                        },
269                        (tt1, tt2) => stream.append_all(&[Some(Ident(s3lf)), tt1, tt2]),
270                    }
271                },
272                TokenTree::Group(group) => {
273                    let tt = self.visit_token_stream(group.stream());
274                    let mut new = proc_macro2::Group::new(group.delimiter(), tt);
275                    new.set_span(group.span());
276                    let group = TokenTree::Group(new);
277                    stream.append(group);
278                }
279                tt => stream.append(tt),
280            }
281        }
282
283        stream
284    }
285}
286
287impl VisitMut for ValidationMutator<'_> {
288    fn visit_expr_call_mut(&mut self, call: &mut syn::ExprCall) {
289        // Only modify the first call we see.
290        if self.visited {
291            return syn::visit_mut::visit_expr_call_mut(self, call);
292        }
293
294        self.visited = true;
295        let accessor = self.field.context_ident().with_span(self.field.ty.span());
296        call.args.insert(0, syn::parse_quote!(#accessor));
297        syn::visit_mut::visit_expr_call_mut(self, call);
298    }
299
300    fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
301        mac.tokens = self.visit_token_stream(mac.tokens.clone());
302        syn::visit_mut::visit_macro_mut(self, mac);
303    }
304
305    fn visit_ident_mut(&mut self, i: &mut syn::Ident) {
306        // replace `self` with the context ident
307        if i == "self" {
308            *i = self.field.context_ident().with_span(self.field.ty.span());
309        }
310    }
311
312    fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
313        fn inner_field(i: &syn::Expr) -> Option<syn::Expr> {
314            if let syn::Expr::Field(e) = i {
315                if let syn::Expr::Path(p) = &*e.base {
316                    if p.path.is_ident("self") {
317                        let member = &e.member;
318                        return Some(syn::parse_quote!(#member));
319                    }
320                }
321            }
322
323            None
324        }
325
326        // replace `self.field` and `&self.field` with `field`
327        if let syn::Expr::Reference(r) = i {
328            if let Some(expr) = inner_field(&r.expr) {
329                if let Some(ref m) = r.mutability {
330                    m.span()
331                        .warning("`mut` has no effect in FromForm` validation")
332                        .note("`mut` is being discarded")
333                        .emit_as_item_tokens();
334                }
335
336                *i = expr;
337            }
338        } else if let Some(expr) = inner_field(&i) {
339            *i = expr;
340        }
341
342        syn::visit_mut::visit_expr_mut(self, i);
343    }
344}
345
346pub fn validators<'v>(field: Field<'v>) -> Result<impl Iterator<Item = syn::Expr> + 'v> {
347    Ok(FieldAttr::from_attrs(FieldAttr::NAME, &field.attrs)?
348        .into_iter()
349        .chain(FieldAttr::from_attrs(FieldAttr::NAME, field.parent.attrs())?)
350        .filter_map(|a| a.validate)
351        .map(move |mut expr| {
352            let mut record = RecordMemberAccesses::default();
353            record.accesses.insert((field.context_ident(), true));
354            record.visit_expr(&expr);
355
356            let mut v = ValidationMutator { field, visited: false };
357            v.visit_expr_mut(&mut expr);
358
359            let span = expr.key_span.unwrap_or(field.ty.span());
360            let matchers = record.accesses.iter().map(|(member, _)| member);
361            let values = record.accesses.iter()
362                .map(|(member, is_ref)| {
363                    if *is_ref { quote_spanned!(span => &#member) }
364                    else { quote_spanned!(span => #member) }
365                });
366
367            let matchers = quote_spanned!(span => (#(Some(#matchers)),*));
368            let values = quote_spanned!(span => (#(#values),*));
369            let name_opt = field.name_buf_opt().unwrap();
370
371            define_spanned_export!(span => _form);
372            let expr: syn::Expr = syn::parse_quote_spanned!(span => {
373                #[allow(unused_parens)]
374                let __result: #_form::Result<'_, ()> = match #values {
375                    #matchers => #expr,
376                    _ => Ok(()),
377                };
378
379                let __e_name = #name_opt;
380                __result.map_err(|__e| match __e_name {
381                    Some(__name) => __e.with_name(__name),
382                    None => __e
383                })
384            });
385
386            expr
387        }))
388}
389
390/// Take an $expr in `default = $expr` and turn it into a `Some($expr.into())`.
391///
392/// As a result of calling `into()`, type inference fails for two common
393/// expressions: integer literals and the bare `None`. As a result, we cheat: if
394/// the expr matches either condition, we pass them through unchanged.
395fn default_expr(expr: &syn::Expr) -> TokenStream {
396    use syn::{Expr, Lit, ExprLit};
397
398    if matches!(expr, Expr::Path(e) if e.path.is_ident("None")) {
399        quote!(#expr)
400    } else if matches!(expr, Expr::Lit(ExprLit { lit: Lit::Int(_), .. })) {
401        quote_spanned!(expr.span() => Some(#expr))
402    } else {
403        quote_spanned!(expr.span() => Some({ #expr }.into()))
404    }
405}
406
407pub fn default<'v>(field: Field<'v>) -> Result<Option<TokenStream>> {
408    let field_attrs = FieldAttr::from_attrs(FieldAttr::NAME, &field.attrs)?;
409    let parent_attrs = FieldAttr::from_attrs(FieldAttr::NAME, field.parent.attrs())?;
410
411    // Expressions in `default = `, except for `None`, are wrapped in `Some()`.
412    let mut expr = field_attrs.iter()
413        .chain(parent_attrs.iter())
414        .filter_map(|a| a.default.as_ref()).map(default_expr);
415
416    // Expressions in `default_with` are passed through directly.
417    let mut expr_with = field_attrs.iter()
418        .chain(parent_attrs.iter())
419        .filter_map(|a| a.default_with.as_ref())
420        .map(|e| e.to_token_stream());
421
422    // Pull the first `default` and `default_with` expressions.
423    let (default, default_with) = (expr.next(), expr_with.next());
424
425    // If there are any more of either, emit an error.
426    if let (Some(e), _) | (_, Some(e)) = (expr.next(), expr_with.next()) {
427        return Err(e.span()
428            .error("duplicate default field expression")
429            .help("at most one `default` or `default_with` is allowed"));
430    }
431
432    // Emit the final expression of type `Option<#ty>` unless both `default` and
433    // `default_with` were provided in which case we error.
434    let ty = field.stripped_ty();
435    match (default, default_with) {
436        (Some(e1), Some(e2)) => {
437            Err(e1.span()
438                .error("duplicate default expressions")
439                .help("only one of `default` or `default_with` must be used")
440                .span_note(e2.span(), "other default expression is here"))
441        },
442        (Some(e), None) | (None, Some(e)) => {
443            Ok(Some(quote_spanned!(e.span() => {
444                let __default: Option<#ty> = if __opts.strict {
445                    None
446                } else {
447                    #e
448                };
449
450                __default
451            })))
452        },
453        (None, None) => Ok(None)
454    }
455}
456
457pub fn first_duplicate<K: Spanned, V: PartialEq + Spanned>(
458    keys: impl Iterator<Item = K> + Clone,
459    values: impl Fn(&K) -> Result<Vec<V>>,
460) -> Result<Option<((usize, Span, Span), (usize, Span, Span))>> {
461    let (mut all_values, mut key_map) = (vec![], vec![]);
462    for key in keys {
463        all_values.append(&mut values(&key)?);
464        key_map.push((all_values.len(), key));
465    }
466
467    // get the key corresponding to all_value index `k`.
468    let key = |k| key_map.iter().find(|(i, _)| k < *i).expect("k < *i");
469
470    for (i, a) in all_values.iter().enumerate() {
471        let mut rest = all_values.iter().enumerate().skip(i + 1);
472        if let Some((j, b)) = rest.find(|(_, b)| *b == a) {
473            let (a_i, key_a) = key(i);
474            let (b_i, key_b) = key(j);
475
476            let a = (*a_i, key_a.span(), a.span());
477            let b = (*b_i, key_b.span(), b.span());
478            return Ok(Some((a, b)));
479        }
480    }
481
482    Ok(None)
483}