1use crate::enum_attributes::ErrorTypeAttribute;
2use crate::utils::die;
3use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes};
4use proc_macro2::Span;
5use quote::{format_ident, ToTokens};
6use std::collections::BTreeSet;
7use syn::{
8 parse::{Parse, ParseStream},
9 parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit,
10 LitInt, Meta, Path, Result, UnOp,
11};
12
13pub(crate) struct EnumInfo {
14 pub(crate) name: Ident,
15 pub(crate) repr: Ident,
16 pub(crate) variants: Vec<VariantInfo>,
17 pub(crate) error_type_info: ErrorType,
18}
19
20impl EnumInfo {
21 pub(crate) fn is_naturally_exhaustive(&self) -> Result<bool> {
24 let repr_str = self.repr.to_string();
25 if !repr_str.is_empty() {
26 let suffix = repr_str
27 .strip_prefix('i')
28 .or_else(|| repr_str.strip_prefix('u'));
29 if let Some(suffix) = suffix {
30 if suffix == "size" {
31 return Ok(false);
32 } else if let Ok(bits) = suffix.parse::<u32>() {
33 let variants = 1usize.checked_shl(bits);
34 return Ok(variants.map_or(false, |v| {
35 v == self
36 .variants
37 .iter()
38 .map(|v| v.alternative_values.len() + 1)
39 .sum()
40 }));
41 }
42 }
43 }
44 die!(self.repr.clone() => "Failed to parse repr into bit size");
45 }
46
47 pub(crate) fn default(&self) -> Option<&Ident> {
48 self.variants
49 .iter()
50 .find(|info| info.is_default)
51 .map(|info| &info.ident)
52 }
53
54 pub(crate) fn catch_all(&self) -> Option<&Ident> {
55 self.variants
56 .iter()
57 .find(|info| info.is_catch_all)
58 .map(|info| &info.ident)
59 }
60
61 pub(crate) fn variant_idents(&self) -> Vec<Ident> {
62 self.variants
63 .iter()
64 .filter(|variant| !variant.is_catch_all)
65 .map(|variant| variant.ident.clone())
66 .collect()
67 }
68
69 pub(crate) fn expression_idents(&self) -> Vec<Vec<Ident>> {
70 self.variants
71 .iter()
72 .filter(|variant| !variant.is_catch_all)
73 .map(|info| {
74 let indices = 0..(info.alternative_values.len() + 1);
75 indices
76 .map(|index| format_ident!("{}__num_enum_{}__", info.ident, index))
77 .collect()
78 })
79 .collect()
80 }
81
82 pub(crate) fn variant_expressions(&self) -> Vec<Vec<Expr>> {
83 self.variants
84 .iter()
85 .filter(|variant| !variant.is_catch_all)
86 .map(|variant| variant.all_values().cloned().collect())
87 .collect()
88 }
89
90 fn parse_attrs<Attrs: Iterator<Item = Attribute>>(
91 attrs: Attrs,
92 ) -> Result<(Ident, Option<ErrorType>)> {
93 let mut maybe_repr = None;
94 let mut maybe_error_type = None;
95 for attr in attrs {
96 if let Meta::List(meta_list) = &attr.meta {
97 if let Some(ident) = meta_list.path.get_ident() {
98 if ident == "repr" {
99 let mut nested = meta_list.tokens.clone().into_iter();
100 let repr_tree = match (nested.next(), nested.next()) {
101 (Some(repr_tree), None) => repr_tree,
102 _ => die!(attr =>
103 "Expected exactly one `repr` argument"
104 ),
105 };
106 let repr_ident: Ident = parse_quote! {
107 #repr_tree
108 };
109 if repr_ident == "C" {
110 die!(repr_ident =>
111 "repr(C) doesn't have a well defined size"
112 );
113 } else {
114 maybe_repr = Some(repr_ident);
115 }
116 } else if ident == "num_enum" {
117 let attributes =
118 attr.parse_args_with(crate::enum_attributes::Attributes::parse)?;
119 if let Some(error_type) = attributes.error_type {
120 if maybe_error_type.is_some() {
121 die!(attr => "At most one num_enum error_type attribute may be specified");
122 }
123 maybe_error_type = Some(error_type.into());
124 }
125 }
126 }
127 }
128 }
129 if maybe_repr.is_none() {
130 die!("Missing `#[repr({Integer})]` attribute");
131 }
132 Ok((maybe_repr.unwrap(), maybe_error_type))
133 }
134}
135
136impl Parse for EnumInfo {
137 fn parse(input: ParseStream) -> Result<Self> {
138 Ok({
139 let input: DeriveInput = input.parse()?;
140 let name = input.ident;
141 let data = match input.data {
142 Data::Enum(data) => data,
143 Data::Union(data) => die!(data.union_token => "Expected enum but found union"),
144 Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"),
145 };
146
147 let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?;
148
149 let mut variants: Vec<VariantInfo> = vec![];
150 let mut has_default_variant: bool = false;
151 let mut has_catch_all_variant: bool = false;
152
153 let mut discriminant_int_val_set = BTreeSet::new();
155
156 let mut next_discriminant = literal(0);
157 for variant in data.variants.into_iter() {
158 let ident = variant.ident.clone();
159
160 let discriminant = match &variant.discriminant {
161 Some(d) => d.1.clone(),
162 None => next_discriminant.clone(),
163 };
164
165 let mut raw_alternative_values: Vec<Expr> = vec![];
166 let mut alt_attr_ref: Vec<&Attribute> = vec![];
168
169 let mut is_default: bool = false;
173 let mut is_catch_all: bool = false;
174
175 for attribute in &variant.attrs {
176 if attribute.path().is_ident("default") {
177 if has_default_variant {
178 die!(attribute =>
179 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
180 );
181 } else if has_catch_all_variant {
182 die!(attribute =>
183 "Attribute `default` is mutually exclusive with `catch_all`"
184 );
185 }
186 is_default = true;
187 has_default_variant = true;
188 }
189
190 if attribute.path().is_ident("num_enum") {
191 match attribute.parse_args_with(NumEnumVariantAttributes::parse) {
192 Ok(variant_attributes) => {
193 for variant_attribute in variant_attributes.items {
194 match variant_attribute {
195 NumEnumVariantAttributeItem::Default(default) => {
196 if has_default_variant {
197 die!(default.keyword =>
198 "Multiple variants marked `#[default]` or `#[num_enum(default)]` found"
199 );
200 } else if has_catch_all_variant {
201 die!(default.keyword =>
202 "Attribute `default` is mutually exclusive with `catch_all`"
203 );
204 }
205 is_default = true;
206 has_default_variant = true;
207 }
208 NumEnumVariantAttributeItem::CatchAll(catch_all) => {
209 if has_catch_all_variant {
210 die!(catch_all.keyword =>
211 "Multiple variants marked with `#[num_enum(catch_all)]`"
212 );
213 } else if has_default_variant {
214 die!(catch_all.keyword =>
215 "Attribute `catch_all` is mutually exclusive with `default`"
216 );
217 }
218
219 match variant
220 .fields
221 .iter()
222 .collect::<Vec<_>>()
223 .as_slice()
224 {
225 [syn::Field {
226 ty: syn::Type::Path(syn::TypePath { path, .. }),
227 ..
228 }] if path.is_ident(&repr) => {
229 is_catch_all = true;
230 has_catch_all_variant = true;
231 }
232 _ => {
233 die!(catch_all.keyword =>
234 "Variant with `catch_all` must be a tuple with exactly 1 field matching the repr type"
235 );
236 }
237 }
238 }
239 NumEnumVariantAttributeItem::Alternatives(alternatives) => {
240 raw_alternative_values.extend(alternatives.expressions);
241 alt_attr_ref.push(attribute);
242 }
243 }
244 }
245 }
246 Err(err) => {
247 if cfg!(not(feature = "complex-expressions")) {
248 let tokens = attribute.meta.to_token_stream();
249
250 let attribute_str = format!("{}", tokens);
251 if attribute_str.contains("alternatives")
252 && attribute_str.contains("..")
253 {
254 die!(attribute => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
256 }
257 }
258 die!(attribute =>
259 format!("Invalid attribute: {}", err)
260 );
261 }
262 }
263 }
264 }
265
266 if !is_catch_all {
267 match &variant.fields {
268 Fields::Named(_) | Fields::Unnamed(_) => {
269 die!(variant => format!("`{}` only supports unit variants (with no associated data), but `{}::{}` was not a unit variant.", get_crate_name(), name, ident));
270 }
271 Fields::Unit => {}
272 }
273 }
274
275 let discriminant_value = parse_discriminant(&discriminant)?;
276
277 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
283 if discriminant_int_val_set.contains(&canonical_value_int) {
284 die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
285 }
286 }
287
288 let mut flattened_alternative_values = Vec::new();
290 let mut flattened_raw_alternative_values = Vec::new();
291 for raw_alternative_value in raw_alternative_values {
292 let expanded_values = parse_alternative_values(&raw_alternative_value)?;
293 for expanded_value in expanded_values {
294 flattened_alternative_values.push(expanded_value);
295 flattened_raw_alternative_values.push(raw_alternative_value.clone())
296 }
297 }
298
299 if !flattened_alternative_values.is_empty() {
300 let alternate_int_values = flattened_alternative_values
301 .into_iter()
302 .map(|v| {
303 match v {
304 DiscriminantValue::Literal(value) => Ok(value),
305 DiscriminantValue::Expr(expr) => {
306 if let Expr::Range(_) = expr {
307 if cfg!(not(feature = "complex-expressions")) {
308 die!(expr => "Ranges are only supported as num_enum alternate values if the `complex-expressions` feature of the crate `num_enum` is enabled".to_string())
310 }
311 }
312 die!(expr => "Only literals are allowed as num_enum alternate values".to_string())
317 },
318 }
319 })
320 .collect::<Result<Vec<i128>>>()?;
321 let mut sorted_alternate_int_values = alternate_int_values.clone();
322 sorted_alternate_int_values.sort_unstable();
323 let sorted_alternate_int_values = sorted_alternate_int_values;
324
325 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
327 if let Some(index) = alternate_int_values
328 .iter()
329 .position(|&x| x == canonical_value_int)
330 {
331 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
332 }
333 }
334
335 if (1..sorted_alternate_int_values.len()).any(|i| {
337 sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
338 }) {
339 let attr = *alt_attr_ref.last().unwrap();
340 die!(attr => "There is duplication in the alternative values");
341 }
342 if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
345 if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
346 for (index, val) in alternate_int_values.iter().enumerate() {
347 if discriminant_int_val_set.contains(val) {
348 die!(&flattened_raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
349 }
350 }
351 }
352 }
353
354 flattened_raw_alternative_values = sorted_alternate_int_values
356 .iter()
357 .map(|val| literal(val.to_owned()))
358 .collect();
359
360 discriminant_int_val_set.extend(sorted_alternate_int_values);
362 }
363
364 if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
366 discriminant_int_val_set.insert(canonical_value_int);
367 }
368
369 variants.push(VariantInfo {
370 ident,
371 is_default,
372 is_catch_all,
373 canonical_value: discriminant,
374 alternative_values: flattened_raw_alternative_values,
375 });
376
377 next_discriminant = match discriminant_value {
379 DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
380 DiscriminantValue::Expr(expr) => {
381 parse_quote! {
382 #repr::wrapping_add(#expr, 1)
383 }
384 }
385 }
386 }
387
388 let error_type_info = maybe_error_type.unwrap_or_else(|| {
389 let crate_name = Ident::new(&get_crate_name(), Span::call_site());
390 ErrorType {
391 name: parse_quote! {
392 ::#crate_name::TryFromPrimitiveError<Self>
393 },
394 constructor: parse_quote! {
395 ::#crate_name::TryFromPrimitiveError::<Self>::new
396 },
397 }
398 });
399
400 EnumInfo {
401 name,
402 repr,
403 variants,
404 error_type_info,
405 }
406 })
407 }
408}
409
410fn literal(i: i128) -> Expr {
411 Expr::Lit(ExprLit {
412 lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
413 attrs: vec![],
414 })
415}
416
417enum DiscriminantValue {
418 Literal(i128),
419 Expr(Expr),
420}
421
422fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
423 let mut sign = 1;
424 let mut unsigned_expr = val_exp;
425 if let Expr::Unary(ExprUnary {
426 op: UnOp::Neg(..),
427 expr,
428 ..
429 }) = val_exp
430 {
431 unsigned_expr = expr;
432 sign = -1;
433 }
434 if let Expr::Lit(ExprLit {
435 lit: Lit::Int(ref lit_int),
436 ..
437 }) = unsigned_expr
438 {
439 Ok(DiscriminantValue::Literal(
440 sign * lit_int.base10_parse::<i128>()?,
441 ))
442 } else {
443 Ok(DiscriminantValue::Expr(val_exp.clone()))
444 }
445}
446
447#[cfg(feature = "complex-expressions")]
448fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
449 fn range_expr_value_to_number(
450 parent_range_expr: &Expr,
451 range_bound_value: &Option<Box<Expr>>,
452 ) -> Result<i128> {
453 if let Some(range_bound_value) = range_bound_value {
456 let range_bound_value = parse_discriminant(range_bound_value.as_ref())?;
457 if let DiscriminantValue::Literal(value) = range_bound_value {
460 return Ok(value);
461 }
462 }
463 die!(parent_range_expr => "When ranges are used for alternate values, both bounds most be explicitly specified numeric literals")
464 }
465
466 if let Expr::Range(syn::ExprRange {
467 start, end, limits, ..
468 }) = val_expr
469 {
470 let lower = range_expr_value_to_number(val_expr, start)?;
471 let upper = range_expr_value_to_number(val_expr, end)?;
472 if lower > upper {
474 die!(val_expr => "When using ranges for alternate values, upper bound must not be less than lower bound");
475 }
476 let mut values = Vec::with_capacity((upper - lower) as usize);
477 let mut next = lower;
478 loop {
479 match limits {
480 syn::RangeLimits::HalfOpen(..) => {
481 if next == upper {
482 break;
483 }
484 }
485 syn::RangeLimits::Closed(..) => {
486 if next > upper {
487 break;
488 }
489 }
490 }
491 values.push(DiscriminantValue::Literal(next));
492 next += 1;
493 }
494 return Ok(values);
495 }
496 parse_discriminant(val_expr).map(|v| vec![v])
497}
498
499#[cfg(not(feature = "complex-expressions"))]
500fn parse_alternative_values(val_expr: &Expr) -> Result<Vec<DiscriminantValue>> {
501 parse_discriminant(val_expr).map(|v| vec![v])
502}
503
504pub(crate) struct VariantInfo {
505 ident: Ident,
506 is_default: bool,
507 is_catch_all: bool,
508 canonical_value: Expr,
509 alternative_values: Vec<Expr>,
510}
511
512impl VariantInfo {
513 fn all_values(&self) -> impl Iterator<Item = &Expr> {
514 ::core::iter::once(&self.canonical_value).chain(self.alternative_values.iter())
515 }
516}
517
518pub(crate) struct ErrorType {
519 pub(crate) name: Path,
520 pub(crate) constructor: Path,
521}
522
523impl From<ErrorTypeAttribute> for ErrorType {
524 fn from(attribute: ErrorTypeAttribute) -> Self {
525 Self {
526 name: attribute.name.path,
527 constructor: attribute.constructor.path,
528 }
529 }
530}
531
532#[cfg(feature = "proc-macro-crate")]
533pub(crate) fn get_crate_name() -> String {
534 let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| {
535 eprintln!("Warning: {}\n => defaulting to `num_enum`", err,);
536 proc_macro_crate::FoundCrate::Itself
537 });
538
539 match found_crate {
540 proc_macro_crate::FoundCrate::Itself => String::from("num_enum"),
541 proc_macro_crate::FoundCrate::Name(name) => name,
542 }
543}
544
545#[cfg(not(feature = "proc-macro-crate"))]
552pub(crate) fn get_crate_name() -> String {
553 String::from("num_enum")
554}