1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::fmt::Display;
4use std::fs::File;
5use std::io::Write;
6use std::net::Ipv4Addr;
7use std::os::fd::AsRawFd;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11#[cfg(not(miri))]
13use rustix::fs::MemfdFlags;
14use shadow_shim_helper_rs::HostId;
15
16#[derive(Debug)]
17struct Database {
18 name_index: HashMap<String, Arc<Record>>,
21 addr_index: HashMap<Ipv4Addr, Arc<Record>>,
22}
23
24#[derive(Debug)]
25struct Record {
26 id: HostId,
27 addr: Ipv4Addr,
28 name: String,
29}
30
31#[derive(Debug, PartialEq)]
32pub enum RegistrationError {
33 BroadcastAddrInvalid,
34 LoopbackAddrInvalid(Ipv4Addr),
35 MulticastAddrInvalid(Ipv4Addr),
36 UnspecifiedAddrInvalid,
37 NameInvalid(String),
38 AddrExists(Ipv4Addr),
39 NameExists(String),
40}
41
42impl Display for RegistrationError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 RegistrationError::BroadcastAddrInvalid => write!(
46 f,
47 "broadcast address '{}' is invalid in DNS",
48 Ipv4Addr::BROADCAST
49 ),
50 RegistrationError::LoopbackAddrInvalid(addr) => {
51 write!(f, "loopback address '{addr}' is invalid in DNS",)
52 }
53 RegistrationError::MulticastAddrInvalid(addr) => {
54 write!(f, "multicast address '{addr}' is invalid in DNS")
55 }
56 RegistrationError::UnspecifiedAddrInvalid => write!(
57 f,
58 "unspecified address '{}' is invalid in DNS",
59 Ipv4Addr::UNSPECIFIED
60 ),
61 RegistrationError::NameInvalid(name) => write!(f, "name '{name}' is invalid in DNS"),
62 RegistrationError::NameExists(name) => {
63 write!(
64 f,
65 "a DNS registration record already exists for name '{name}'"
66 )
67 }
68 RegistrationError::AddrExists(addr) => {
69 write!(
70 f,
71 "a DNS registration record already exists for address '{addr}'"
72 )
73 }
74 }
75 }
76}
77
78impl std::error::Error for RegistrationError {}
79
80#[derive(Debug)]
81pub struct DnsBuilder {
82 db: Database,
83}
84
85impl DnsBuilder {
86 pub fn new() -> Self {
87 Self {
88 db: Database {
89 name_index: HashMap::new(),
90 addr_index: HashMap::new(),
91 },
92 }
93 }
94
95 pub fn register(
96 &mut self,
97 id: HostId,
98 addr: Ipv4Addr,
99 name: String,
100 ) -> Result<(), RegistrationError> {
101 if addr.is_unspecified() {
103 return Err(RegistrationError::UnspecifiedAddrInvalid);
104 } else if addr.is_loopback() {
105 return Err(RegistrationError::LoopbackAddrInvalid(addr));
106 } else if addr.is_broadcast() {
107 return Err(RegistrationError::BroadcastAddrInvalid);
108 } else if addr.is_multicast() {
109 return Err(RegistrationError::MulticastAddrInvalid(addr));
110 } else if name.eq_ignore_ascii_case("localhost") {
111 return Err(RegistrationError::NameInvalid(name));
112 }
113
114 match self.db.addr_index.entry(addr) {
117 Entry::Occupied(_) => Err(RegistrationError::AddrExists(addr)),
118 Entry::Vacant(addr_entry) => match self.db.name_index.entry(name.clone()) {
119 Entry::Occupied(_) => Err(RegistrationError::NameExists(name)),
120 Entry::Vacant(name_entry) => {
121 let record = Arc::new(Record { id, addr, name });
122 addr_entry.insert(record.clone());
123 name_entry.insert(record);
124 Ok(())
125 }
126 },
127 }
128 }
129
130 pub fn into_dns(self) -> std::io::Result<Dns> {
131 #[cfg(miri)]
133 let mut file = tempfile::tempfile()?;
134 #[cfg(not(miri))]
135 let mut file = {
136 let name = format!("shadow_dns_hosts_file_{}", std::process::id());
137 File::from(rustix::fs::memfd_create(name, MemfdFlags::CLOEXEC)?)
138 };
139
140 let mut records: Vec<&Arc<Record>> = self.db.addr_index.values().collect();
142 records.sort_by_key(|x| x.addr);
144
145 writeln!(file, "127.0.0.1 localhost")?;
146 for record in records.iter() {
147 assert!(!record.name.as_bytes().iter().any(u8::is_ascii_whitespace));
149 writeln!(file, "{} {}", record.addr, record.name)?;
150 }
151
152 Ok(Dns {
153 db: self.db,
154 hosts_file: file,
155 })
156 }
157}
158
159impl Default for DnsBuilder {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165#[derive(Debug)]
166pub struct Dns {
167 db: Database,
168 hosts_file: File,
171}
172
173impl Dns {
174 pub fn addr_to_host_id(&self, addr: Ipv4Addr) -> Option<HostId> {
175 self.db.addr_index.get(&addr).map(|record| record.id)
176 }
177
178 #[cfg(test)]
179 fn addr_to_name(&self, addr: Ipv4Addr) -> Option<&str> {
180 self.db
181 .addr_index
182 .get(&addr)
183 .map(|record| record.name.as_str())
184 }
185
186 pub fn name_to_addr(&self, name: &str) -> Option<Ipv4Addr> {
187 self.db.name_index.get(name).map(|record| record.addr)
188 }
189
190 pub fn hosts_path(&self) -> PathBuf {
191 PathBuf::from(format!("/proc/self/fd/{}", self.hosts_file.as_raw_fd()))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 fn host_a() -> (HostId, Ipv4Addr, String) {
200 let id = HostId::from(0);
201 let addr = Ipv4Addr::new(100, 1, 2, 3);
202 let name = String::from("myhost");
203 (id, addr, name)
204 }
205
206 fn host_b() -> (HostId, Ipv4Addr, String) {
207 let id = HostId::from(1);
208 let addr = Ipv4Addr::new(200, 3, 2, 1);
209 let name = String::from("theirhost");
210 (id, addr, name)
211 }
212
213 #[test]
214 fn register() {
215 let (id_a, addr_a, name_a) = host_a();
216 let (id_b, addr_b, name_b) = host_b();
217
218 let mut builder = DnsBuilder::new();
219
220 assert!(builder.register(id_a, addr_a, name_a.clone()).is_ok());
221
222 assert_eq!(
223 builder.register(id_b, Ipv4Addr::UNSPECIFIED, name_b.clone()),
224 Err(RegistrationError::UnspecifiedAddrInvalid)
225 );
226 assert_eq!(
227 builder.register(id_b, Ipv4Addr::BROADCAST, name_b.clone()),
228 Err(RegistrationError::BroadcastAddrInvalid)
229 );
230 let multicast_example_addr = Ipv4Addr::new(224, 0, 0, 1);
231 assert_eq!(
232 builder.register(id_b, multicast_example_addr, name_b.clone()),
234 Err(RegistrationError::MulticastAddrInvalid(
235 multicast_example_addr
236 ))
237 );
238 assert_eq!(
239 builder.register(id_b, Ipv4Addr::LOCALHOST, name_b.clone()),
240 Err(RegistrationError::LoopbackAddrInvalid(Ipv4Addr::LOCALHOST))
241 );
242 let localhost_string = String::from("localhost");
243 assert_eq!(
244 builder.register(id_b, addr_b, localhost_string.clone()),
245 Err(RegistrationError::NameInvalid(localhost_string))
246 );
247 assert_eq!(
248 builder.register(id_b, addr_a, name_b.clone()),
249 Err(RegistrationError::AddrExists(addr_a))
250 );
251 assert_eq!(
252 builder.register(id_b, addr_b, name_a.clone()),
253 Err(RegistrationError::NameExists(name_a))
254 );
255
256 assert!(builder.register(id_b, addr_b, name_b.clone()).is_ok());
257 }
258
259 #[test]
260 fn lookups() {
261 let (id_a, addr_a, name_a) = host_a();
262 let (id_b, addr_b, name_b) = host_b();
263
264 let mut builder = DnsBuilder::new();
265 builder.register(id_a, addr_a, name_a.clone()).unwrap();
266 builder.register(id_b, addr_b, name_b.clone()).unwrap();
267 let dns = builder.into_dns().unwrap();
268
269 assert_eq!(dns.addr_to_host_id(addr_a), Some(id_a));
270 assert_eq!(dns.addr_to_host_id(addr_b), Some(id_b));
271 assert_eq!(dns.addr_to_host_id(Ipv4Addr::new(1, 2, 3, 4)), None);
272
273 assert_eq!(dns.addr_to_name(addr_a), Some(name_a.as_str()));
274 assert_eq!(dns.addr_to_name(addr_b), Some(name_b.as_str()));
275 assert_eq!(dns.addr_to_name(Ipv4Addr::new(1, 2, 3, 4)), None);
276
277 assert_eq!(dns.name_to_addr(&name_a), Some(addr_a));
278 assert_eq!(dns.name_to_addr(&name_b), Some(addr_b));
279 assert_eq!(dns.name_to_addr("empty"), None);
280 assert_eq!(dns.name_to_addr("localhost"), None);
281 }
282
283 #[test]
284 #[cfg_attr(miri, ignore)]
285 fn hosts_file() {
286 let (id_a, addr_a, name_a) = host_a();
287 let (id_b, addr_b, name_b) = host_b();
288
289 let mut builder = DnsBuilder::new();
290 builder.register(id_a, addr_a, name_a.clone()).unwrap();
291 builder.register(id_b, addr_b, name_b.clone()).unwrap();
292 let dns = builder.into_dns().unwrap();
293
294 let contents = std::fs::read_to_string(dns.hosts_path()).unwrap();
295
296 let expected = "127.0.0.1 localhost\n100.1.2.3 myhost\n200.3.2.1 theirhost\n";
297 assert_eq!(contents.as_str(), expected);
298 let unexpected = "127.0.0.1 localhost\n200.3.2.1 theirhost\n100.1.2.3 myhost\n";
299 assert_ne!(contents.as_str(), unexpected);
300 }
301}