devise_core/
generator.rs

1use std::ops::Deref;
2
3use proc_macro2::TokenStream;
4use syn::{self, Token, punctuated::Punctuated, spanned::Spanned, parse::Parser};
5use proc_macro2_diagnostics::{SpanDiagnosticExt, Diagnostic};
6use quote::ToTokens;
7
8use crate::ext::{GenericExt, GenericParamExt, GenericsExt};
9use crate::support::Support;
10use crate::derived::{ItemInput, Input};
11use crate::mapper::Mapper;
12use crate::validator::Validator;
13
14pub type Result<T> = std::result::Result<T, Diagnostic>;
15
16pub struct TraitItem {
17    item: syn::ItemImpl,
18    pub path: syn::Path,
19    pub name: syn::Ident,
20}
21
22impl TraitItem {
23    fn parse<T: ToTokens>(raw: T) -> Self {
24        let item: syn::ItemImpl = syn::parse2(quote!(#raw for Foo {}))
25            .expect("invalid impl token stream");
26
27        let path = item.trait_.clone()
28            .expect("impl does not have trait")
29            .1;
30
31        let name = path.segments.last()
32            .map(|s| s.ident.clone())
33            .expect("trait to impl for is empty");
34
35        Self { item, path, name }
36    }
37}
38
39impl Deref for TraitItem {
40    type Target = syn::ItemImpl;
41
42    fn deref(&self) -> &Self::Target {
43        &self.item
44    }
45}
46
47pub struct DeriveGenerator {
48    pub input: ItemInput,
49    pub item: TraitItem,
50    pub support: Support,
51    pub validator: Option<Box<dyn Validator>>,
52    pub inner_mappers: Vec<Box<dyn Mapper>>,
53    pub outer_mappers: Vec<Box<dyn Mapper>>,
54    pub type_bound_mapper: Option<Box<dyn Mapper>>,
55    generic_replacements: Vec<(usize, usize)>,
56}
57
58impl DeriveGenerator {
59    pub fn build_for<I, T>(input: I, trait_impl: T) -> DeriveGenerator
60        where I: Into<TokenStream>, T: ToTokens
61    {
62        let item = TraitItem::parse(trait_impl);
63        let input: syn::DeriveInput = syn::parse2(input.into())
64            .expect("invalid derive input");
65
66        DeriveGenerator {
67            item,
68            input: input.into(),
69            support: Support::default(),
70            generic_replacements: vec![],
71            validator: None,
72            type_bound_mapper: None,
73            inner_mappers: vec![],
74            outer_mappers: vec![],
75        }
76    }
77
78    pub fn support(&mut self, support: Support) -> &mut Self {
79        self.support = support;
80        self
81    }
82
83    pub fn type_bound<B: ToTokens>(&mut self, bound: B) -> &mut Self {
84        let tokens = bound.to_token_stream();
85        self.type_bound_mapper(crate::MapperBuild::new()
86            .try_input_map(move |_, input| {
87                let tokens = tokens.clone();
88                let bounds = input.generics().parsed_bounded_types(tokens)?;
89                Ok(bounds.into_token_stream())
90            }))
91    }
92
93    /// Take the 0-indexed `trait_gen`th generic in the generics in impl<..>
94    /// being built and substitute those tokens in place of the 0-indexed
95    /// `impl_gen`th generic of the same kind in the input type.
96    pub fn replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self {
97        self.generic_replacements.push((trait_gen, impl_gen));
98        self
99    }
100
101    pub fn validator<V: Validator + 'static>(&mut self, validator: V) -> &mut Self {
102        self.validator = Some(Box::new(validator));
103        self
104    }
105
106    pub fn type_bound_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
107        self.type_bound_mapper = Some(Box::new(mapper));
108        self
109    }
110
111    pub fn inner_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
112        self.inner_mappers.push(Box::new(mapper));
113        self
114    }
115
116    pub fn outer_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
117        self.outer_mappers.push(Box::new(mapper));
118        self
119    }
120
121    fn _to_tokens(&mut self) -> Result<TokenStream> {
122        // Step 1: Run all validators.
123        // Step 1a: First, check for data support.
124        let input = Input::from(&self.input);
125        let (span, support) = (input.span(), self.support);
126        match input {
127            Input::Struct(v) => {
128                if v.fields().are_named() && !support.contains(Support::NamedStruct) {
129                    return Err(span.error("named structs are not supported"));
130                }
131
132                if !v.fields().are_named() && !support.contains(Support::TupleStruct) {
133                    return Err(span.error("tuple structs are not supported"));
134                }
135            }
136            Input::Enum(..) if !support.contains(Support::Enum) => {
137                return Err(span.error("enums are not supported"));
138            }
139            Input::Union(..) if !support.contains(Support::Union) => {
140                return Err(span.error("unions are not supported"));
141            }
142            _ => { /* we're okay! */ }
143        }
144
145        // Step 1b: Second, check for generics support.
146        for generic in &input.generics().params {
147            use syn::GenericParam::*;
148
149            let span = generic.span();
150            match generic {
151                Type(..) if !support.contains(Support::Type) => {
152                    return Err(span.error("type generics are not supported"));
153                }
154                Lifetime(..) if !support.contains(Support::Lifetime) => {
155                    return Err(span.error("lifetime generics are not supported"));
156                }
157                Const(..) if !support.contains(Support::Const) => {
158                    return Err(span.error("const generics are not supported"));
159                }
160                _ => { /* we're okay! */ }
161            }
162        }
163
164        // Step 1c: Third, run the custom validator, if any.
165        if let Some(validator) = &mut self.validator {
166            validator.validate_input((&self.input).into())?;
167        }
168
169        // Step 2: Generate the code!
170
171        // Step 2a: Copy user's generics to mutate with bounds + replacements.
172        let mut type_generics = self.input.generics().clone();
173
174        // Step 2b: Perform generic replacememnt: replace generics in the input
175        // type with generics from the trait definition: 1) determine the
176        // identifer of the generic to be replaced in the type. 2) replace every
177        // identifer in the type with the same name with the identifer of the
178        // replacement trait generic. For example:
179        //   * replace: trait_i = 1, type_i = 0
180        //   * trait: impl<'_a, '_b: '_a> GenExample<'_a, '_b>
181        //   * type: GenFooAB<'x, 'y: 'x>
182        //   * new type: GenFooAB<'_b, 'y: 'b>
183        for (trait_i, type_i) in &self.generic_replacements {
184            let idents = self.item.generics.params.iter()
185                .nth(*trait_i)
186                .and_then(|trait_gen| type_generics.params.iter()
187                    .filter(|gen| gen.kind() == trait_gen.kind())
188                    .nth(*type_i)
189                    .map(|type_gen| (trait_gen.ident(), type_gen.ident().clone())));
190
191            if let Some((with, ref to_replace)) = idents {
192                type_generics.replace(to_replace, with);
193            }
194        }
195
196        // Step 2c.1: Generate the code for each function.
197        let mut function_code = vec![];
198        for mapper in &mut self.inner_mappers {
199            let tokens = mapper.map_input((&self.input).into())?;
200            function_code.push(tokens);
201        }
202
203        // Step 2c.2: Generate the code for each item.
204        let mut item_code = vec![];
205        for mapper in &mut self.outer_mappers {
206            let tokens = mapper.map_input((&self.input).into())?;
207            item_code.push(tokens);
208        }
209
210        // Step 2d: Add the requested type bounds.
211        if let Some(ref mut mapper) = self.type_bound_mapper {
212            let tokens = mapper.map_input((&self.input).into())?;
213            let bounds = Punctuated::<syn::WherePredicate, Token![,]>::parse_terminated
214                .parse2(tokens)
215                .map_err(|e| e.span().error(format!("invalid type bounds: {}", e)))?;
216
217            type_generics.add_where_predicates(bounds);
218        }
219
220        // Step 2e: Determine which generics from the type need to be added to
221        // the trait's `impl<>` generics. These are all of the generics in the
222        // type that aren't in the trait's `impl<>` already.
223        let mut type_generics_for_impl = self.item.generics.clone();
224        for type_gen in &type_generics.params {
225            let type_gen_in_trait_gens = type_generics_for_impl.params.iter()
226                .map(|gen| gen.ident())
227                .find(|g| g == &type_gen.ident())
228                .is_some();
229
230            if !type_gen_in_trait_gens {
231                type_generics_for_impl.params.push(type_gen.clone())
232            }
233        }
234
235        // Step 2f: Split the generics, but use the `impl_generics` from above.
236        let (impl_gen, _, _) = type_generics_for_impl.split_for_impl();
237        let (_, ty_gen, where_gen) = type_generics.split_for_impl();
238
239        // Step 2g: Generate the complete implementation.
240        let (target, trait_path) = (&self.input.ident(), &self.item.path);
241        Ok(quote! {
242            #[allow(non_snake_case)]
243            const _: () = {
244                #(#item_code)*
245
246                impl #impl_gen #trait_path for #target #ty_gen #where_gen {
247                    #(#function_code)*
248                }
249            };
250        })
251    }
252
253    pub fn debug(&mut self) -> &mut Self {
254        match self._to_tokens() {
255            Ok(tokens) => println!("Tokens produced: {}", tokens.to_string()),
256            Err(e) => println!("Error produced: {:?}", e)
257        }
258
259        self
260    }
261
262    pub fn to_tokens<T: From<TokenStream>>(&mut self) -> T {
263        self.try_to_tokens()
264            .unwrap_or_else(|diag| diag.emit_as_item_tokens())
265            .into()
266    }
267
268    pub fn try_to_tokens<T: From<TokenStream>>(&mut self) -> Result<T> {
269        // FIXME: Emit something like: Trait: msg.
270        self._to_tokens()
271            .map_err(|diag| {
272                if let Some(last) = self.item.path.segments.last() {
273                    use proc_macro2::Span;
274                    use proc_macro2_diagnostics::Level::*;
275
276                    let id = &last.ident;
277                    let msg = match diag.level() {
278                        Error => format!("error occurred while deriving `{}`", id),
279                        Warning => format!("warning issued by `{}` derive", id),
280                        Note => format!("note issued by `{}` derive", id),
281                        Help => format!("help provided by `{}` derive", id),
282                        _ => format!("while deriving `{}`", id)
283                    };
284
285                    diag.span_note(Span::call_site(), msg)
286                } else {
287                    diag
288                }
289            })
290            .map(|t| t.into())
291    }
292}