1use std::ops::Deref;
2
3use proc_macro2::TokenStream;
4use syn::{self, Token, punctuated::Punctuated, spanned::Spanned, parse::Parser};
5use proc_macro2_diagnostics::{SpanDiagnosticExt, Diagnostic};
6use quote::ToTokens;
7
8use crate::ext::{GenericExt, GenericParamExt, GenericsExt};
9use crate::support::Support;
10use crate::derived::{ItemInput, Input};
11use crate::mapper::Mapper;
12use crate::validator::Validator;
13
14pub type Result<T> = std::result::Result<T, Diagnostic>;
15
16pub struct TraitItem {
17 item: syn::ItemImpl,
18 pub path: syn::Path,
19 pub name: syn::Ident,
20}
21
22impl TraitItem {
23 fn parse<T: ToTokens>(raw: T) -> Self {
24 let item: syn::ItemImpl = syn::parse2(quote!(#raw for Foo {}))
25 .expect("invalid impl token stream");
26
27 let path = item.trait_.clone()
28 .expect("impl does not have trait")
29 .1;
30
31 let name = path.segments.last()
32 .map(|s| s.ident.clone())
33 .expect("trait to impl for is empty");
34
35 Self { item, path, name }
36 }
37}
38
39impl Deref for TraitItem {
40 type Target = syn::ItemImpl;
41
42 fn deref(&self) -> &Self::Target {
43 &self.item
44 }
45}
46
47pub struct DeriveGenerator {
48 pub input: ItemInput,
49 pub item: TraitItem,
50 pub support: Support,
51 pub validator: Option<Box<dyn Validator>>,
52 pub inner_mappers: Vec<Box<dyn Mapper>>,
53 pub outer_mappers: Vec<Box<dyn Mapper>>,
54 pub type_bound_mapper: Option<Box<dyn Mapper>>,
55 generic_replacements: Vec<(usize, usize)>,
56}
57
58impl DeriveGenerator {
59 pub fn build_for<I, T>(input: I, trait_impl: T) -> DeriveGenerator
60 where I: Into<TokenStream>, T: ToTokens
61 {
62 let item = TraitItem::parse(trait_impl);
63 let input: syn::DeriveInput = syn::parse2(input.into())
64 .expect("invalid derive input");
65
66 DeriveGenerator {
67 item,
68 input: input.into(),
69 support: Support::default(),
70 generic_replacements: vec![],
71 validator: None,
72 type_bound_mapper: None,
73 inner_mappers: vec![],
74 outer_mappers: vec![],
75 }
76 }
77
78 pub fn support(&mut self, support: Support) -> &mut Self {
79 self.support = support;
80 self
81 }
82
83 pub fn type_bound<B: ToTokens>(&mut self, bound: B) -> &mut Self {
84 let tokens = bound.to_token_stream();
85 self.type_bound_mapper(crate::MapperBuild::new()
86 .try_input_map(move |_, input| {
87 let tokens = tokens.clone();
88 let bounds = input.generics().parsed_bounded_types(tokens)?;
89 Ok(bounds.into_token_stream())
90 }))
91 }
92
93 pub fn replace_generic(&mut self, trait_gen: usize, impl_gen: usize) -> &mut Self {
97 self.generic_replacements.push((trait_gen, impl_gen));
98 self
99 }
100
101 pub fn validator<V: Validator + 'static>(&mut self, validator: V) -> &mut Self {
102 self.validator = Some(Box::new(validator));
103 self
104 }
105
106 pub fn type_bound_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
107 self.type_bound_mapper = Some(Box::new(mapper));
108 self
109 }
110
111 pub fn inner_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
112 self.inner_mappers.push(Box::new(mapper));
113 self
114 }
115
116 pub fn outer_mapper<V: Mapper + 'static>(&mut self, mapper: V) -> &mut Self {
117 self.outer_mappers.push(Box::new(mapper));
118 self
119 }
120
121 fn _to_tokens(&mut self) -> Result<TokenStream> {
122 let input = Input::from(&self.input);
125 let (span, support) = (input.span(), self.support);
126 match input {
127 Input::Struct(v) => {
128 if v.fields().are_named() && !support.contains(Support::NamedStruct) {
129 return Err(span.error("named structs are not supported"));
130 }
131
132 if !v.fields().are_named() && !support.contains(Support::TupleStruct) {
133 return Err(span.error("tuple structs are not supported"));
134 }
135 }
136 Input::Enum(..) if !support.contains(Support::Enum) => {
137 return Err(span.error("enums are not supported"));
138 }
139 Input::Union(..) if !support.contains(Support::Union) => {
140 return Err(span.error("unions are not supported"));
141 }
142 _ => { }
143 }
144
145 for generic in &input.generics().params {
147 use syn::GenericParam::*;
148
149 let span = generic.span();
150 match generic {
151 Type(..) if !support.contains(Support::Type) => {
152 return Err(span.error("type generics are not supported"));
153 }
154 Lifetime(..) if !support.contains(Support::Lifetime) => {
155 return Err(span.error("lifetime generics are not supported"));
156 }
157 Const(..) if !support.contains(Support::Const) => {
158 return Err(span.error("const generics are not supported"));
159 }
160 _ => { }
161 }
162 }
163
164 if let Some(validator) = &mut self.validator {
166 validator.validate_input((&self.input).into())?;
167 }
168
169 let mut type_generics = self.input.generics().clone();
173
174 for (trait_i, type_i) in &self.generic_replacements {
184 let idents = self.item.generics.params.iter()
185 .nth(*trait_i)
186 .and_then(|trait_gen| type_generics.params.iter()
187 .filter(|gen| gen.kind() == trait_gen.kind())
188 .nth(*type_i)
189 .map(|type_gen| (trait_gen.ident(), type_gen.ident().clone())));
190
191 if let Some((with, ref to_replace)) = idents {
192 type_generics.replace(to_replace, with);
193 }
194 }
195
196 let mut function_code = vec![];
198 for mapper in &mut self.inner_mappers {
199 let tokens = mapper.map_input((&self.input).into())?;
200 function_code.push(tokens);
201 }
202
203 let mut item_code = vec![];
205 for mapper in &mut self.outer_mappers {
206 let tokens = mapper.map_input((&self.input).into())?;
207 item_code.push(tokens);
208 }
209
210 if let Some(ref mut mapper) = self.type_bound_mapper {
212 let tokens = mapper.map_input((&self.input).into())?;
213 let bounds = Punctuated::<syn::WherePredicate, Token![,]>::parse_terminated
214 .parse2(tokens)
215 .map_err(|e| e.span().error(format!("invalid type bounds: {}", e)))?;
216
217 type_generics.add_where_predicates(bounds);
218 }
219
220 let mut type_generics_for_impl = self.item.generics.clone();
224 for type_gen in &type_generics.params {
225 let type_gen_in_trait_gens = type_generics_for_impl.params.iter()
226 .map(|gen| gen.ident())
227 .find(|g| g == &type_gen.ident())
228 .is_some();
229
230 if !type_gen_in_trait_gens {
231 type_generics_for_impl.params.push(type_gen.clone())
232 }
233 }
234
235 let (impl_gen, _, _) = type_generics_for_impl.split_for_impl();
237 let (_, ty_gen, where_gen) = type_generics.split_for_impl();
238
239 let (target, trait_path) = (&self.input.ident(), &self.item.path);
241 Ok(quote! {
242 #[allow(non_snake_case)]
243 const _: () = {
244 #(#item_code)*
245
246 impl #impl_gen #trait_path for #target #ty_gen #where_gen {
247 #(#function_code)*
248 }
249 };
250 })
251 }
252
253 pub fn debug(&mut self) -> &mut Self {
254 match self._to_tokens() {
255 Ok(tokens) => println!("Tokens produced: {}", tokens.to_string()),
256 Err(e) => println!("Error produced: {:?}", e)
257 }
258
259 self
260 }
261
262 pub fn to_tokens<T: From<TokenStream>>(&mut self) -> T {
263 self.try_to_tokens()
264 .unwrap_or_else(|diag| diag.emit_as_item_tokens())
265 .into()
266 }
267
268 pub fn try_to_tokens<T: From<TokenStream>>(&mut self) -> Result<T> {
269 self._to_tokens()
271 .map_err(|diag| {
272 if let Some(last) = self.item.path.segments.last() {
273 use proc_macro2::Span;
274 use proc_macro2_diagnostics::Level::*;
275
276 let id = &last.ident;
277 let msg = match diag.level() {
278 Error => format!("error occurred while deriving `{}`", id),
279 Warning => format!("warning issued by `{}` derive", id),
280 Note => format!("note issued by `{}` derive", id),
281 Help => format!("help provided by `{}` derive", id),
282 _ => format!("while deriving `{}`", id)
283 };
284
285 diag.span_note(Span::call_site(), msg)
286 } else {
287 diag
288 }
289 })
290 .map(|t| t.into())
291 }
292}