1mod petgraph_wrapper;
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::error::Error;
6use std::hash::Hash;
7
8use anyhow::Context;
9use log::*;
10use petgraph::graph::NodeIndex;
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12
13use crate::core::configuration::{self, Compression, FileSource, GraphOptions, GraphSource};
14use crate::network::graph::petgraph_wrapper::GraphWrapper;
15use crate::utility::tilde_expansion;
16use crate::utility::units::{self, Unit};
17
18type NetGraphError = Box<dyn Error + Send + Sync + 'static>;
19
20#[derive(Debug, PartialEq, Eq)]
22pub struct ShadowNode {
23 pub id: u32,
24 pub bandwidth_down: Option<units::BitsPerSec<units::SiPrefixUpper>>,
25 pub bandwidth_up: Option<units::BitsPerSec<units::SiPrefixUpper>>,
26}
27
28impl TryFrom<gml_parser::gml::Node<'_>> for ShadowNode {
29 type Error = String;
30
31 fn try_from(mut gml_node: gml_parser::gml::Node) -> Result<Self, Self::Error> {
32 Ok(Self {
33 id: gml_node.id.ok_or("Node 'id' was not provided")?,
34 bandwidth_down: gml_node
35 .other
36 .remove("host_bandwidth_down")
37 .map(|bandwidth| {
38 bandwidth
39 .as_str()
40 .ok_or("Node 'host_bandwidth_down' is not a string")?
41 .parse()
42 .map_err(|e| {
43 format!("Node 'host_bandwidth_down' is not a valid unit: {}", e)
44 })
45 })
46 .transpose()?,
47 bandwidth_up: gml_node
48 .other
49 .remove("host_bandwidth_up")
50 .map(|bandwidth| {
51 bandwidth
52 .as_str()
53 .ok_or("Node 'host_bandwidth_up' is not a string")?
54 .parse()
55 .map_err(|e| format!("Node 'host_bandwidth_up' is not a valid unit: {}", e))
56 })
57 .transpose()?,
58 })
59 }
60}
61
62#[derive(Debug, PartialEq)]
64pub struct ShadowEdge {
65 pub source: u32,
66 pub target: u32,
67 pub latency: units::Time<units::TimePrefix>,
68 pub jitter: units::Time<units::TimePrefix>,
69 pub packet_loss: f32,
70}
71
72impl TryFrom<gml_parser::gml::Edge<'_>> for ShadowEdge {
73 type Error = String;
74
75 fn try_from(mut gml_edge: gml_parser::gml::Edge) -> Result<Self, Self::Error> {
76 let rv = Self {
77 source: gml_edge.source,
78 target: gml_edge.target,
79 latency: gml_edge
80 .other
81 .remove("latency")
82 .ok_or("Edge 'latency' was not provided")?
83 .as_str()
84 .ok_or("Edge 'latency' is not a string")?
85 .parse()
86 .map_err(|e| format!("Edge 'latency' is not a valid unit: {}", e))?,
87 jitter: match gml_edge.other.remove("jitter") {
88 Some(x) => x
89 .as_str()
90 .ok_or("Edge 'jitter' is not a string")?
91 .parse()
92 .map_err(|e| format!("Edge 'jitter' is not a valid unit: {}", e))?,
93 None => units::Time::new(0, units::TimePrefix::Milli),
94 },
95 packet_loss: match gml_edge.other.remove("packet_loss") {
96 Some(x) => x.as_float().ok_or("Edge 'packet_loss' is not a float")?,
97 None => 0.0,
98 },
99 };
100
101 if rv.packet_loss < 0f32 || rv.packet_loss > 1f32 {
102 return Err("Edge 'packet_loss' is not in the range [0,1]".into());
103 }
104
105 if rv.latency.value() == 0 {
106 return Err("Edge 'latency' must not be 0".into());
107 }
108
109 Ok(rv)
110 }
111}
112
113#[derive(Debug)]
116pub struct NetworkGraph {
117 graph: GraphWrapper<ShadowNode, ShadowEdge, u32>,
118 node_id_to_index_map: HashMap<u32, NodeIndex>,
119}
120
121impl NetworkGraph {
122 pub fn graph(&self) -> &GraphWrapper<ShadowNode, ShadowEdge, u32> {
123 &self.graph
124 }
125
126 pub fn node_id_to_index(&self, id: u32) -> Option<&NodeIndex> {
127 self.node_id_to_index_map.get(&id)
128 }
129
130 pub fn node_index_to_id(&self, index: NodeIndex) -> Option<u32> {
131 self.graph.node_weight(index).map(|w| w.id)
132 }
133
134 pub fn parse(graph_text: &str) -> Result<Self, NetGraphError> {
135 let gml_graph = gml_parser::parse(graph_text)?;
136
137 let mut g = match gml_graph.directed {
138 true => GraphWrapper::Directed(
139 petgraph::graph::Graph::<_, _, petgraph::Directed, _>::with_capacity(
140 gml_graph.nodes.len(),
141 gml_graph.edges.len(),
142 ),
143 ),
144 false => {
145 GraphWrapper::Undirected(
146 petgraph::graph::Graph::<_, _, petgraph::Undirected, _>::with_capacity(
147 gml_graph.nodes.len(),
148 gml_graph.edges.len(),
149 ),
150 )
151 }
152 };
153
154 let mut id_map = HashMap::new();
156
157 for x in gml_graph.nodes.into_iter() {
158 let x: ShadowNode = x.try_into()?;
159 let gml_id = x.id;
160 let petgraph_id = g.add_node(x);
161 id_map.insert(gml_id, petgraph_id);
162 }
163
164 for x in gml_graph.edges.into_iter() {
165 let x: ShadowEdge = x.try_into()?;
166
167 let source = *id_map
168 .get(&x.source)
169 .ok_or(format!("Edge source {} doesn't exist", x.source))?;
170 let target = *id_map
171 .get(&x.target)
172 .ok_or(format!("Edge target {} doesn't exist", x.target))?;
173
174 g.add_edge(source, target, x);
175 }
176
177 Ok(Self {
178 graph: g,
179 node_id_to_index_map: id_map,
180 })
181 }
182
183 pub fn compute_shortest_paths(
184 &self,
185 nodes: &[NodeIndex],
186 ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
187 let start = std::time::Instant::now();
188
189 let mut paths: HashMap<(_, _), PathProperties> = nodes
191 .into_par_iter()
192 .flat_map(|src| {
193 match &self.graph {
194 GraphWrapper::Directed(graph) => {
195 petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
196 }
197 GraphWrapper::Undirected(graph) => {
198 petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
199 }
200 }
201 .into_iter()
202 .filter(|(dst, _)| nodes.contains(dst))
204 .map(|(dst, path)| ((*src, dst), path))
206 .collect::<HashMap<(_, _), _>>()
207 })
208 .collect();
209
210 for node in nodes {
212 assert_eq!(paths[&(*node, *node)], PathProperties::default());
214
215 paths.insert((*node, *node), self.get_edge_weight(node, node)?.into());
217 }
218
219 assert_eq!(paths.len(), nodes.len().pow(2));
220
221 debug!(
222 "Finished computing shortest paths: {} seconds, {} entries",
223 (std::time::Instant::now() - start).as_secs(),
224 paths.len()
225 );
226
227 Ok(paths)
228 }
229
230 pub fn get_direct_paths(
231 &self,
232 nodes: &[NodeIndex],
233 ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
234 let start = std::time::Instant::now();
235
236 let paths: HashMap<_, _> = nodes
237 .iter()
238 .flat_map(|src| nodes.iter().map(move |dst| (*src, *dst)))
239 .map(|(src, dst)| Ok(((src, dst), self.get_edge_weight(&src, &dst)?.into())))
241 .collect::<Result<_, NetGraphError>>()?;
242
243 assert_eq!(paths.len(), nodes.len().pow(2));
244
245 debug!(
246 "Finished computing direct paths: {} seconds, {} entries",
247 (std::time::Instant::now() - start).as_secs(),
248 paths.len()
249 );
250
251 Ok(paths)
252 }
253
254 fn get_edge_weight(
257 &self,
258 src: &NodeIndex,
259 dst: &NodeIndex,
260 ) -> Result<&ShadowEdge, NetGraphError> {
261 let src_id = self.node_index_to_id(*src).unwrap();
262 let dst_id = self.node_index_to_id(*dst).unwrap();
263 match &self.graph {
264 GraphWrapper::Directed(graph) => {
265 let mut edges = graph.edges_connecting(*src, *dst);
266 let edge = edges
267 .next()
268 .ok_or(format!("No edge connecting node {} to {}", src_id, dst_id))?;
269 if edges.count() != 0 {
270 return Err(format!(
271 "More than one edge connecting node {} to {}",
272 src_id, dst_id
273 )
274 .into());
275 }
276 Ok(edge.weight())
277 }
278 GraphWrapper::Undirected(graph) => {
279 let mut edges = graph.edges_connecting(*src, *dst);
280 let edge = edges
281 .next()
282 .ok_or(format!("No edge connecting node {} to {}", src_id, dst_id))?;
283 if edges.count() != 0 {
284 return Err(format!(
285 "More than one edge connecting node {} to {}",
286 src_id, dst_id
287 )
288 .into());
289 }
290 Ok(edge.weight())
291 }
292 }
293 }
294}
295
296#[derive(Debug, Default, Clone, Copy)]
298pub struct PathProperties {
299 pub latency_ns: u64,
301 pub packet_loss: f32,
303}
304
305impl PartialOrd for PathProperties {
306 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307 match self.latency_ns.cmp(&other.latency_ns) {
309 std::cmp::Ordering::Equal => self.packet_loss.partial_cmp(&other.packet_loss),
310 x => Some(x),
311 }
312 }
313}
314
315impl PartialEq for PathProperties {
316 fn eq(&self, other: &Self) -> bool {
317 self.partial_cmp(other) == Some(std::cmp::Ordering::Equal)
319 }
320}
321
322impl core::ops::Add for PathProperties {
323 type Output = Self;
324
325 fn add(self, other: Self) -> Self::Output {
326 Self {
327 latency_ns: self.latency_ns + other.latency_ns,
328 packet_loss: 1f32 - (1f32 - self.packet_loss) * (1f32 - other.packet_loss),
329 }
330 }
331}
332
333impl std::convert::From<&ShadowEdge> for PathProperties {
334 fn from(e: &ShadowEdge) -> Self {
335 Self {
336 latency_ns: e.latency.convert(units::TimePrefix::Nano).unwrap().value(),
337 packet_loss: e.packet_loss,
338 }
339 }
340}
341
342#[derive(Debug)]
343pub struct IpPreviouslyAssignedError;
344impl std::error::Error for IpPreviouslyAssignedError {}
345
346impl std::fmt::Display for IpPreviouslyAssignedError {
347 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
348 write!(f, "IP address has already been assigned")
349 }
350}
351
352#[derive(Debug)]
354pub struct IpAssignment<T: Copy + Eq + Hash + std::fmt::Display> {
355 map: HashMap<std::net::IpAddr, T>,
357 last_assigned_addr: std::net::IpAddr,
359}
360
361impl<T: Copy + Eq + Hash + std::fmt::Display> IpAssignment<T> {
362 pub fn new() -> Self {
363 Self {
364 map: HashMap::new(),
365 last_assigned_addr: std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 0)),
366 }
367 }
368
369 pub fn assign(&mut self, node_id: T) -> std::net::IpAddr {
371 loop {
373 let ip_addr = Self::increment_address(&self.last_assigned_addr);
374 self.last_assigned_addr = ip_addr;
375 if let std::collections::hash_map::Entry::Vacant(e) = self.map.entry(ip_addr) {
376 e.insert(node_id);
377 break ip_addr;
378 }
379 }
380 }
381
382 pub fn assign_ip(
384 &mut self,
385 node_id: T,
386 ip_addr: std::net::IpAddr,
387 ) -> Result<(), IpPreviouslyAssignedError> {
388 let entry = self.map.entry(ip_addr);
389 if let Entry::Occupied(_) = &entry {
390 return Err(IpPreviouslyAssignedError);
391 }
392 entry.or_insert(node_id);
393 Ok(())
394 }
395
396 pub fn get_node(&self, ip_addr: std::net::IpAddr) -> Option<T> {
398 self.map.get(&ip_addr).copied()
399 }
400
401 pub fn get_nodes(&self) -> std::collections::HashSet<T> {
403 self.map.values().copied().collect()
404 }
405
406 fn increment_address(addr: &std::net::IpAddr) -> std::net::IpAddr {
407 match addr {
408 std::net::IpAddr::V4(x) => {
409 let addr_bits = u32::from(*x);
410 let mut increment = 1;
411 loop {
412 let next_addr = std::net::Ipv4Addr::from(addr_bits + increment);
414 match next_addr.octets()[3] {
415 0 | 255 => increment += 1,
417 _ => break std::net::IpAddr::V4(next_addr),
418 }
419 }
420 }
421 std::net::IpAddr::V6(_) => unimplemented!(),
422 }
423 }
424}
425
426impl<T: Copy + Eq + Hash + std::fmt::Display> Default for IpAssignment<T> {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432#[derive(Debug)]
434pub struct RoutingInfo<T: Eq + Hash + std::fmt::Display + Clone + Copy> {
435 paths: HashMap<(T, T), PathProperties>,
436 packet_counters: std::sync::RwLock<HashMap<(T, T), u64>>,
437}
438
439impl<T: Eq + Hash + std::fmt::Display + Clone + Copy> RoutingInfo<T> {
440 pub fn new(paths: HashMap<(T, T), PathProperties>) -> Self {
441 Self {
442 paths,
443 packet_counters: std::sync::RwLock::new(HashMap::new()),
444 }
445 }
446
447 pub fn path(&self, start: T, end: T) -> Option<PathProperties> {
449 self.paths.get(&(start, end)).copied()
450 }
451
452 pub fn increment_packet_count(&self, start: T, end: T) {
454 let key = (start, end);
455 let mut packet_counters = self.packet_counters.write().unwrap();
456 match packet_counters.get_mut(&key) {
457 Some(x) => *x = x.saturating_add(1),
458 None => assert!(packet_counters.insert(key, 1).is_none()),
459 }
460 }
461
462 pub fn log_packet_counts(&self) {
464 for ((start, end), count) in self.packet_counters.read().unwrap().iter() {
466 let path = self.paths.get(&(*start, *end)).unwrap();
467 log::debug!(
468 "Found path {}->{}: latency={}ns, packet_loss={}, packet_count={}",
469 start,
470 end,
471 path.latency_ns,
472 path.packet_loss,
473 count,
474 );
475 }
476 }
477
478 pub fn get_smallest_latency_ns(&self) -> Option<u64> {
479 self.paths.values().map(|x| x.latency_ns).min()
480 }
481}
482
483fn read_xz<P: AsRef<std::path::Path>>(path: P) -> Result<String, NetGraphError> {
485 let path = path.as_ref();
486
487 let mut f = std::io::BufReader::new(
488 std::fs::File::open(path).with_context(|| format!("Failed to open file: {path:?}"))?,
489 );
490
491 let mut decomp: Vec<u8> = Vec::new();
492 lzma_rs::xz_decompress(&mut f, &mut decomp).context("Failed to decompress file")?;
493 decomp.shrink_to_fit();
494
495 Ok(String::from_utf8(decomp)?)
496}
497
498pub fn load_network_graph(graph_options: &GraphOptions) -> Result<String, NetGraphError> {
500 Ok(match graph_options {
501 GraphOptions::Gml(GraphSource::File(FileSource {
502 compression: None,
503 path: f,
504 })) => std::fs::read_to_string(tilde_expansion(f))
505 .with_context(|| format!("Failed to read file: {f}"))?,
506 GraphOptions::Gml(GraphSource::File(FileSource {
507 compression: Some(Compression::Xz),
508 path: f,
509 })) => read_xz(tilde_expansion(f))?,
510 GraphOptions::Gml(GraphSource::Inline(s)) => s.clone(),
511 GraphOptions::OneGbitSwitch => configuration::ONE_GBIT_SWITCH_GRAPH.to_string(),
512 })
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_path_add() {
521 let p1 = PathProperties {
522 latency_ns: 23,
523 packet_loss: 0.35,
524 };
525 let p2 = PathProperties {
526 latency_ns: 11,
527 packet_loss: 0.85,
528 };
529
530 let p3 = p1 + p2;
531 assert_eq!(p3.latency_ns, 34);
532 assert!((p3.packet_loss - 0.9025).abs() < 0.01);
533 }
534
535 #[test]
536 fn test_nonexistent_id() {
537 for id in &[2, 3] {
538 let graph = format!(
539 r#"graph [
540 node [
541 id 1
542 ]
543 node [
544 id 3
545 ]
546 edge [
547 source 1
548 target {}
549 latency "1 ns"
550 ]
551 ]"#,
552 id
553 );
554
555 if *id == 3 {
556 NetworkGraph::parse(&graph).unwrap();
557 } else {
558 NetworkGraph::parse(&graph).unwrap_err();
559 }
560 }
561 }
562
563 #[test]
565 #[cfg_attr(miri, ignore)]
566 fn test_shortest_path() {
567 for directed in &[true, false] {
568 let graph = format!(
569 r#"graph [
570 directed {}
571 node [
572 id 0
573 ]
574 node [
575 id 1
576 ]
577 node [
578 id 2
579 ]
580 edge [
581 source 0
582 target 0
583 latency "3333 ns"
584 ]
585 edge [
586 source 1
587 target 1
588 latency "5555 ns"
589 ]
590 edge [
591 source 2
592 target 2
593 latency "7777 ns"
594 ]
595 edge [
596 source 0
597 target 1
598 latency "3 ns"
599 ]
600 edge [
601 source 1
602 target 0
603 latency "5 ns"
604 ]
605 edge [
606 source 0
607 target 2
608 latency "7 ns"
609 ]
610 edge [
611 source 2
612 target 1
613 latency "11 ns"
614 ]
615 ]"#,
616 if *directed { 1 } else { 0 }
617 );
618 let graph = NetworkGraph::parse(&graph).unwrap();
619 let node_0 = *graph.node_id_to_index(0).unwrap();
620 let node_1 = *graph.node_id_to_index(1).unwrap();
621 let node_2 = *graph.node_id_to_index(2).unwrap();
622
623 let shortest_paths = graph
624 .compute_shortest_paths(&[node_0, node_1, node_2])
625 .unwrap();
626
627 let lookup_latency = |a, b| shortest_paths.get(&(a, b)).unwrap().latency_ns;
628
629 if *directed {
630 assert_eq!(lookup_latency(node_0, node_0), 3333);
631 assert_eq!(lookup_latency(node_0, node_1), 3);
632 assert_eq!(lookup_latency(node_0, node_2), 7);
633 assert_eq!(lookup_latency(node_1, node_0), 5);
634 assert_eq!(lookup_latency(node_1, node_1), 5555);
635 assert_eq!(lookup_latency(node_1, node_2), 12);
636 assert_eq!(lookup_latency(node_2, node_0), 16);
637 assert_eq!(lookup_latency(node_2, node_1), 11);
638 assert_eq!(lookup_latency(node_2, node_2), 7777);
639 } else {
640 assert_eq!(lookup_latency(node_0, node_0), 3333);
641 assert_eq!(lookup_latency(node_0, node_1), 3);
642 assert_eq!(lookup_latency(node_0, node_2), 7);
643 assert_eq!(lookup_latency(node_1, node_0), 3);
644 assert_eq!(lookup_latency(node_1, node_1), 5555);
645 assert_eq!(lookup_latency(node_1, node_2), 10);
646 assert_eq!(lookup_latency(node_2, node_0), 7);
647 assert_eq!(lookup_latency(node_2, node_1), 10);
648 assert_eq!(lookup_latency(node_2, node_2), 7777);
649 }
650 }
651 }
652
653 #[test]
654 fn test_increment_address_skip_broadcast() {
655 let addr = std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 254));
656 let incremented = IpAssignment::<i32>::increment_address(&addr);
657 assert!(incremented > addr);
658 assert_ne!(
659 incremented,
660 std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 255))
661 );
662 }
663}