shadow_rs/network/graph/
mod.rs

1mod petgraph_wrapper;
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::error::Error;
6use std::hash::Hash;
7
8use anyhow::Context;
9use log::*;
10use petgraph::graph::NodeIndex;
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12
13use crate::core::configuration::{self, Compression, FileSource, GraphOptions, GraphSource};
14use crate::network::graph::petgraph_wrapper::GraphWrapper;
15use crate::utility::tilde_expansion;
16use crate::utility::units::{self, Unit};
17
18type NetGraphError = Box<dyn Error + Send + Sync + 'static>;
19
20/// A graph node.
21#[derive(Debug, PartialEq, Eq)]
22pub struct ShadowNode {
23    pub id: u32,
24    pub bandwidth_down: Option<units::BitsPerSec<units::SiPrefixUpper>>,
25    pub bandwidth_up: Option<units::BitsPerSec<units::SiPrefixUpper>>,
26}
27
28impl TryFrom<gml_parser::gml::Node<'_>> for ShadowNode {
29    type Error = String;
30
31    fn try_from(mut gml_node: gml_parser::gml::Node) -> Result<Self, Self::Error> {
32        Ok(Self {
33            id: gml_node.id.ok_or("Node 'id' was not provided")?,
34            bandwidth_down: gml_node
35                .other
36                .remove("host_bandwidth_down")
37                .map(|bandwidth| {
38                    bandwidth
39                        .as_str()
40                        .ok_or("Node 'host_bandwidth_down' is not a string")?
41                        .parse()
42                        .map_err(|e| {
43                            format!("Node 'host_bandwidth_down' is not a valid unit: {}", e)
44                        })
45                })
46                .transpose()?,
47            bandwidth_up: gml_node
48                .other
49                .remove("host_bandwidth_up")
50                .map(|bandwidth| {
51                    bandwidth
52                        .as_str()
53                        .ok_or("Node 'host_bandwidth_up' is not a string")?
54                        .parse()
55                        .map_err(|e| format!("Node 'host_bandwidth_up' is not a valid unit: {}", e))
56                })
57                .transpose()?,
58        })
59    }
60}
61
62/// A graph edge.
63#[derive(Debug, PartialEq)]
64pub struct ShadowEdge {
65    pub source: u32,
66    pub target: u32,
67    pub latency: units::Time<units::TimePrefix>,
68    pub jitter: units::Time<units::TimePrefix>,
69    pub packet_loss: f32,
70}
71
72impl TryFrom<gml_parser::gml::Edge<'_>> for ShadowEdge {
73    type Error = String;
74
75    fn try_from(mut gml_edge: gml_parser::gml::Edge) -> Result<Self, Self::Error> {
76        let rv = Self {
77            source: gml_edge.source,
78            target: gml_edge.target,
79            latency: gml_edge
80                .other
81                .remove("latency")
82                .ok_or("Edge 'latency' was not provided")?
83                .as_str()
84                .ok_or("Edge 'latency' is not a string")?
85                .parse()
86                .map_err(|e| format!("Edge 'latency' is not a valid unit: {}", e))?,
87            jitter: match gml_edge.other.remove("jitter") {
88                Some(x) => x
89                    .as_str()
90                    .ok_or("Edge 'jitter' is not a string")?
91                    .parse()
92                    .map_err(|e| format!("Edge 'jitter' is not a valid unit: {}", e))?,
93                None => units::Time::new(0, units::TimePrefix::Milli),
94            },
95            packet_loss: match gml_edge.other.remove("packet_loss") {
96                Some(x) => x.as_float().ok_or("Edge 'packet_loss' is not a float")?,
97                None => 0.0,
98            },
99        };
100
101        if rv.packet_loss < 0f32 || rv.packet_loss > 1f32 {
102            return Err("Edge 'packet_loss' is not in the range [0,1]".into());
103        }
104
105        if rv.latency.value() == 0 {
106            return Err("Edge 'latency' must not be 0".into());
107        }
108
109        Ok(rv)
110    }
111}
112
113/// A network graph containing the petgraph graph and a map from gml node ids to petgraph node
114/// indexes.
115#[derive(Debug)]
116pub struct NetworkGraph {
117    graph: GraphWrapper<ShadowNode, ShadowEdge, u32>,
118    node_id_to_index_map: HashMap<u32, NodeIndex>,
119}
120
121impl NetworkGraph {
122    pub fn graph(&self) -> &GraphWrapper<ShadowNode, ShadowEdge, u32> {
123        &self.graph
124    }
125
126    pub fn node_id_to_index(&self, id: u32) -> Option<&NodeIndex> {
127        self.node_id_to_index_map.get(&id)
128    }
129
130    pub fn node_index_to_id(&self, index: NodeIndex) -> Option<u32> {
131        self.graph.node_weight(index).map(|w| w.id)
132    }
133
134    pub fn parse(graph_text: &str) -> Result<Self, NetGraphError> {
135        let gml_graph = gml_parser::parse(graph_text)?;
136
137        let mut g = match gml_graph.directed {
138            true => GraphWrapper::Directed(
139                petgraph::graph::Graph::<_, _, petgraph::Directed, _>::with_capacity(
140                    gml_graph.nodes.len(),
141                    gml_graph.edges.len(),
142                ),
143            ),
144            false => {
145                GraphWrapper::Undirected(
146                    petgraph::graph::Graph::<_, _, petgraph::Undirected, _>::with_capacity(
147                        gml_graph.nodes.len(),
148                        gml_graph.edges.len(),
149                    ),
150                )
151            }
152        };
153
154        // map from GML id to petgraph id
155        let mut id_map = HashMap::new();
156
157        for x in gml_graph.nodes.into_iter() {
158            let x: ShadowNode = x.try_into()?;
159            let gml_id = x.id;
160            let petgraph_id = g.add_node(x);
161            id_map.insert(gml_id, petgraph_id);
162        }
163
164        for x in gml_graph.edges.into_iter() {
165            let x: ShadowEdge = x.try_into()?;
166
167            let source = *id_map
168                .get(&x.source)
169                .ok_or(format!("Edge source {} doesn't exist", x.source))?;
170            let target = *id_map
171                .get(&x.target)
172                .ok_or(format!("Edge target {} doesn't exist", x.target))?;
173
174            g.add_edge(source, target, x);
175        }
176
177        Ok(Self {
178            graph: g,
179            node_id_to_index_map: id_map,
180        })
181    }
182
183    pub fn compute_shortest_paths(
184        &self,
185        nodes: &[NodeIndex],
186    ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
187        let start = std::time::Instant::now();
188
189        // calculate shortest paths
190        let mut paths: HashMap<(_, _), PathProperties> = nodes
191            .into_par_iter()
192            .flat_map(|src| {
193                match &self.graph {
194                    GraphWrapper::Directed(graph) => {
195                        petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
196                    }
197                    GraphWrapper::Undirected(graph) => {
198                        petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
199                    }
200                }
201                .into_iter()
202                // ignore nodes that aren't in use
203                .filter(|(dst, _)| nodes.contains(dst))
204                // include the src node
205                .map(|(dst, path)| ((*src, dst), path))
206                .collect::<HashMap<(_, _), _>>()
207            })
208            .collect();
209
210        // use the self-loop for paths from a node to itself
211        for node in nodes {
212            // the dijkstra shortest path from node -> node will always be 0
213            assert_eq!(paths[&(*node, *node)], PathProperties::default());
214
215            // there must be a single self-loop for each node
216            paths.insert((*node, *node), self.get_edge_weight(node, node)?.into());
217        }
218
219        assert_eq!(paths.len(), nodes.len().pow(2));
220
221        debug!(
222            "Finished computing shortest paths: {} seconds, {} entries",
223            (std::time::Instant::now() - start).as_secs(),
224            paths.len()
225        );
226
227        Ok(paths)
228    }
229
230    pub fn get_direct_paths(
231        &self,
232        nodes: &[NodeIndex],
233    ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
234        let start = std::time::Instant::now();
235
236        let paths: HashMap<_, _> = nodes
237            .iter()
238            .flat_map(|src| nodes.iter().map(move |dst| (*src, *dst)))
239            // we require the graph to be connected with exactly one edge between any two nodes
240            .map(|(src, dst)| Ok(((src, dst), self.get_edge_weight(&src, &dst)?.into())))
241            .collect::<Result<_, NetGraphError>>()?;
242
243        assert_eq!(paths.len(), nodes.len().pow(2));
244
245        debug!(
246            "Finished computing direct paths: {} seconds, {} entries",
247            (std::time::Instant::now() - start).as_secs(),
248            paths.len()
249        );
250
251        Ok(paths)
252    }
253
254    /// Get the weight for the edge between two nodes. Returns an error if there
255    /// is not exactly one edge between them.
256    fn get_edge_weight(
257        &self,
258        src: &NodeIndex,
259        dst: &NodeIndex,
260    ) -> Result<&ShadowEdge, NetGraphError> {
261        let src_id = self.node_index_to_id(*src).unwrap();
262        let dst_id = self.node_index_to_id(*dst).unwrap();
263        match &self.graph {
264            GraphWrapper::Directed(graph) => {
265                let mut edges = graph.edges_connecting(*src, *dst);
266                let edge = edges
267                    .next()
268                    .ok_or(format!("No edge connecting node {} to {}", src_id, dst_id))?;
269                if edges.count() != 0 {
270                    return Err(format!(
271                        "More than one edge connecting node {} to {}",
272                        src_id, dst_id
273                    )
274                    .into());
275                }
276                Ok(edge.weight())
277            }
278            GraphWrapper::Undirected(graph) => {
279                let mut edges = graph.edges_connecting(*src, *dst);
280                let edge = edges
281                    .next()
282                    .ok_or(format!("No edge connecting node {} to {}", src_id, dst_id))?;
283                if edges.count() != 0 {
284                    return Err(format!(
285                        "More than one edge connecting node {} to {}",
286                        src_id, dst_id
287                    )
288                    .into());
289                }
290                Ok(edge.weight())
291            }
292        }
293    }
294}
295
296/// Network characteristics for a path between two nodes.
297#[derive(Debug, Default, Clone, Copy)]
298pub struct PathProperties {
299    /// Latency in nanoseconds.
300    pub latency_ns: u64,
301    /// Packet loss as fraction.
302    pub packet_loss: f32,
303}
304
305impl PartialOrd for PathProperties {
306    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
307        // order by lowest latency first, then by lowest packet loss
308        match self.latency_ns.cmp(&other.latency_ns) {
309            std::cmp::Ordering::Equal => self.packet_loss.partial_cmp(&other.packet_loss),
310            x => Some(x),
311        }
312    }
313}
314
315impl PartialEq for PathProperties {
316    fn eq(&self, other: &Self) -> bool {
317        // PartialEq must be consistent with PartialOrd
318        self.partial_cmp(other) == Some(std::cmp::Ordering::Equal)
319    }
320}
321
322impl core::ops::Add for PathProperties {
323    type Output = Self;
324
325    fn add(self, other: Self) -> Self::Output {
326        Self {
327            latency_ns: self.latency_ns + other.latency_ns,
328            packet_loss: 1f32 - (1f32 - self.packet_loss) * (1f32 - other.packet_loss),
329        }
330    }
331}
332
333impl std::convert::From<&ShadowEdge> for PathProperties {
334    fn from(e: &ShadowEdge) -> Self {
335        Self {
336            latency_ns: e.latency.convert(units::TimePrefix::Nano).unwrap().value(),
337            packet_loss: e.packet_loss,
338        }
339    }
340}
341
342#[derive(Debug)]
343pub struct IpPreviouslyAssignedError;
344impl std::error::Error for IpPreviouslyAssignedError {}
345
346impl std::fmt::Display for IpPreviouslyAssignedError {
347    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
348        write!(f, "IP address has already been assigned")
349    }
350}
351
352/// Tool for assigning IP addresses to graph nodes.
353#[derive(Debug)]
354pub struct IpAssignment<T: Copy + Eq + Hash + std::fmt::Display> {
355    /// A map of host IP addresses to node ids.
356    map: HashMap<std::net::IpAddr, T>,
357    /// The last dynamically assigned address.
358    last_assigned_addr: std::net::IpAddr,
359}
360
361impl<T: Copy + Eq + Hash + std::fmt::Display> IpAssignment<T> {
362    pub fn new() -> Self {
363        Self {
364            map: HashMap::new(),
365            last_assigned_addr: std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 0)),
366        }
367    }
368
369    /// Get an unused address and assign it to a node.
370    pub fn assign(&mut self, node_id: T) -> std::net::IpAddr {
371        // loop until we find an unused address
372        loop {
373            let ip_addr = Self::increment_address(&self.last_assigned_addr);
374            self.last_assigned_addr = ip_addr;
375            if let std::collections::hash_map::Entry::Vacant(e) = self.map.entry(ip_addr) {
376                e.insert(node_id);
377                break ip_addr;
378            }
379        }
380    }
381
382    /// Assign an address to a node.
383    pub fn assign_ip(
384        &mut self,
385        node_id: T,
386        ip_addr: std::net::IpAddr,
387    ) -> Result<(), IpPreviouslyAssignedError> {
388        let entry = self.map.entry(ip_addr);
389        if let Entry::Occupied(_) = &entry {
390            return Err(IpPreviouslyAssignedError);
391        }
392        entry.or_insert(node_id);
393        Ok(())
394    }
395
396    /// Get the node that an address is assigned to.
397    pub fn get_node(&self, ip_addr: std::net::IpAddr) -> Option<T> {
398        self.map.get(&ip_addr).copied()
399    }
400
401    /// Get all nodes with assigned addresses.
402    pub fn get_nodes(&self) -> std::collections::HashSet<T> {
403        self.map.values().copied().collect()
404    }
405
406    fn increment_address(addr: &std::net::IpAddr) -> std::net::IpAddr {
407        match addr {
408            std::net::IpAddr::V4(x) => {
409                let addr_bits = u32::from(*x);
410                let mut increment = 1;
411                loop {
412                    // increment the address
413                    let next_addr = std::net::Ipv4Addr::from(addr_bits + increment);
414                    match next_addr.octets()[3] {
415                        // if the address ends in ".0" or ".255" (broadcast), try the next
416                        0 | 255 => increment += 1,
417                        _ => break std::net::IpAddr::V4(next_addr),
418                    }
419                }
420            }
421            std::net::IpAddr::V6(_) => unimplemented!(),
422        }
423    }
424}
425
426impl<T: Copy + Eq + Hash + std::fmt::Display> Default for IpAssignment<T> {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432/// Routing information for paths between nodes.
433#[derive(Debug)]
434pub struct RoutingInfo<T: Eq + Hash + std::fmt::Display + Clone + Copy> {
435    paths: HashMap<(T, T), PathProperties>,
436    packet_counters: std::sync::RwLock<HashMap<(T, T), u64>>,
437}
438
439impl<T: Eq + Hash + std::fmt::Display + Clone + Copy> RoutingInfo<T> {
440    pub fn new(paths: HashMap<(T, T), PathProperties>) -> Self {
441        Self {
442            paths,
443            packet_counters: std::sync::RwLock::new(HashMap::new()),
444        }
445    }
446
447    /// Get properties for the path from one node to another.
448    pub fn path(&self, start: T, end: T) -> Option<PathProperties> {
449        self.paths.get(&(start, end)).copied()
450    }
451
452    /// Increment the number of packets sent from one node to another.
453    pub fn increment_packet_count(&self, start: T, end: T) {
454        let key = (start, end);
455        let mut packet_counters = self.packet_counters.write().unwrap();
456        match packet_counters.get_mut(&key) {
457            Some(x) => *x = x.saturating_add(1),
458            None => assert!(packet_counters.insert(key, 1).is_none()),
459        }
460    }
461
462    /// Log the number of packets sent between nodes.
463    pub fn log_packet_counts(&self) {
464        // only logs paths that have transmitted at least one packet
465        for ((start, end), count) in self.packet_counters.read().unwrap().iter() {
466            let path = self.paths.get(&(*start, *end)).unwrap();
467            log::debug!(
468                "Found path {}->{}: latency={}ns, packet_loss={}, packet_count={}",
469                start,
470                end,
471                path.latency_ns,
472                path.packet_loss,
473                count,
474            );
475        }
476    }
477
478    pub fn get_smallest_latency_ns(&self) -> Option<u64> {
479        self.paths.values().map(|x| x.latency_ns).min()
480    }
481}
482
483/// Read and decompress a file.
484fn read_xz<P: AsRef<std::path::Path>>(path: P) -> Result<String, NetGraphError> {
485    let path = path.as_ref();
486
487    let mut f = std::io::BufReader::new(
488        std::fs::File::open(path).with_context(|| format!("Failed to open file: {path:?}"))?,
489    );
490
491    let mut decomp: Vec<u8> = Vec::new();
492    lzma_rs::xz_decompress(&mut f, &mut decomp).context("Failed to decompress file")?;
493    decomp.shrink_to_fit();
494
495    Ok(String::from_utf8(decomp)?)
496}
497
498/// Get the network graph as a string.
499pub fn load_network_graph(graph_options: &GraphOptions) -> Result<String, NetGraphError> {
500    Ok(match graph_options {
501        GraphOptions::Gml(GraphSource::File(FileSource {
502            compression: None,
503            path: f,
504        })) => std::fs::read_to_string(tilde_expansion(f))
505            .with_context(|| format!("Failed to read file: {f}"))?,
506        GraphOptions::Gml(GraphSource::File(FileSource {
507            compression: Some(Compression::Xz),
508            path: f,
509        })) => read_xz(tilde_expansion(f))?,
510        GraphOptions::Gml(GraphSource::Inline(s)) => s.clone(),
511        GraphOptions::OneGbitSwitch => configuration::ONE_GBIT_SWITCH_GRAPH.to_string(),
512    })
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_path_add() {
521        let p1 = PathProperties {
522            latency_ns: 23,
523            packet_loss: 0.35,
524        };
525        let p2 = PathProperties {
526            latency_ns: 11,
527            packet_loss: 0.85,
528        };
529
530        let p3 = p1 + p2;
531        assert_eq!(p3.latency_ns, 34);
532        assert!((p3.packet_loss - 0.9025).abs() < 0.01);
533    }
534
535    #[test]
536    fn test_nonexistent_id() {
537        for id in &[2, 3] {
538            let graph = format!(
539                r#"graph [
540                node [
541                  id 1
542                ]
543                node [
544                  id 3
545                ]
546                edge [
547                  source 1
548                  target {}
549                  latency "1 ns"
550                ]
551            ]"#,
552                id
553            );
554
555            if *id == 3 {
556                NetworkGraph::parse(&graph).unwrap();
557            } else {
558                NetworkGraph::parse(&graph).unwrap_err();
559            }
560        }
561    }
562
563    // disabled under miri due to https://github.com/rayon-rs/rayon/issues/952
564    #[test]
565    #[cfg_attr(miri, ignore)]
566    fn test_shortest_path() {
567        for directed in &[true, false] {
568            let graph = format!(
569                r#"graph [
570                  directed {}
571                  node [
572                    id 0
573                  ]
574                  node [
575                    id 1
576                  ]
577                  node [
578                    id 2
579                  ]
580                  edge [
581                    source 0
582                    target 0
583                    latency "3333 ns"
584                  ]
585                  edge [
586                    source 1
587                    target 1
588                    latency "5555 ns"
589                  ]
590                  edge [
591                    source 2
592                    target 2
593                    latency "7777 ns"
594                  ]
595                  edge [
596                    source 0
597                    target 1
598                    latency "3 ns"
599                  ]
600                  edge [
601                    source 1
602                    target 0
603                    latency "5 ns"
604                  ]
605                  edge [
606                    source 0
607                    target 2
608                    latency "7 ns"
609                  ]
610                  edge [
611                    source 2
612                    target 1
613                    latency "11 ns"
614                  ]
615                ]"#,
616                if *directed { 1 } else { 0 }
617            );
618            let graph = NetworkGraph::parse(&graph).unwrap();
619            let node_0 = *graph.node_id_to_index(0).unwrap();
620            let node_1 = *graph.node_id_to_index(1).unwrap();
621            let node_2 = *graph.node_id_to_index(2).unwrap();
622
623            let shortest_paths = graph
624                .compute_shortest_paths(&[node_0, node_1, node_2])
625                .unwrap();
626
627            let lookup_latency = |a, b| shortest_paths.get(&(a, b)).unwrap().latency_ns;
628
629            if *directed {
630                assert_eq!(lookup_latency(node_0, node_0), 3333);
631                assert_eq!(lookup_latency(node_0, node_1), 3);
632                assert_eq!(lookup_latency(node_0, node_2), 7);
633                assert_eq!(lookup_latency(node_1, node_0), 5);
634                assert_eq!(lookup_latency(node_1, node_1), 5555);
635                assert_eq!(lookup_latency(node_1, node_2), 12);
636                assert_eq!(lookup_latency(node_2, node_0), 16);
637                assert_eq!(lookup_latency(node_2, node_1), 11);
638                assert_eq!(lookup_latency(node_2, node_2), 7777);
639            } else {
640                assert_eq!(lookup_latency(node_0, node_0), 3333);
641                assert_eq!(lookup_latency(node_0, node_1), 3);
642                assert_eq!(lookup_latency(node_0, node_2), 7);
643                assert_eq!(lookup_latency(node_1, node_0), 3);
644                assert_eq!(lookup_latency(node_1, node_1), 5555);
645                assert_eq!(lookup_latency(node_1, node_2), 10);
646                assert_eq!(lookup_latency(node_2, node_0), 7);
647                assert_eq!(lookup_latency(node_2, node_1), 10);
648                assert_eq!(lookup_latency(node_2, node_2), 7777);
649            }
650        }
651    }
652
653    #[test]
654    fn test_increment_address_skip_broadcast() {
655        let addr = std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 254));
656        let incremented = IpAssignment::<i32>::increment_address(&addr);
657        assert!(incremented > addr);
658        assert_ne!(
659            incremented,
660            std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 255))
661        );
662    }
663}