naked_function_macro/
naked.rs

1use proc_macro2::{Ident, Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    punctuated::Punctuated, Abi, AttrStyle, Attribute, Expr, ExprLit, ExprMacro, ForeignItem,
5    ForeignItemFn, Item, ItemFn, ItemForeignMod, ItemMacro, Lit, LitStr, Macro, MacroDelimiter,
6    Meta, MetaNameValue, Result, Signature, Token,
7};
8
9use crate::asm::{extract_asm, AsmOperand};
10
11/// Sanity checks the function signature.
12fn validate_sig(sig: &Signature) -> Result<()> {
13    if let Some(constness) = sig.constness {
14        bail!(constness, "#[naked] is not supported on const functions");
15    }
16    if let Some(asyncness) = sig.asyncness {
17        bail!(asyncness, "#[naked] is not supported on async functions");
18    }
19    if sig.unsafety.is_none() {
20        bail!(sig, "#[naked] can only be used on unsafe functions");
21    }
22    match &sig.abi {
23        Some(Abi {
24            extern_token: _,
25            name: Some(name),
26        }) if matches!(&*name.value(), "C" | "C-unwind") => {}
27        _ => bail!(
28            &sig.abi,
29            "#[naked] functions must be `extern \"C\"` or `extern \"C-unwind\"`"
30        ),
31    }
32    if !sig.generics.params.is_empty() {
33        bail!(
34            &sig.generics,
35            "#[naked] cannot be used with generic functions"
36        );
37    }
38    Ok(())
39}
40
41struct ParsedAttrs {
42    foreign_attrs: Vec<Attribute>,
43    cfg: Vec<Attribute>,
44    symbol: Expr,
45    link_section: Expr,
46    instruction_set: Option<TokenStream>,
47}
48
49/// Parses the attributes on the function and checks them against a whitelist
50/// of supported attributes.
51///
52/// The symbol name of the function and the linker section it will be placed in
53/// are computed here based on the function attributes.
54fn parse_attrs(ident: &Ident, attrs: &[Attribute]) -> Result<ParsedAttrs> {
55    let mut foreign_attrs = vec![];
56    let mut cfg = vec![];
57    let mut no_mangle = false;
58    let mut export_name = None;
59    let mut link_section = None;
60    let mut instruction_set = None;
61
62    // Attributes to forward to the foreign function declaration that we will
63    // generate.
64    let attr_whitelist = [
65        "doc",
66        "allow",
67        "warn",
68        "deny",
69        "forbid",
70        "deprecated",
71        "must_use",
72    ];
73
74    'outer: for attr in attrs {
75        if let AttrStyle::Inner(_) = attr.style {
76            bail!(attr, "unexpected inner attribute");
77        }
78
79        // Forward whitelisted attributes to the foreign item.
80        for whitelist in attr_whitelist {
81            if attr.path().is_ident(whitelist) {
82                foreign_attrs.push(attr.clone());
83                continue 'outer;
84            }
85        }
86
87        if attr
88            .path()
89            .segments
90            .first()
91            .map_or(false, |segment| segment.ident == "rustfmt")
92        {
93            // Ignore rustfmt attributes
94        } else if attr.path().is_ident("no_mangle") {
95            attr.meta.require_path_only()?;
96            no_mangle = true;
97        } else if attr.path().is_ident("export_name") {
98            // Pass the export_name attribute through as a #[link_section] on
99            // the foreign import declaration.
100            let name_value = attr.meta.require_name_value()?;
101            export_name = Some(name_value.value.clone());
102            let mut link_name = attr.clone();
103            link_name.meta = Meta::NameValue(MetaNameValue {
104                path: syn::parse2(quote!(link_name)).unwrap(),
105                eq_token: name_value.eq_token,
106                value: name_value.value.clone(),
107            });
108            foreign_attrs.push(link_name);
109        } else if attr.path().is_ident("link_section") {
110            let name_value = attr.meta.require_name_value()?;
111            link_section = Some(name_value.value.clone());
112        } else if attr.path().is_ident("cfg") {
113            cfg.push(attr.clone())
114        } else if attr.path().is_ident("instruction_set") {
115            instruction_set = Some(attr.meta.require_list()?.tokens.clone());
116        } else {
117            bail!(
118                attr,
119                "naked functions only support \
120                #[no_mangle], #[export_name] and #[link_section] attributes"
121            );
122        }
123    }
124
125    let symbol = if let Some(export_name) = &export_name {
126        export_name.clone()
127    } else {
128        let raw_symbol = if no_mangle {
129            ident.to_string()
130        } else {
131            format!("rust_naked_function_{}", ident.to_string())
132        };
133
134        Expr::Lit(ExprLit {
135            attrs: vec![],
136            lit: Lit::Str(LitStr::new(&raw_symbol, Span::call_site())),
137        })
138    };
139
140    // Add a #[link_name] attribute to the import pointing to our manually
141    // mangled symbol name.
142    if export_name.is_none() {
143        foreign_attrs.push(Attribute {
144            pound_token: Default::default(),
145            style: AttrStyle::Outer,
146            bracket_token: Default::default(),
147            meta: Meta::NameValue(MetaNameValue {
148                path: syn::parse2(quote!(link_name)).unwrap(),
149                eq_token: Default::default(),
150                value: symbol.clone(),
151            }),
152        });
153    }
154
155    // Use the given section if provided, otherwise use the platform
156    // default. This is usually .text.$SYMBOL, except on Mach-O targets
157    // which don't have per-symbol sections.
158    let link_section = if let Some(link_section) = link_section {
159        link_section
160    } else {
161        Expr::Macro(ExprMacro {
162            attrs: vec![],
163            mac: Macro {
164                path: syn::parse2(quote!(::naked_function::__asm_default_section)).unwrap(),
165                bang_token: Default::default(),
166                delimiter: MacroDelimiter::Paren(Default::default()),
167                tokens: symbol.to_token_stream(),
168            },
169        })
170    };
171
172    Ok(ParsedAttrs {
173        foreign_attrs,
174        cfg,
175        symbol,
176        link_section,
177        instruction_set,
178    })
179}
180
181fn emit_foreign_mod(func: &ItemFn, attrs: &ParsedAttrs) -> ItemForeignMod {
182    // Remove the ABI and unsafe from the function signature and move it to the
183    // `extern` block.
184    let sig = Signature {
185        abi: None,
186        unsafety: None,
187        ..func.sig.clone()
188    };
189    let foreign_fn = ForeignItem::Fn(ForeignItemFn {
190        attrs: {
191            let mut attrs_ = attrs.foreign_attrs.clone();
192            attrs_.extend_from_slice(&attrs.cfg[..]);
193            attrs_
194        },
195        vis: func.vis.clone(),
196        sig,
197        semi_token: Default::default(),
198    });
199    ItemForeignMod {
200        attrs: vec![],
201        unsafety: None,
202        abi: func.sig.abi.clone().unwrap(),
203        brace_token: Default::default(),
204        items: vec![foreign_fn],
205    }
206}
207
208fn emit_global_asm(attrs: &ParsedAttrs, mut asm: Punctuated<AsmOperand, Token![,]>) -> ItemMacro {
209    // Inject a prefix to the assembly code containing the necessary assembler
210    // directives to start a function.
211    let symbol = &attrs.symbol;
212    let link_section = &attrs.link_section;
213    let instruction_set = &attrs.instruction_set;
214    let prefix = syn::parse2(quote! {
215        ::naked_function::__asm_function_begin!(#symbol, #link_section, (#instruction_set))
216    })
217    .unwrap();
218    asm.insert(0, AsmOperand::Template(prefix));
219
220    // Inject a suffix at the end of the assembly code containing assembler
221    // directives to end a function.
222    let last_template = asm
223        .iter()
224        .rposition(|op| matches!(op, AsmOperand::Template(_)))
225        .unwrap();
226    let suffix = syn::parse2(quote! {
227        ::naked_function::__asm_function_end!(#symbol)
228    })
229    .unwrap();
230    asm.insert(last_template + 1, AsmOperand::Template(suffix));
231
232    let global_asm = Macro {
233        path: syn::parse2(quote!(::core::arch::global_asm)).unwrap(),
234        bang_token: Default::default(),
235        delimiter: MacroDelimiter::Paren(Default::default()),
236        tokens: asm.to_token_stream(),
237    };
238    ItemMacro {
239        attrs: attrs.cfg.clone(),
240        ident: None,
241        mac: global_asm,
242        semi_token: Some(Default::default()),
243    }
244}
245
246/// Entry point of the proc macro.
247pub fn naked_attribute(func: &ItemFn) -> Result<Vec<Item>> {
248    validate_sig(&func.sig)?;
249    let attrs = parse_attrs(&func.sig.ident, &func.attrs)?;
250    let asm = extract_asm(func)?;
251    let foreign_mod = emit_foreign_mod(func, &attrs);
252    let global_asm = emit_global_asm(&attrs, asm);
253    Ok(vec![Item::ForeignMod(foreign_mod), Item::Macro(global_asm)])
254}