1use syn::spanned::Spanned;
2use syn::{punctuated::Punctuated, Token};
3use syn::parse::{Parse as SynParse, ParseStream as SynParseStream};
4use proc_macro2::{Span, Delimiter};
5use proc_macro2_diagnostics::{Diagnostic, SpanDiagnosticExt};
6
7pub type PResult<T> = Result<T, Diagnostic>;
8
9pub trait Parse: Sized {
10 fn parse(input: syn::parse::ParseStream) -> PResult<Self>;
11
12 fn syn_parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
13 Self::parse(input).map_err(|e| e.into())
14 }
15}
16
17trait ParseStreamExt {
18 fn parse_group<F, G>(self, delimiter: Delimiter, parser: F) -> syn::parse::Result<G>
19 where F: FnOnce(SynParseStream) -> syn::parse::Result<G>;
20
21 fn try_parse<F, G>(self, parser: F) -> syn::parse::Result<G>
22 where F: Fn(SynParseStream) -> syn::parse::Result<G>;
23}
24
25impl<'a> ParseStreamExt for SynParseStream<'a> {
26 fn parse_group<F, G>(self, delimiter: Delimiter, parser: F) -> syn::parse::Result<G>
27 where F: FnOnce(SynParseStream) -> syn::parse::Result<G>
28 {
29 let content;
30 match delimiter {
31 Delimiter::Brace => { syn::braced!(content in self); },
32 Delimiter::Bracket => { syn::bracketed!(content in self); },
33 Delimiter::Parenthesis => { syn::parenthesized!(content in self); },
34 Delimiter::None => return parser(self),
35 }
36
37 parser(&content)
38 }
39
40 fn try_parse<F, G>(self, parser: F) -> syn::parse::Result<G>
41 where F: Fn(SynParseStream) -> syn::parse::Result<G>
42 {
43 let input = self.fork();
44 parser(&input)?;
45 parser(self)
46 }
47}
48
49#[derive(Debug)]
50pub struct CallPattern {
51 pub name: Option<syn::Ident>,
52 pub at: Option<Token![@]>,
53 pub expr: syn::ExprCall,
54}
55
56impl syn::parse::Parse for CallPattern {
57 fn parse(input: syn::parse::ParseStream) -> syn::parse::Result<Self> {
58 Self::syn_parse(input)
59 }
60}
61
62impl quote::ToTokens for CallPattern {
63 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
64 let (expr, at) = (&self.expr, &self.at);
65 match self.name {
66 Some(ref name) => quote!(#name #at #expr).to_tokens(tokens),
67 None => expr.to_tokens(tokens)
68 }
69 }
70}
71
72impl quote::ToTokens for Guard {
73 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
74 self.expr.to_tokens(tokens)
75 }
76}
77
78type CallPatterns = Punctuated<CallPattern, Token![|]>;
79
80#[derive(Debug)]
81pub enum Pattern {
82 Wild(Token![_]),
83 Calls(CallPatterns),
84}
85
86#[derive(Debug)]
87pub struct Guard {
88 pub _if: Token![if],
89 pub expr: syn::Expr,
90}
91
92#[derive(Debug)]
93pub struct Case {
94 pub pattern: Pattern,
95 pub expr: syn::Expr,
96 pub guard: Option<Guard>,
97 pub span: Span,
98}
99
100#[derive(Debug)]
101pub struct Switch {
102 pub context: Context,
103 pub cases: Punctuated<Case, Token![,]>
104}
105
106fn parse_expr_call(input: SynParseStream) -> syn::parse::Result<syn::ExprCall> {
108 let path: syn::ExprPath = input.parse()?;
109 let paren_span = input.cursor().span();
110 let args = input.parse_group(Delimiter::Parenthesis, |i| {
111 i.parse_terminated(syn::Expr::parse, Token![,])
112 })?;
113
114 Ok(syn::ExprCall {
115 attrs: vec![],
116 func: Box::new(syn::Expr::Path(path)),
117 paren_token: syn::token::Paren(paren_span),
118 args
119 })
120}
121
122impl Parse for CallPattern {
123 fn parse(input: SynParseStream) -> PResult<Self> {
124 let name_at = input.try_parse(|input| {
125 let ident: syn::Ident = input.parse()?;
126 let at = input.parse::<Token![@]>()?;
127 Ok((ident, at))
128 }).ok();
129
130 let (name, at) = match name_at {
131 Some((name, at)) => (Some(name), Some(at)),
132 None => (None, None)
133 };
134
135 Ok(CallPattern { name, at, expr: parse_expr_call(input)? })
136 }
137}
138
139impl Parse for Guard {
140 fn parse(input: SynParseStream) -> PResult<Self> {
141 Ok(Guard {
142 _if: input.parse()?,
143 expr: input.parse()?,
144 })
145 }
146}
147
148impl Parse for Pattern {
149 fn parse(input: SynParseStream) -> PResult<Self> {
150 type CallPatterns = Punctuated<CallPattern, Token![|]>;
151
152 let pattern = match input.parse::<Token![_]>() {
154 Ok(wild) => Pattern::Wild(wild),
155 Err(_) => Pattern::Calls(input.call(CallPatterns::parse_separated_nonempty)?)
156 };
157
158 if let Pattern::Calls(ref calls) = pattern {
160 let first_name = calls.first().and_then(|call| call.name.clone());
161 for call in calls.iter() {
162 if first_name != call.name {
163 let mut err = if let Some(ref ident) = call.name {
164 ident.span()
165 .error("captured name differs from declaration")
166 } else {
167 call.expr.span()
168 .error("expected capture name due to previous declaration")
169 };
170
171 err = match first_name {
172 Some(p) => err.span_note(p.span(), "declared here"),
173 None => err
174 };
175
176 return Err(err);
177 }
178 }
179 }
180
181 Ok(pattern)
182 }
183}
184
185impl Parse for Case {
186 fn parse(input: SynParseStream) -> PResult<Self> {
187 let case_span_start = input.cursor().span();
188 let pattern = Pattern::parse(input)?;
189 let guard = match input.peek(Token![if]) {
190 true => Some(Guard::parse(input)?),
191 false => None,
192 };
193
194 input.parse::<Token![=>]>()?;
195 let expr: syn::Expr = input.parse()?;
196 let span = case_span_start
197 .join(input.cursor().span())
198 .unwrap_or(case_span_start);
199
200 Ok(Case { pattern, expr, guard, span, })
201 }
202}
203
204#[derive(Debug)]
205pub struct Context {
206 pub info: syn::Ident,
207 pub input: syn::Expr,
208 pub marker: syn::Expr,
209 pub output: syn::Type,
210}
211
212impl Parse for Context {
213 fn parse(stream: SynParseStream) -> PResult<Context> {
214 let (info, input, marker, output) = stream.parse_group(Delimiter::Bracket, |inner| {
215 let info: syn::Ident = inner.parse()?;
216 inner.parse::<Token![;]>()?;
217 let input: syn::Expr = inner.parse()?;
218 inner.parse::<Token![;]>()?;
219 let marker: syn::Expr = inner.parse()?;
220 inner.parse::<Token![;]>()?;
221 let output: syn::Type = inner.parse()?;
222 Ok((info, input, marker, output))
223 })?;
224
225 Ok(Context { info, input, marker, output })
226 }
227}
228
229impl Parse for Switch {
230 fn parse(stream: SynParseStream) -> PResult<Switch> {
231 let context = stream.try_parse(Context::syn_parse)?;
232 let cases = stream.parse_terminated(Case::syn_parse, Token![,])?;
233 if !stream.is_empty() {
234 Err(stream.error("trailing characters; expected eof"))?;
235 }
236
237 if cases.is_empty() {
238 Err(stream.error("switch cannot be empty"))?;
239 }
240
241 for case in cases.iter().take(cases.len() - 1) {
242 if let Pattern::Wild(..) = case.pattern {
243 if case.guard.is_none() {
244 Err(case.span.error("unguarded `_` can only appear as the last case"))?;
245 }
246 }
247 }
248
249 Ok(Switch { context, cases })
250 }
251}
252
253#[derive(Debug, Clone)]
254pub struct AttrArgs {
255 pub raw: Option<Span>,
256 pub rewind: Option<Span>,
257 pub peek: Option<Span>,
258}
259
260impl Parse for AttrArgs {
261 fn parse(input: SynParseStream) -> PResult<Self> {
262 let args = input.call(<Punctuated<syn::Ident, Token![,]>>::parse_terminated)?;
263 let (mut raw, mut rewind, mut peek) = Default::default();
264 for case in args.iter() {
265 if case == "raw" {
266 raw = Some(case.span());
267 } else if case == "rewind" {
268 rewind = Some(case.span());
269 } else if case == "peek" {
270 peek = Some(case.span());
271 } else {
272 return Err(case.span()
273 .error(format!("unknown attribute argument `{}`", case))
274 .help("supported arguments are: `rewind`, `peek`"));
275 }
276 }
277
278 Ok(AttrArgs { raw, rewind, peek })
279 }
280}