num_enum_derive/
parsing.rs

1use crate::enum_attributes::ErrorTypeAttribute;
2use crate::utils::die;
3use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes};
4use proc_macro2::Span;
5use quote::{format_ident, ToTokens};
6use std::collections::BTreeSet;
7use syn::{
8    parse::{Parse, ParseStream},
9    parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit,
10    LitInt, Meta, Path, Result, UnOp,
11};
12
13pub(crate) struct EnumInfo {
14    pub(crate) name: Ident,
15    pub(crate) repr: Ident,
16    pub(crate) variants: Vec<VariantInfo>,
17    pub(crate) error_type_info: ErrorType,
18}
19
20impl EnumInfo {
21    /// Returns whether the number of variants (ignoring defaults, catch-alls, etc) is the same as
22    /// the capacity of the repr.
23    pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> {
24        let repr_str = self.repr.to_string();
25        if !repr_str.is_empty() {
26            let suffix = repr_str
27                .strip_prefix('i')
28                .or_else(|| repr_str.strip_prefix('u'));
29            if let Some(suffix) = suffix {
30                if suffix == "size" {
31                    return Ok(false);
32                } else if let Ok(bits) = suffix.parse::<u32>() {
33                    let variants = 1usize.checked_shl(bits);
34                    return Ok(variants.map_or(false, |v| {
35                        v == self
36                            .variants
37                            .iter()
38                            .map(|v| v.alternative_values.len() + 1)
39                            .sum()
40                    }));
41                }
42            }
43        }
44        die!(self.repr.clone() => "Failed to parse repr into bit size");
45    }
46
47    pub(crate) fn default(&self) -> Option<&Ident> {
48        self.variants
49            .iter()
50            .find(|info| info.is_default)
51            .map(|info| &info.ident)
52    }
53
54    pub(crate) fn catch_all(&self) -> Option<&Ident> {
55        self.variants
56            .iter()
57            .find(|info| info.is_catch_all)
58            .map(|info| &info.ident)
59    }
60
61    pub(crate) fn variant_idents(&self) -> Vec<Ident> {
62        self.variants
63            .iter()
64            .filter(|variant| !variant.is_catch_all)
65            .map(|variant| variant.ident.clone())
66            .collect()
67    }
68
69    pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> {
70        self.variants
71            .iter()
72            .filter(|variant| !variant.is_catch_all)
73            .map(|info| {
74                let indices = 0..(info.alternative_values.len() + 1);
75                indices
76                    .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
77                    .collect()
78            })
79            .collect()
80    }
81
82    pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> {
83        self.variants
84            .iter()
85            .filter(|variant| !variant.is_catch_all)
86            .map(|variant| variant.all_values().cloned().collect())
87            .collect()
88    }
89
90    fn parse_attrs<Attrs: Iterator<Item = Attribute>>(
91        attrs: Attrs,
92    ) -> Result<(Ident, Option<ErrorType>)> {
93        let mut maybe_repr = None;
94        let mut maybe_error_type = None;
95        for attr in attrs {
96            if let Meta::List(meta_list) = &attr.meta {
97                if let Some(ident) = meta_list.path.get_ident() {
98                    if ident == "repr" {
99                        let mut nested = meta_list.tokens.clone().into_iter();
100                        let repr_tree = match (nested.next(), nested.next()) {
101                            (Some(repr_tree), None) => repr_tree,
102                            _ => die!(attr =>
103                                "Expected exactly one `repr` argument"
104                            ),
105                        };
106                        let repr_ident: Ident = parse_quote! {
107                            #repr_tree
108                        };
109                        if repr_ident == "C" {
110                            die!(repr_ident =>
111                                "repr(C) doesn't have a well defined size"
112                            );
113                        } else {
114                            maybe_repr = Some(repr_ident);
115                        }
116                    } else if ident == "num_enum" {
117                        let attributes =
118                            attr.parse_args_with(crate::enum_attributes::Attributes::parse)?;
119                        if let Some(error_type) = attributes.error_type {
120                            if maybe_error_type.is_some() {
121                                die!(attr => "At most one num_enum error_type attribute may be specified");
122                            }
123                            maybe_error_type = Some(error_type.into());
124                        }
125                    }
126                }
127            }
128        }
129        if maybe_repr.is_none() {
130            die!("Missing `#[repr({Integer})]` attribute");
131        }
132        Ok((maybe_repr.unwrap(), maybe_error_type))
133    }
134}
135
136impl Parse for EnumInfo {
137    fn parse(input: ParseStream) -> Result<Self> {
138        Ok({
139            let input: DeriveInput = input.parse()?;
140            let name = input.ident;
141            let data = match input.data {
142                Data::Enum(data) => data,
143                Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
144                Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
145            };
146
147            let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?;
148
149            let mut variants: Vec<VariantInfo> = vec![];
150            let mut has_default_variant: bool = false;
151            let mut has_catch_all_variant: bool = false;
152
153            // Vec to keep track of the used discriminants and alt values.
154            let mut discriminant_int_val_set = BTreeSet::new();
155
156            let mut next_discriminant = literal(0);
157            for variant in data.variants.into_iter() {
158                let ident = variant.ident.clone();
159
160                let discriminant = match &variant.discriminant {
161                    Some(d) => d.1.clone(),
162                    None => next_discriminant.clone(),
163                };
164
165                let mut raw_alternative_values: Vec<Expr> = vec![];
166                // Keep the attribute around for better error reporting.
167                let mut alt_attr_ref: Vec<&Attribute> = vec![];
168
169                // `#[num_enum(default)]` is required by `#[derive(FromPrimitive)]`
170                // and forbidden by `#[derive(UnsafeFromPrimitive)]`, so we need to
171                // keep track of whether we encountered such an attribute:
172                let mut is_default: bool = false;
173                let mut is_catch_all: bool = false;
174
175                for attribute in &variant.attrs {
176                    if attribute.path().is_ident("default") {
177                        if has_default_variant {
178                            die!(attribute =>
179                                "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
180                            );
181                        } else if has_catch_all_variant {
182                            die!(attribute =>
183                                "Attribute `default` is mutually exclusive with `catch_all`"
184                            );
185                        }
186                        is_default = true;
187                        has_default_variant = true;
188                    }
189
190                    if attribute.path().is_ident("num_enum") {
191                        match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
192                            Ok(variant_attributes) => {
193                                for variant_attribute in variant_attributes.items {
194                                    match variant_attribute {
195                                        NumEnumVariantAttributeItem::Default(default) => {
196                                            if has_default_variant {
197                                                die!(default.keyword =>
198                                                    "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
199                                                );
200                                            } else if has_catch_all_variant {
201                                                die!(default.keyword =>
202                                                    "Attribute `default` is mutually exclusive with `catch_all`"
203                                                );
204                                            }
205                                            is_default = true;
206                                            has_default_variant = true;
207                                        }
208                                        NumEnumVariantAttributeItem::CatchAll(catch_all) => {
209                                            if has_catch_all_variant {
210                                                die!(catch_all.keyword =>
211                                                    "Multiple variants marked with `#[num_enum(catch_all)]`"
212                                                );
213                                            } else if has_default_variant {
214                                                die!(catch_all.keyword =>
215                                                    "Attribute `catch_all` is mutually exclusive with `default`"
216                                                );
217                                            }
218
219                                            match variant
220                                                .fields
221                                                .iter()
222                                                .collect::<Vec<_>>()
223                                                .as_slice()
224                                            {
225                                                [syn::Field {
226                                                    ty: syn::Type::Path(syn::TypePath { path, .. }),
227                                                    ..
228                                                }] if path.is_ident(&repr) => {
229                                                    is_catch_all = true;
230                                                    has_catch_all_variant = true;
231                                                }
232                                                _ => {
233                                                    die!(catch_all.keyword =>
234                                                        "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type"
235                                                    );
236                                                }
237                                            }
238                                        }
239                                        NumEnumVariantAttributeItem::Alternatives(alternatives) => {
240                                            raw_alternative_values.extend(alternatives.expressions);
241                                            alt_attr_ref.push(attribute);
242                                        }
243                                    }
244                                }
245                            }
246                            Err(err) => {
247                                if cfg!(not(feature = "complex-expressions")) {
248                                    let tokens = attribute.meta.to_token_stream();
249
250                                    let attribute_str = format!("{}", tokens);
251                                    if attribute_str.contains("alternatives")
252                                        && attribute_str.contains("..")
253                                    {
254                                        // Give a nice error message suggesting how to fix the problem.
255                                        die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
256                                    }
257                                }
258                                die!(attribute =>
259                                    format!("Invalid attribute: {}", err)
260                                );
261                            }
262                        }
263                    }
264                }
265
266                if !is_catch_all {
267                    match &variant.fields {
268                        Fields::Named(_) | Fields::Unnamed(_) => {
269                            die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
270                        }
271                        Fields::Unit => {}
272                    }
273                }
274
275                let discriminant_value = parse_discriminant(&discriminant)?;
276
277                // Check for collision.
278                // We can't do const evaluation, or even compare arbitrary Exprs,
279                // so unfortunately we can't check for duplicates.
280                // That's not the end of the world, just we'll end up with compile errors for
281                // matches with duplicate branches in generated code instead of nice friendly error messages.
282                if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
283                    if discriminant_int_val_set.contains(&canonical_value_int) {
284                        die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
285                    }
286                }
287
288                // Deal with the alternative values.
289                let mut flattened_alternative_values = Vec::new();
290                let mut flattened_raw_alternative_values = Vec::new();
291                for raw_alternative_value in raw_alternative_values {
292                    let expanded_values = parse_alternative_values(&raw_alternative_value)?;
293                    for expanded_value in expanded_values {
294                        flattened_alternative_values.push(expanded_value);
295                        flattened_raw_alternative_values.push(raw_alternative_value.clone())
296                    }
297                }
298
299                if !flattened_alternative_values.is_empty() {
300                    let alternate_int_values = flattened_alternative_values
301                        .into_iter()
302                        .map(|v| {
303                            match v {
304                                DiscriminantValue::Literal(value) => Ok(value),
305                                DiscriminantValue::Expr(expr) => {
306                                    if let Expr::Range(_) = expr {
307                                        if cfg!(not(feature = "complex-expressions")) {
308                                            // Give a nice error message suggesting how to fix the problem.
309                                            die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
310                                        }
311                                    }
312                                    // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
313                                    // We could probably allow them, but there doesn't seem to be much of a use-case,
314                                    // and it's easier to give good error messages about duplicate values this way,
315                                    // rather than rustc errors on conflicting match branches.
316                                    die!(expr => "Only literals are allowed as num_enum alternate values".to_string())
317                                },
318                            }
319                        })
320                        .collect::<Result<Vec<i128>>>()?;
321                    let mut sorted_alternate_int_values = alternate_int_values.clone();
322                    sorted_alternate_int_values.sort_unstable();
323                    let sorted_alternate_int_values = sorted_alternate_int_values;
324
325                    // Check if the current discriminant is not in the alternative values.
326                    if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
327                        if let Some(index) = alternate_int_values
328                            .iter()
329                            .position(|&x| x == canonical_value_int)
330                        {
331                            die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
332                        }
333                    }
334
335                    // Search for duplicates, the vec is sorted. Warn about them.
336                    if (1..sorted_alternate_int_values.len()).any(|i| {
337                        sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
338                    }) {
339                        let attr = *alt_attr_ref.last().unwrap();
340                        die!(attr => "There is duplication in the alternative values");
341                    }
342                    // Search if those discriminant_int_val_set where already attributed.
343                    // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
344                    if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
345                        if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
346                            for (index, val) in alternate_int_values.iter().enumerate() {
347                                if discriminant_int_val_set.contains(val) {
348                                    die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
349                                }
350                            }
351                        }
352                    }
353
354                    // Reconstruct the alternative_values vec of Expr but sorted.
355                    flattened_raw_alternative_values = sorted_alternate_int_values
356                        .iter()
357                        .map(|val| literal(val.to_owned()))
358                        .collect();
359
360                    // Add the alternative values to the the set to keep track.
361                    discriminant_int_val_set.extend(sorted_alternate_int_values);
362                }
363
364                // Add the current discriminant to the the set to keep track.
365                if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
366                    discriminant_int_val_set.insert(canonical_value_int);
367                }
368
369                variants.push(VariantInfo {
370                    ident,
371                    is_default,
372                    is_catch_all,
373                    canonical_value: discriminant,
374                    alternative_values: flattened_raw_alternative_values,
375                });
376
377                // Get the next value for the discriminant.
378                next_discriminant = match discriminant_value {
379                    DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
380                    DiscriminantValue::Expr(expr) => {
381                        parse_quote! {
382                            #repr::wrapping_add(#expr, 1)
383                        }
384                    }
385                }
386            }
387
388            let error_type_info = maybe_error_type.unwrap_or_else(|| {
389                let crate_name = Ident::new(&get_crate_name(), Span::call_site());
390                ErrorType {
391                    name: parse_quote! {
392                        ::#crate_name::TryFromPrimitiveError<Self>
393                    },
394                    constructor: parse_quote! {
395                        ::#crate_name::TryFromPrimitiveError::<Self>::new
396                    },
397                }
398            });
399
400            EnumInfo {
401                name,
402                repr,
403                variants,
404                error_type_info,
405            }
406        })
407    }
408}
409
410fn literal(i: i128) -> Expr {
411    Expr::Lit(ExprLit {
412        lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
413        attrs: vec![],
414    })
415}
416
417enum DiscriminantValue {
418    Literal(i128),
419    Expr(Expr),
420}
421
422fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
423    let mut sign = 1;
424    let mut unsigned_expr = val_exp;
425    if let Expr::Unary(ExprUnary {
426        op: UnOp::Neg(..),
427        expr,
428        ..
429    }) = val_exp
430    {
431        unsigned_expr = expr;
432        sign = -1;
433    }
434    if let Expr::Lit(ExprLit {
435        lit: Lit::Int(ref lit_int),
436        ..
437    }) = unsigned_expr
438    {
439        Ok(DiscriminantValue::Literal(
440            sign * lit_int.base10_parse::<i128>()?,
441        ))
442    } else {
443        Ok(DiscriminantValue::Expr(val_exp.clone()))
444    }
445}
446
447#[cfg(feature = "complex-expressions")]
448fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
449    fn range_expr_value_to_number(
450        parent_range_expr: &Expr,
451        range_bound_value: &Option<Box<Expr>>,
452    ) -> Result<i128> {
453        // Avoid needing to calculate what the lower and upper bound would be - these are type dependent,
454        // and also may not be obvious in context (e.g. an omitted bound could reasonably mean "from the last discriminant" or "from the lower bound of the type").
455        if let Some(range_bound_value) = range_bound_value {
456            let range_bound_value = parse_discriminant(range_bound_value.as_ref())?;
457            // If non-literals are used, we can't expand to the mapped values, so can't write a nice match statement or do exhaustiveness checking.
458            // Require literals instead.
459            if let DiscriminantValue::Literal(value) = range_bound_value {
460                return Ok(value);
461            }
462        }
463        die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals")
464    }
465
466    if let Expr::Range(syn::ExprRange {
467        start, end, limits, ..
468    }) = val_expr
469    {
470        let lower = range_expr_value_to_number(val_expr, start)?;
471        let upper = range_expr_value_to_number(val_expr, end)?;
472        // While this is technically allowed in Rust, and results in an empty range, it's almost certainly a mistake in this context.
473        if lower > upper {
474            die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound");
475        }
476        let mut values = Vec::with_capacity((upper - lower) as usize);
477        let mut next = lower;
478        loop {
479            match limits {
480                syn::RangeLimits::HalfOpen(..) => {
481                    if next == upper {
482                        break;
483                    }
484                }
485                syn::RangeLimits::Closed(..) => {
486                    if next > upper {
487                        break;
488                    }
489                }
490            }
491            values.push(DiscriminantValue::Literal(next));
492            next += 1;
493        }
494        return Ok(values);
495    }
496    parse_discriminant(val_expr).map(|v| vec![v])
497}
498
499#[cfg(not(feature = "complex-expressions"))]
500fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
501    parse_discriminant(val_expr).map(|v| vec![v])
502}
503
504pub(crate) struct VariantInfo {
505    ident: Ident,
506    is_default: bool,
507    is_catch_all: bool,
508    canonical_value: Expr,
509    alternative_values: Vec<Expr>,
510}
511
512impl VariantInfo {
513    fn all_values(&self) -> impl Iterator<Item = &Expr> {
514        ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
515    }
516}
517
518pub(crate) struct ErrorType {
519    pub(crate) name: Path,
520    pub(crate) constructor: Path,
521}
522
523impl From<ErrorTypeAttribute> for ErrorType {
524    fn from(attribute: ErrorTypeAttribute) -> Self {
525        Self {
526            name: attribute.name.path,
527            constructor: attribute.constructor.path,
528        }
529    }
530}
531
532#[cfg(feature = "proc-macro-crate")]
533pub(crate) fn get_crate_name() -> String {
534    let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
535        eprintln!("Warning: {}\n    => defaulting to `num_enum`", err,);
536        proc_macro_crate::FoundCrate::Itself
537    });
538
539    match found_crate {
540        proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
541        proc_macro_crate::FoundCrate::Name(name) => name,
542    }
543}
544
545// Don't depend on proc-macro-crate in no_std environments because it causes an awkward dependency
546// on serde with std.
547//
548// no_std dependees on num_enum cannot rename the num_enum crate when they depend on it. Sorry.
549//
550// See https://github.com/illicitonion/num_enum/issues/18
551#[cfg(not(feature = "proc-macro-crate"))]
552pub(crate) fn get_crate_name() -> String {
553    String::from("num_enum")
554}