shadow_rs/network/
dns.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::fmt::Display;
4use std::fs::File;
5use std::io::Write;
6use std::net::Ipv4Addr;
7use std::os::fd::AsRawFd;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11// The memfd syscall is not supported in our miri test environment.
12#[cfg(not(miri))]
13use rustix::fs::MemfdFlags;
14use shadow_shim_helper_rs::HostId;
15
16#[derive(Debug)]
17struct Database {
18    // We can use `String` here because [`crate::core::configuration::HostName`] limits the
19    // configured host names to a subset of ascii, which are always valid utf-8.
20    name_index: HashMap<String, Arc<Record>>,
21    addr_index: HashMap<Ipv4Addr, Arc<Record>>,
22}
23
24#[derive(Debug)]
25struct Record {
26    id: HostId,
27    addr: Ipv4Addr,
28    name: String,
29}
30
31#[derive(Debug, PartialEq)]
32pub enum RegistrationError {
33    BroadcastAddrInvalid,
34    LoopbackAddrInvalid(Ipv4Addr),
35    MulticastAddrInvalid(Ipv4Addr),
36    UnspecifiedAddrInvalid,
37    NameInvalid(String),
38    AddrExists(Ipv4Addr),
39    NameExists(String),
40}
41
42impl Display for RegistrationError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            RegistrationError::BroadcastAddrInvalid => write!(
46                f,
47                "broadcast address '{}' is invalid in DNS",
48                Ipv4Addr::BROADCAST
49            ),
50            RegistrationError::LoopbackAddrInvalid(addr) => {
51                write!(f, "loopback address '{addr}' is invalid in DNS",)
52            }
53            RegistrationError::MulticastAddrInvalid(addr) => {
54                write!(f, "multicast address '{addr}' is invalid in DNS")
55            }
56            RegistrationError::UnspecifiedAddrInvalid => write!(
57                f,
58                "unspecified address '{}' is invalid in DNS",
59                Ipv4Addr::UNSPECIFIED
60            ),
61            RegistrationError::NameInvalid(name) => write!(f, "name '{name}' is invalid in DNS"),
62            RegistrationError::NameExists(name) => {
63                write!(
64                    f,
65                    "a DNS registration record already exists for name '{name}'"
66                )
67            }
68            RegistrationError::AddrExists(addr) => {
69                write!(
70                    f,
71                    "a DNS registration record already exists for address '{addr}'"
72                )
73            }
74        }
75    }
76}
77
78impl std::error::Error for RegistrationError {}
79
80#[derive(Debug)]
81pub struct DnsBuilder {
82    db: Database,
83}
84
85impl DnsBuilder {
86    pub fn new() -> Self {
87        Self {
88            db: Database {
89                name_index: HashMap::new(),
90                addr_index: HashMap::new(),
91            },
92        }
93    }
94
95    pub fn register(
96        &mut self,
97        id: HostId,
98        addr: Ipv4Addr,
99        name: String,
100    ) -> Result<(), RegistrationError> {
101        // Make sure we don't register reserved addresses or names.
102        if addr.is_unspecified() {
103            return Err(RegistrationError::UnspecifiedAddrInvalid);
104        } else if addr.is_loopback() {
105            return Err(RegistrationError::LoopbackAddrInvalid(addr));
106        } else if addr.is_broadcast() {
107            return Err(RegistrationError::BroadcastAddrInvalid);
108        } else if addr.is_multicast() {
109            return Err(RegistrationError::MulticastAddrInvalid(addr));
110        } else if name.eq_ignore_ascii_case("localhost") {
111            return Err(RegistrationError::NameInvalid(name));
112        }
113
114        // A single HostId is allowed to register multiple name/addr mappings,
115        // but only vacant addresses and names are allowed.
116        match self.db.addr_index.entry(addr) {
117            Entry::Occupied(_) => Err(RegistrationError::AddrExists(addr)),
118            Entry::Vacant(addr_entry) => match self.db.name_index.entry(name.clone()) {
119                Entry::Occupied(_) => Err(RegistrationError::NameExists(name)),
120                Entry::Vacant(name_entry) => {
121                    let record = Arc::new(Record { id, addr, name });
122                    addr_entry.insert(record.clone());
123                    name_entry.insert(record);
124                    Ok(())
125                }
126            },
127        }
128    }
129
130    pub fn into_dns(self) -> std::io::Result<Dns> {
131        // The memfd syscall is not supported in our miri test environment.
132        #[cfg(miri)]
133        let mut file = tempfile::tempfile()?;
134        #[cfg(not(miri))]
135        let mut file = {
136            let name = format!("shadow_dns_hosts_file_{}", std::process::id());
137            File::from(rustix::fs::memfd_create(name, MemfdFlags::CLOEXEC)?)
138        };
139
140        // Sort the records to produce deterministic ordering in the hosts file.
141        let mut records: Vec<&Arc<Record>> = self.db.addr_index.values().collect();
142        // records.sort_by(|a, b| a.addr.cmp(&b.addr));
143        records.sort_by_key(|x| x.addr);
144
145        writeln!(file, "127.0.0.1 localhost")?;
146        for record in records.iter() {
147            // Make it easier to debug if somehow we ever got a name with whitespace.
148            assert!(!record.name.as_bytes().iter().any(u8::is_ascii_whitespace));
149            writeln!(file, "{} {}", record.addr, record.name)?;
150        }
151
152        Ok(Dns {
153            db: self.db,
154            hosts_file: file,
155        })
156    }
157}
158
159impl Default for DnsBuilder {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165#[derive(Debug)]
166pub struct Dns {
167    db: Database,
168    // Keep this handle while Dns is valid to prevent closing the file
169    // containing the hosts database in /etc/hosts format.
170    hosts_file: File,
171}
172
173impl Dns {
174    pub fn addr_to_host_id(&self, addr: Ipv4Addr) -> Option<HostId> {
175        self.db.addr_index.get(&addr).map(|record| record.id)
176    }
177
178    #[cfg(test)]
179    fn addr_to_name(&self, addr: Ipv4Addr) -> Option<&str> {
180        self.db
181            .addr_index
182            .get(&addr)
183            .map(|record| record.name.as_str())
184    }
185
186    pub fn name_to_addr(&self, name: &str) -> Option<Ipv4Addr> {
187        self.db.name_index.get(name).map(|record| record.addr)
188    }
189
190    pub fn hosts_path(&self) -> PathBuf {
191        PathBuf::from(format!("/proc/self/fd/{}", self.hosts_file.as_raw_fd()))
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    fn host_a() -> (HostId, Ipv4Addr, String) {
200        let id = HostId::from(0);
201        let addr = Ipv4Addr::new(100, 1, 2, 3);
202        let name = String::from("myhost");
203        (id, addr, name)
204    }
205
206    fn host_b() -> (HostId, Ipv4Addr, String) {
207        let id = HostId::from(1);
208        let addr = Ipv4Addr::new(200, 3, 2, 1);
209        let name = String::from("theirhost");
210        (id, addr, name)
211    }
212
213    #[test]
214    fn register() {
215        let (id_a, addr_a, name_a) = host_a();
216        let (id_b, addr_b, name_b) = host_b();
217
218        let mut builder = DnsBuilder::new();
219
220        assert!(builder.register(id_a, addr_a, name_a.clone()).is_ok());
221
222        assert_eq!(
223            builder.register(id_b, Ipv4Addr::UNSPECIFIED, name_b.clone()),
224            Err(RegistrationError::UnspecifiedAddrInvalid)
225        );
226        assert_eq!(
227            builder.register(id_b, Ipv4Addr::BROADCAST, name_b.clone()),
228            Err(RegistrationError::BroadcastAddrInvalid)
229        );
230        let multicast_example_addr = Ipv4Addr::new(224, 0, 0, 1);
231        assert_eq!(
232            // Multicast addresses not allowed.
233            builder.register(id_b, multicast_example_addr, name_b.clone()),
234            Err(RegistrationError::MulticastAddrInvalid(
235                multicast_example_addr
236            ))
237        );
238        assert_eq!(
239            builder.register(id_b, Ipv4Addr::LOCALHOST, name_b.clone()),
240            Err(RegistrationError::LoopbackAddrInvalid(Ipv4Addr::LOCALHOST))
241        );
242        let localhost_string = String::from("localhost");
243        assert_eq!(
244            builder.register(id_b, addr_b, localhost_string.clone()),
245            Err(RegistrationError::NameInvalid(localhost_string))
246        );
247        assert_eq!(
248            builder.register(id_b, addr_a, name_b.clone()),
249            Err(RegistrationError::AddrExists(addr_a))
250        );
251        assert_eq!(
252            builder.register(id_b, addr_b, name_a.clone()),
253            Err(RegistrationError::NameExists(name_a))
254        );
255
256        assert!(builder.register(id_b, addr_b, name_b.clone()).is_ok());
257    }
258
259    #[test]
260    fn lookups() {
261        let (id_a, addr_a, name_a) = host_a();
262        let (id_b, addr_b, name_b) = host_b();
263
264        let mut builder = DnsBuilder::new();
265        builder.register(id_a, addr_a, name_a.clone()).unwrap();
266        builder.register(id_b, addr_b, name_b.clone()).unwrap();
267        let dns = builder.into_dns().unwrap();
268
269        assert_eq!(dns.addr_to_host_id(addr_a), Some(id_a));
270        assert_eq!(dns.addr_to_host_id(addr_b), Some(id_b));
271        assert_eq!(dns.addr_to_host_id(Ipv4Addr::new(1, 2, 3, 4)), None);
272
273        assert_eq!(dns.addr_to_name(addr_a), Some(name_a.as_str()));
274        assert_eq!(dns.addr_to_name(addr_b), Some(name_b.as_str()));
275        assert_eq!(dns.addr_to_name(Ipv4Addr::new(1, 2, 3, 4)), None);
276
277        assert_eq!(dns.name_to_addr(&name_a), Some(addr_a));
278        assert_eq!(dns.name_to_addr(&name_b), Some(addr_b));
279        assert_eq!(dns.name_to_addr("empty"), None);
280        assert_eq!(dns.name_to_addr("localhost"), None);
281    }
282
283    #[test]
284    #[cfg_attr(miri, ignore)]
285    fn hosts_file() {
286        let (id_a, addr_a, name_a) = host_a();
287        let (id_b, addr_b, name_b) = host_b();
288
289        let mut builder = DnsBuilder::new();
290        builder.register(id_a, addr_a, name_a.clone()).unwrap();
291        builder.register(id_b, addr_b, name_b.clone()).unwrap();
292        let dns = builder.into_dns().unwrap();
293
294        let contents = std::fs::read_to_string(dns.hosts_path()).unwrap();
295
296        let expected = "127.0.0.1 localhost\n100.1.2.3 myhost\n200.3.2.1 theirhost\n";
297        assert_eq!(contents.as_str(), expected);
298        let unexpected = "127.0.0.1 localhost\n200.3.2.1 theirhost\n100.1.2.3 myhost\n";
299        assert_ne!(contents.as_str(), unexpected);
300    }
301}