enum_dispatch/
supported_generics.rs

1//! Utilities for dealing with generic arguments and parameters.
2
3/// Represents any single generic argument from e.g. `#[enum_dispatch(Ty<...>)]` that can be
4/// supported by `enum_dispatch`.
5pub enum SupportedGenericArg {
6    /// A `_` type.
7    Inferred,
8    /// A named generic argument, e.g. `T`.
9    Identifier(proc_macro2::Ident),
10    /// A const generic char, e.g. `'a'`.
11    ConstChar(syn::LitChar),
12    /// A const generic byte, e.g. `b'a'`.
13    ConstByte(syn::LitByte),
14    /// A const generic integer, e.g. `9`.
15    ConstInt(syn::LitInt),
16    /// A const generic integer, e.g. `true`.
17    ConstBool(syn::LitBool),
18}
19
20/// Represents any single generic argument from `#[enum_dispatch(Ty<...>)]` that can _not_ be
21/// supported by `enum_dispatch`.
22pub enum UnsupportedGenericArg {
23    NonIdentifierType,
24    NonIntegralConstGenericType,
25    Lifetime,
26    Constraint,
27    AssocType,
28    AssocConst,
29    Unknown,
30}
31
32impl std::fmt::Display for UnsupportedGenericArg {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::NonIdentifierType => write!(f, "Generic types in #[enum_dispatch(...)] must be identifiers"),
36            Self::NonIntegralConstGenericType => write!(f, "Non-integral const generic types in #[enum_dispatch(...)] are not supported"),
37            Self::Lifetime => write!(f, "Lifetime generics in #[enum_dispatch(...)] are not supported"),
38            Self::AssocType => write!(f, "Generic associated types in #[enum_dispatch(...)] are not supported"),
39            Self::AssocConst => write!(f, "Generic associated constants in #[enum_dispatch(...)] are not supported"),
40            Self::Constraint => write!(f, "Generic trait constraints in #[enum_dispatch(...)] are not supported"),
41            Self::Unknown => write!(f, "Unsupported generic argument syntax in #[enum_dispatch(...)]"),
42        }
43    }
44}
45
46/// Strings corresponding to types that are supported as const generics.
47const SUPPORTED_CONST_GENERIC_TYPES: &[&str] = &[
48    "u8",
49    "i8",
50    "u16",
51    "i16",
52    "u32",
53    "i32",
54    "u64",
55    "i64",
56    "u128",
57    "i128",
58    "usize",
59    "isize",
60    "char",
61    "bool",
62];
63
64/// Counts the number of supported generic parameters from an enum or trait definition.
65pub fn num_supported_generics(g: &syn::Generics) -> usize {
66    let type_generics = g.type_params().count();
67    let const_generics = g.const_params().filter(|p| {
68        if let syn::Type::Path(syn::TypePath { qself: None, path }) = &p.ty {
69            for supported_type in SUPPORTED_CONST_GENERIC_TYPES {
70                if path.is_ident(supported_type) {
71                    return true;
72                }
73            }
74        }
75        false
76    }).count();
77
78    type_generics + const_generics
79}
80
81/// Converts a `syn::GenericArgument` to a `SupportedGenericArg`, or an `UnsupportedGenericArg` if
82/// it is not supported.
83pub fn convert_to_supported_generic(generic_arg: &syn::GenericArgument) -> Result<SupportedGenericArg, (UnsupportedGenericArg, proc_macro2::Span)> {
84    use syn::spanned::Spanned as _;
85    let span = generic_arg.span();
86
87    match generic_arg {
88        syn::GenericArgument::Type(syn::Type::Path(t)) if t.qself.is_none() => {
89            if let Some(ident) = t.path.get_ident() {
90                Ok(SupportedGenericArg::Identifier(ident.clone()))
91            } else {
92                Err((UnsupportedGenericArg::NonIdentifierType, span))
93            }
94        }
95        syn::GenericArgument::Type(syn::Type::Infer(_)) => Ok(SupportedGenericArg::Inferred),
96        syn::GenericArgument::Type(_) => Err((UnsupportedGenericArg::NonIdentifierType, span)),
97        syn::GenericArgument::Const(syn::Expr::Lit(syn::ExprLit { attrs: _, lit })) => {
98            match lit {
99                syn::Lit::Byte(b) => Ok(SupportedGenericArg::ConstByte(b.clone())),
100                syn::Lit::Char(c) => Ok(SupportedGenericArg::ConstChar(c.clone())),
101                syn::Lit::Int(i) => Ok(SupportedGenericArg::ConstInt(i.clone())),
102                syn::Lit::Bool(b) => Ok(SupportedGenericArg::ConstBool(b.clone())),
103                _ => Err((UnsupportedGenericArg::NonIntegralConstGenericType, span)),
104            }
105        }
106        syn::GenericArgument::Const(_) => Err((UnsupportedGenericArg::NonIntegralConstGenericType, span)),
107        syn::GenericArgument::Lifetime(_) => Err((UnsupportedGenericArg::Lifetime, span)),
108        syn::GenericArgument::Constraint(_) => Err((UnsupportedGenericArg::Constraint, span)),
109        syn::GenericArgument::AssocType(_) => Err((UnsupportedGenericArg::AssocType, span)),
110        syn::GenericArgument::AssocConst(_) => Err((UnsupportedGenericArg::AssocConst, span)),
111        _ => Err((UnsupportedGenericArg::Unknown, span)),
112    }
113}