1mod petgraph_wrapper;
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::error::Error;
6use std::hash::Hash;
7
8use anyhow::Context;
9use log::*;
10use petgraph::graph::NodeIndex;
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12
13use crate::core::configuration::{self, Compression, FileSource, GraphOptions, GraphSource};
14use crate::network::graph::petgraph_wrapper::GraphWrapper;
15use crate::utility::tilde_expansion;
16use crate::utility::units::{self, Unit};
17
18type NetGraphError = Box<dyn Error + Send + Sync + 'static>;
19
20#[derive(Debug, PartialEq, Eq)]
22pub struct ShadowNode {
23 pub id: u32,
24 pub bandwidth_down: Option<units::BitsPerSec<units::SiPrefixUpper>>,
25 pub bandwidth_up: Option<units::BitsPerSec<units::SiPrefixUpper>>,
26}
27
28impl TryFrom<gml_parser::gml::Node<'_>> for ShadowNode {
29 type Error = String;
30
31 fn try_from(mut gml_node: gml_parser::gml::Node) -> Result<Self, Self::Error> {
32 Ok(Self {
33 id: gml_node.id.ok_or("Node 'id' was not provided")?,
34 bandwidth_down: gml_node
35 .other
36 .remove("host_bandwidth_down")
37 .map(|bandwidth| {
38 bandwidth
39 .as_str()
40 .ok_or("Node 'host_bandwidth_down' is not a string")?
41 .parse()
42 .map_err(|e| format!("Node 'host_bandwidth_down' is not a valid unit: {e}"))
43 })
44 .transpose()?,
45 bandwidth_up: gml_node
46 .other
47 .remove("host_bandwidth_up")
48 .map(|bandwidth| {
49 bandwidth
50 .as_str()
51 .ok_or("Node 'host_bandwidth_up' is not a string")?
52 .parse()
53 .map_err(|e| format!("Node 'host_bandwidth_up' is not a valid unit: {e}"))
54 })
55 .transpose()?,
56 })
57 }
58}
59
60#[derive(Debug, PartialEq)]
62pub struct ShadowEdge {
63 pub source: u32,
64 pub target: u32,
65 pub latency: units::Time<units::TimePrefix>,
66 pub jitter: units::Time<units::TimePrefix>,
67 pub packet_loss: f32,
68}
69
70impl TryFrom<gml_parser::gml::Edge<'_>> for ShadowEdge {
71 type Error = String;
72
73 fn try_from(mut gml_edge: gml_parser::gml::Edge) -> Result<Self, Self::Error> {
74 let rv = Self {
75 source: gml_edge.source,
76 target: gml_edge.target,
77 latency: gml_edge
78 .other
79 .remove("latency")
80 .ok_or("Edge 'latency' was not provided")?
81 .as_str()
82 .ok_or("Edge 'latency' is not a string")?
83 .parse()
84 .map_err(|e| format!("Edge 'latency' is not a valid unit: {e}"))?,
85 jitter: match gml_edge.other.remove("jitter") {
86 Some(x) => x
87 .as_str()
88 .ok_or("Edge 'jitter' is not a string")?
89 .parse()
90 .map_err(|e| format!("Edge 'jitter' is not a valid unit: {e}"))?,
91 None => units::Time::new(0, units::TimePrefix::Milli),
92 },
93 packet_loss: match gml_edge.other.remove("packet_loss") {
94 Some(x) => x.as_float().ok_or("Edge 'packet_loss' is not a float")?,
95 None => 0.0,
96 },
97 };
98
99 if rv.packet_loss < 0f32 || rv.packet_loss > 1f32 {
100 return Err("Edge 'packet_loss' is not in the range [0,1]".into());
101 }
102
103 if rv.latency.value() == 0 {
104 return Err("Edge 'latency' must not be 0".into());
105 }
106
107 Ok(rv)
108 }
109}
110
111#[derive(Debug)]
114pub struct NetworkGraph {
115 graph: GraphWrapper<ShadowNode, ShadowEdge, u32>,
116 node_id_to_index_map: HashMap<u32, NodeIndex>,
117}
118
119impl NetworkGraph {
120 pub fn graph(&self) -> &GraphWrapper<ShadowNode, ShadowEdge, u32> {
121 &self.graph
122 }
123
124 pub fn node_id_to_index(&self, id: u32) -> Option<&NodeIndex> {
125 self.node_id_to_index_map.get(&id)
126 }
127
128 pub fn node_index_to_id(&self, index: NodeIndex) -> Option<u32> {
129 self.graph.node_weight(index).map(|w| w.id)
130 }
131
132 pub fn parse(graph_text: &str) -> Result<Self, NetGraphError> {
133 let gml_graph = gml_parser::parse(graph_text)?;
134
135 let mut g = match gml_graph.directed {
136 true => GraphWrapper::Directed(
137 petgraph::graph::Graph::<_, _, petgraph::Directed, _>::with_capacity(
138 gml_graph.nodes.len(),
139 gml_graph.edges.len(),
140 ),
141 ),
142 false => {
143 GraphWrapper::Undirected(
144 petgraph::graph::Graph::<_, _, petgraph::Undirected, _>::with_capacity(
145 gml_graph.nodes.len(),
146 gml_graph.edges.len(),
147 ),
148 )
149 }
150 };
151
152 let mut id_map = HashMap::new();
154
155 for x in gml_graph.nodes.into_iter() {
156 let x: ShadowNode = x.try_into()?;
157 let gml_id = x.id;
158 let petgraph_id = g.add_node(x);
159 id_map.insert(gml_id, petgraph_id);
160 }
161
162 for x in gml_graph.edges.into_iter() {
163 let x: ShadowEdge = x.try_into()?;
164
165 let source = *id_map
166 .get(&x.source)
167 .ok_or(format!("Edge source {} doesn't exist", x.source))?;
168 let target = *id_map
169 .get(&x.target)
170 .ok_or(format!("Edge target {} doesn't exist", x.target))?;
171
172 g.add_edge(source, target, x);
173 }
174
175 Ok(Self {
176 graph: g,
177 node_id_to_index_map: id_map,
178 })
179 }
180
181 pub fn compute_shortest_paths(
182 &self,
183 nodes: &[NodeIndex],
184 ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
185 let start = std::time::Instant::now();
186
187 let mut paths: HashMap<(_, _), PathProperties> = nodes
189 .into_par_iter()
190 .flat_map(|src| {
191 match &self.graph {
192 GraphWrapper::Directed(graph) => {
193 petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
194 }
195 GraphWrapper::Undirected(graph) => {
196 petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
197 }
198 }
199 .into_iter()
200 .filter(|(dst, _)| nodes.contains(dst))
202 .map(|(dst, path)| ((*src, dst), path))
204 .collect::<HashMap<(_, _), _>>()
205 })
206 .collect();
207
208 for node in nodes {
210 assert_eq!(paths[&(*node, *node)], PathProperties::default());
212
213 paths.insert((*node, *node), self.get_edge_weight(node, node)?.into());
215 }
216
217 assert_eq!(paths.len(), nodes.len().pow(2));
218
219 debug!(
220 "Finished computing shortest paths: {} seconds, {} entries",
221 (std::time::Instant::now() - start).as_secs(),
222 paths.len()
223 );
224
225 Ok(paths)
226 }
227
228 pub fn get_direct_paths(
229 &self,
230 nodes: &[NodeIndex],
231 ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
232 let start = std::time::Instant::now();
233
234 let paths: HashMap<_, _> = nodes
235 .iter()
236 .flat_map(|src| nodes.iter().map(move |dst| (*src, *dst)))
237 .map(|(src, dst)| Ok(((src, dst), self.get_edge_weight(&src, &dst)?.into())))
239 .collect::<Result<_, NetGraphError>>()?;
240
241 assert_eq!(paths.len(), nodes.len().pow(2));
242
243 debug!(
244 "Finished computing direct paths: {} seconds, {} entries",
245 (std::time::Instant::now() - start).as_secs(),
246 paths.len()
247 );
248
249 Ok(paths)
250 }
251
252 fn get_edge_weight(
255 &self,
256 src: &NodeIndex,
257 dst: &NodeIndex,
258 ) -> Result<&ShadowEdge, NetGraphError> {
259 let src_id = self.node_index_to_id(*src).unwrap();
260 let dst_id = self.node_index_to_id(*dst).unwrap();
261 match &self.graph {
262 GraphWrapper::Directed(graph) => {
263 let mut edges = graph.edges_connecting(*src, *dst);
264 let edge = edges
265 .next()
266 .ok_or(format!("No edge connecting node {src_id} to {dst_id}"))?;
267 if edges.count() != 0 {
268 return Err(
269 format!("More than one edge connecting node {src_id} to {dst_id}").into(),
270 );
271 }
272 Ok(edge.weight())
273 }
274 GraphWrapper::Undirected(graph) => {
275 let mut edges = graph.edges_connecting(*src, *dst);
276 let edge = edges
277 .next()
278 .ok_or(format!("No edge connecting node {src_id} to {dst_id}"))?;
279 if edges.count() != 0 {
280 return Err(
281 format!("More than one edge connecting node {src_id} to {dst_id}").into(),
282 );
283 }
284 Ok(edge.weight())
285 }
286 }
287 }
288}
289
290#[derive(Debug, Default, Clone, Copy)]
292pub struct PathProperties {
293 pub latency_ns: u64,
295 pub packet_loss: f32,
297}
298
299impl PartialOrd for PathProperties {
300 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
301 match self.latency_ns.cmp(&other.latency_ns) {
303 std::cmp::Ordering::Equal => self.packet_loss.partial_cmp(&other.packet_loss),
304 x => Some(x),
305 }
306 }
307}
308
309impl PartialEq for PathProperties {
310 fn eq(&self, other: &Self) -> bool {
311 self.partial_cmp(other) == Some(std::cmp::Ordering::Equal)
313 }
314}
315
316impl core::ops::Add for PathProperties {
317 type Output = Self;
318
319 fn add(self, other: Self) -> Self::Output {
320 Self {
321 latency_ns: self.latency_ns + other.latency_ns,
322 packet_loss: 1f32 - (1f32 - self.packet_loss) * (1f32 - other.packet_loss),
323 }
324 }
325}
326
327impl std::convert::From<&ShadowEdge> for PathProperties {
328 fn from(e: &ShadowEdge) -> Self {
329 Self {
330 latency_ns: e.latency.convert(units::TimePrefix::Nano).unwrap().value(),
331 packet_loss: e.packet_loss,
332 }
333 }
334}
335
336#[derive(Debug)]
337pub struct IpPreviouslyAssignedError;
338impl std::error::Error for IpPreviouslyAssignedError {}
339
340impl std::fmt::Display for IpPreviouslyAssignedError {
341 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
342 write!(f, "IP address has already been assigned")
343 }
344}
345
346#[derive(Debug)]
348pub struct IpAssignment<T: Copy + Eq + Hash + std::fmt::Display> {
349 map: HashMap<std::net::IpAddr, T>,
351 last_assigned_addr: std::net::IpAddr,
353}
354
355impl<T: Copy + Eq + Hash + std::fmt::Display> IpAssignment<T> {
356 pub fn new() -> Self {
357 Self {
358 map: HashMap::new(),
359 last_assigned_addr: std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 0)),
360 }
361 }
362
363 pub fn assign(&mut self, node_id: T) -> std::net::IpAddr {
365 loop {
367 let ip_addr = Self::increment_address(&self.last_assigned_addr);
368 self.last_assigned_addr = ip_addr;
369 if let std::collections::hash_map::Entry::Vacant(e) = self.map.entry(ip_addr) {
370 e.insert(node_id);
371 break ip_addr;
372 }
373 }
374 }
375
376 pub fn assign_ip(
378 &mut self,
379 node_id: T,
380 ip_addr: std::net::IpAddr,
381 ) -> Result<(), IpPreviouslyAssignedError> {
382 let entry = self.map.entry(ip_addr);
383 if let Entry::Occupied(_) = &entry {
384 return Err(IpPreviouslyAssignedError);
385 }
386 entry.or_insert(node_id);
387 Ok(())
388 }
389
390 pub fn get_node(&self, ip_addr: std::net::IpAddr) -> Option<T> {
392 self.map.get(&ip_addr).copied()
393 }
394
395 pub fn get_nodes(&self) -> std::collections::HashSet<T> {
397 self.map.values().copied().collect()
398 }
399
400 fn increment_address(addr: &std::net::IpAddr) -> std::net::IpAddr {
401 match addr {
402 std::net::IpAddr::V4(x) => {
403 let addr_bits = u32::from(*x);
404 let mut increment = 1;
405 loop {
406 let next_addr = std::net::Ipv4Addr::from(addr_bits + increment);
408 match next_addr.octets()[3] {
409 0 | 255 => increment += 1,
411 _ => break std::net::IpAddr::V4(next_addr),
412 }
413 }
414 }
415 std::net::IpAddr::V6(_) => unimplemented!(),
416 }
417 }
418}
419
420impl<T: Copy + Eq + Hash + std::fmt::Display> Default for IpAssignment<T> {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[derive(Debug)]
428pub struct RoutingInfo<T: Eq + Hash + std::fmt::Display + Clone + Copy> {
429 paths: HashMap<(T, T), PathProperties>,
430 packet_counters: std::sync::RwLock<HashMap<(T, T), u64>>,
431}
432
433impl<T: Eq + Hash + std::fmt::Display + Clone + Copy> RoutingInfo<T> {
434 pub fn new(paths: HashMap<(T, T), PathProperties>) -> Self {
435 Self {
436 paths,
437 packet_counters: std::sync::RwLock::new(HashMap::new()),
438 }
439 }
440
441 pub fn path(&self, start: T, end: T) -> Option<PathProperties> {
443 self.paths.get(&(start, end)).copied()
444 }
445
446 pub fn increment_packet_count(&self, start: T, end: T) {
448 let key = (start, end);
449 let mut packet_counters = self.packet_counters.write().unwrap();
450 match packet_counters.get_mut(&key) {
451 Some(x) => *x = x.saturating_add(1),
452 None => assert!(packet_counters.insert(key, 1).is_none()),
453 }
454 }
455
456 pub fn log_packet_counts(&self) {
458 for ((start, end), count) in self.packet_counters.read().unwrap().iter() {
460 let path = self.paths.get(&(*start, *end)).unwrap();
461 log::debug!(
462 "Found path {}->{}: latency={}ns, packet_loss={}, packet_count={}",
463 start,
464 end,
465 path.latency_ns,
466 path.packet_loss,
467 count,
468 );
469 }
470 }
471
472 pub fn get_smallest_latency_ns(&self) -> Option<u64> {
473 self.paths.values().map(|x| x.latency_ns).min()
474 }
475}
476
477fn read_xz<P: AsRef<std::path::Path>>(path: P) -> Result<String, NetGraphError> {
479 let path = path.as_ref();
480
481 let mut f = std::io::BufReader::new(
482 std::fs::File::open(path).with_context(|| format!("Failed to open file: {path:?}"))?,
483 );
484
485 let mut decomp: Vec<u8> = Vec::new();
486 lzma_rs::xz_decompress(&mut f, &mut decomp).context("Failed to decompress file")?;
487 decomp.shrink_to_fit();
488
489 Ok(String::from_utf8(decomp)?)
490}
491
492pub fn load_network_graph(graph_options: &GraphOptions) -> Result<String, NetGraphError> {
494 Ok(match graph_options {
495 GraphOptions::Gml(GraphSource::File(FileSource {
496 compression: None,
497 path: f,
498 })) => std::fs::read_to_string(tilde_expansion(f))
499 .with_context(|| format!("Failed to read file: {f}"))?,
500 GraphOptions::Gml(GraphSource::File(FileSource {
501 compression: Some(Compression::Xz),
502 path: f,
503 })) => read_xz(tilde_expansion(f))?,
504 GraphOptions::Gml(GraphSource::Inline(s)) => s.clone(),
505 GraphOptions::OneGbitSwitch => configuration::ONE_GBIT_SWITCH_GRAPH.to_string(),
506 })
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_path_add() {
515 let p1 = PathProperties {
516 latency_ns: 23,
517 packet_loss: 0.35,
518 };
519 let p2 = PathProperties {
520 latency_ns: 11,
521 packet_loss: 0.85,
522 };
523
524 let p3 = p1 + p2;
525 assert_eq!(p3.latency_ns, 34);
526 assert!((p3.packet_loss - 0.9025).abs() < 0.01);
527 }
528
529 #[test]
530 fn test_nonexistent_id() {
531 for id in &[2, 3] {
532 let graph = format!(
533 r#"graph [
534 node [
535 id 1
536 ]
537 node [
538 id 3
539 ]
540 edge [
541 source 1
542 target {id}
543 latency "1 ns"
544 ]
545 ]"#,
546 );
547
548 if *id == 3 {
549 NetworkGraph::parse(&graph).unwrap();
550 } else {
551 NetworkGraph::parse(&graph).unwrap_err();
552 }
553 }
554 }
555
556 #[test]
558 #[cfg_attr(miri, ignore)]
559 fn test_shortest_path() {
560 for directed in &[true, false] {
561 let graph = format!(
562 r#"graph [
563 directed {}
564 node [
565 id 0
566 ]
567 node [
568 id 1
569 ]
570 node [
571 id 2
572 ]
573 edge [
574 source 0
575 target 0
576 latency "3333 ns"
577 ]
578 edge [
579 source 1
580 target 1
581 latency "5555 ns"
582 ]
583 edge [
584 source 2
585 target 2
586 latency "7777 ns"
587 ]
588 edge [
589 source 0
590 target 1
591 latency "3 ns"
592 ]
593 edge [
594 source 1
595 target 0
596 latency "5 ns"
597 ]
598 edge [
599 source 0
600 target 2
601 latency "7 ns"
602 ]
603 edge [
604 source 2
605 target 1
606 latency "11 ns"
607 ]
608 ]"#,
609 if *directed { 1 } else { 0 }
610 );
611 let graph = NetworkGraph::parse(&graph).unwrap();
612 let node_0 = *graph.node_id_to_index(0).unwrap();
613 let node_1 = *graph.node_id_to_index(1).unwrap();
614 let node_2 = *graph.node_id_to_index(2).unwrap();
615
616 let shortest_paths = graph
617 .compute_shortest_paths(&[node_0, node_1, node_2])
618 .unwrap();
619
620 let lookup_latency = |a, b| shortest_paths.get(&(a, b)).unwrap().latency_ns;
621
622 if *directed {
623 assert_eq!(lookup_latency(node_0, node_0), 3333);
624 assert_eq!(lookup_latency(node_0, node_1), 3);
625 assert_eq!(lookup_latency(node_0, node_2), 7);
626 assert_eq!(lookup_latency(node_1, node_0), 5);
627 assert_eq!(lookup_latency(node_1, node_1), 5555);
628 assert_eq!(lookup_latency(node_1, node_2), 12);
629 assert_eq!(lookup_latency(node_2, node_0), 16);
630 assert_eq!(lookup_latency(node_2, node_1), 11);
631 assert_eq!(lookup_latency(node_2, node_2), 7777);
632 } else {
633 assert_eq!(lookup_latency(node_0, node_0), 3333);
634 assert_eq!(lookup_latency(node_0, node_1), 3);
635 assert_eq!(lookup_latency(node_0, node_2), 7);
636 assert_eq!(lookup_latency(node_1, node_0), 3);
637 assert_eq!(lookup_latency(node_1, node_1), 5555);
638 assert_eq!(lookup_latency(node_1, node_2), 10);
639 assert_eq!(lookup_latency(node_2, node_0), 7);
640 assert_eq!(lookup_latency(node_2, node_1), 10);
641 assert_eq!(lookup_latency(node_2, node_2), 7777);
642 }
643 }
644 }
645
646 #[test]
647 fn test_increment_address_skip_broadcast() {
648 let addr = std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 254));
649 let incremented = IpAssignment::<i32>::increment_address(&addr);
650 assert!(incremented > addr);
651 assert_ne!(
652 incremented,
653 std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 255))
654 );
655 }
656}