shadow_rs/network/graph/
mod.rs

1mod petgraph_wrapper;
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::error::Error;
6use std::hash::Hash;
7
8use anyhow::Context;
9use log::*;
10use petgraph::graph::NodeIndex;
11use rayon::iter::{IntoParallelIterator, ParallelIterator};
12
13use crate::core::configuration::{self, Compression, FileSource, GraphOptions, GraphSource};
14use crate::network::graph::petgraph_wrapper::GraphWrapper;
15use crate::utility::tilde_expansion;
16use crate::utility::units::{self, Unit};
17
18type NetGraphError = Box<dyn Error + Send + Sync + 'static>;
19
20/// A graph node.
21#[derive(Debug, PartialEq, Eq)]
22pub struct ShadowNode {
23    pub id: u32,
24    pub bandwidth_down: Option<units::BitsPerSec<units::SiPrefixUpper>>,
25    pub bandwidth_up: Option<units::BitsPerSec<units::SiPrefixUpper>>,
26}
27
28impl TryFrom<gml_parser::gml::Node<'_>> for ShadowNode {
29    type Error = String;
30
31    fn try_from(mut gml_node: gml_parser::gml::Node) -> Result<Self, Self::Error> {
32        Ok(Self {
33            id: gml_node.id.ok_or("Node 'id' was not provided")?,
34            bandwidth_down: gml_node
35                .other
36                .remove("host_bandwidth_down")
37                .map(|bandwidth| {
38                    bandwidth
39                        .as_str()
40                        .ok_or("Node 'host_bandwidth_down' is not a string")?
41                        .parse()
42                        .map_err(|e| format!("Node 'host_bandwidth_down' is not a valid unit: {e}"))
43                })
44                .transpose()?,
45            bandwidth_up: gml_node
46                .other
47                .remove("host_bandwidth_up")
48                .map(|bandwidth| {
49                    bandwidth
50                        .as_str()
51                        .ok_or("Node 'host_bandwidth_up' is not a string")?
52                        .parse()
53                        .map_err(|e| format!("Node 'host_bandwidth_up' is not a valid unit: {e}"))
54                })
55                .transpose()?,
56        })
57    }
58}
59
60/// A graph edge.
61#[derive(Debug, PartialEq)]
62pub struct ShadowEdge {
63    pub source: u32,
64    pub target: u32,
65    pub latency: units::Time<units::TimePrefix>,
66    pub jitter: units::Time<units::TimePrefix>,
67    pub packet_loss: f32,
68}
69
70impl TryFrom<gml_parser::gml::Edge<'_>> for ShadowEdge {
71    type Error = String;
72
73    fn try_from(mut gml_edge: gml_parser::gml::Edge) -> Result<Self, Self::Error> {
74        let rv = Self {
75            source: gml_edge.source,
76            target: gml_edge.target,
77            latency: gml_edge
78                .other
79                .remove("latency")
80                .ok_or("Edge 'latency' was not provided")?
81                .as_str()
82                .ok_or("Edge 'latency' is not a string")?
83                .parse()
84                .map_err(|e| format!("Edge 'latency' is not a valid unit: {e}"))?,
85            jitter: match gml_edge.other.remove("jitter") {
86                Some(x) => x
87                    .as_str()
88                    .ok_or("Edge 'jitter' is not a string")?
89                    .parse()
90                    .map_err(|e| format!("Edge 'jitter' is not a valid unit: {e}"))?,
91                None => units::Time::new(0, units::TimePrefix::Milli),
92            },
93            packet_loss: match gml_edge.other.remove("packet_loss") {
94                Some(x) => x.as_float().ok_or("Edge 'packet_loss' is not a float")?,
95                None => 0.0,
96            },
97        };
98
99        if rv.packet_loss < 0f32 || rv.packet_loss > 1f32 {
100            return Err("Edge 'packet_loss' is not in the range [0,1]".into());
101        }
102
103        if rv.latency.value() == 0 {
104            return Err("Edge 'latency' must not be 0".into());
105        }
106
107        Ok(rv)
108    }
109}
110
111/// A network graph containing the petgraph graph and a map from gml node ids to petgraph node
112/// indexes.
113#[derive(Debug)]
114pub struct NetworkGraph {
115    graph: GraphWrapper<ShadowNode, ShadowEdge, u32>,
116    node_id_to_index_map: HashMap<u32, NodeIndex>,
117}
118
119impl NetworkGraph {
120    pub fn graph(&self) -> &GraphWrapper<ShadowNode, ShadowEdge, u32> {
121        &self.graph
122    }
123
124    pub fn node_id_to_index(&self, id: u32) -> Option<&NodeIndex> {
125        self.node_id_to_index_map.get(&id)
126    }
127
128    pub fn node_index_to_id(&self, index: NodeIndex) -> Option<u32> {
129        self.graph.node_weight(index).map(|w| w.id)
130    }
131
132    pub fn parse(graph_text: &str) -> Result<Self, NetGraphError> {
133        let gml_graph = gml_parser::parse(graph_text)?;
134
135        let mut g = match gml_graph.directed {
136            true => GraphWrapper::Directed(
137                petgraph::graph::Graph::<_, _, petgraph::Directed, _>::with_capacity(
138                    gml_graph.nodes.len(),
139                    gml_graph.edges.len(),
140                ),
141            ),
142            false => {
143                GraphWrapper::Undirected(
144                    petgraph::graph::Graph::<_, _, petgraph::Undirected, _>::with_capacity(
145                        gml_graph.nodes.len(),
146                        gml_graph.edges.len(),
147                    ),
148                )
149            }
150        };
151
152        // map from GML id to petgraph id
153        let mut id_map = HashMap::new();
154
155        for x in gml_graph.nodes.into_iter() {
156            let x: ShadowNode = x.try_into()?;
157            let gml_id = x.id;
158            let petgraph_id = g.add_node(x);
159            id_map.insert(gml_id, petgraph_id);
160        }
161
162        for x in gml_graph.edges.into_iter() {
163            let x: ShadowEdge = x.try_into()?;
164
165            let source = *id_map
166                .get(&x.source)
167                .ok_or(format!("Edge source {} doesn't exist", x.source))?;
168            let target = *id_map
169                .get(&x.target)
170                .ok_or(format!("Edge target {} doesn't exist", x.target))?;
171
172            g.add_edge(source, target, x);
173        }
174
175        Ok(Self {
176            graph: g,
177            node_id_to_index_map: id_map,
178        })
179    }
180
181    pub fn compute_shortest_paths(
182        &self,
183        nodes: &[NodeIndex],
184    ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
185        let start = std::time::Instant::now();
186
187        // calculate shortest paths
188        let mut paths: HashMap<(_, _), PathProperties> = nodes
189            .into_par_iter()
190            .flat_map(|src| {
191                match &self.graph {
192                    GraphWrapper::Directed(graph) => {
193                        petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
194                    }
195                    GraphWrapper::Undirected(graph) => {
196                        petgraph::algo::dijkstra(&graph, *src, None, |e| e.weight().into())
197                    }
198                }
199                .into_iter()
200                // ignore nodes that aren't in use
201                .filter(|(dst, _)| nodes.contains(dst))
202                // include the src node
203                .map(|(dst, path)| ((*src, dst), path))
204                .collect::<HashMap<(_, _), _>>()
205            })
206            .collect();
207
208        // use the self-loop for paths from a node to itself
209        for node in nodes {
210            // the dijkstra shortest path from node -> node will always be 0
211            assert_eq!(paths[&(*node, *node)], PathProperties::default());
212
213            // there must be a single self-loop for each node
214            paths.insert((*node, *node), self.get_edge_weight(node, node)?.into());
215        }
216
217        assert_eq!(paths.len(), nodes.len().pow(2));
218
219        debug!(
220            "Finished computing shortest paths: {} seconds, {} entries",
221            (std::time::Instant::now() - start).as_secs(),
222            paths.len()
223        );
224
225        Ok(paths)
226    }
227
228    pub fn get_direct_paths(
229        &self,
230        nodes: &[NodeIndex],
231    ) -> Result<HashMap<(NodeIndex, NodeIndex), PathProperties>, NetGraphError> {
232        let start = std::time::Instant::now();
233
234        let paths: HashMap<_, _> = nodes
235            .iter()
236            .flat_map(|src| nodes.iter().map(move |dst| (*src, *dst)))
237            // we require the graph to be connected with exactly one edge between any two nodes
238            .map(|(src, dst)| Ok(((src, dst), self.get_edge_weight(&src, &dst)?.into())))
239            .collect::<Result<_, NetGraphError>>()?;
240
241        assert_eq!(paths.len(), nodes.len().pow(2));
242
243        debug!(
244            "Finished computing direct paths: {} seconds, {} entries",
245            (std::time::Instant::now() - start).as_secs(),
246            paths.len()
247        );
248
249        Ok(paths)
250    }
251
252    /// Get the weight for the edge between two nodes. Returns an error if there
253    /// is not exactly one edge between them.
254    fn get_edge_weight(
255        &self,
256        src: &NodeIndex,
257        dst: &NodeIndex,
258    ) -> Result<&ShadowEdge, NetGraphError> {
259        let src_id = self.node_index_to_id(*src).unwrap();
260        let dst_id = self.node_index_to_id(*dst).unwrap();
261        match &self.graph {
262            GraphWrapper::Directed(graph) => {
263                let mut edges = graph.edges_connecting(*src, *dst);
264                let edge = edges
265                    .next()
266                    .ok_or(format!("No edge connecting node {src_id} to {dst_id}"))?;
267                if edges.count() != 0 {
268                    return Err(
269                        format!("More than one edge connecting node {src_id} to {dst_id}").into(),
270                    );
271                }
272                Ok(edge.weight())
273            }
274            GraphWrapper::Undirected(graph) => {
275                let mut edges = graph.edges_connecting(*src, *dst);
276                let edge = edges
277                    .next()
278                    .ok_or(format!("No edge connecting node {src_id} to {dst_id}"))?;
279                if edges.count() != 0 {
280                    return Err(
281                        format!("More than one edge connecting node {src_id} to {dst_id}").into(),
282                    );
283                }
284                Ok(edge.weight())
285            }
286        }
287    }
288}
289
290/// Network characteristics for a path between two nodes.
291#[derive(Debug, Default, Clone, Copy)]
292pub struct PathProperties {
293    /// Latency in nanoseconds.
294    pub latency_ns: u64,
295    /// Packet loss as fraction.
296    pub packet_loss: f32,
297}
298
299impl PartialOrd for PathProperties {
300    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
301        // order by lowest latency first, then by lowest packet loss
302        match self.latency_ns.cmp(&other.latency_ns) {
303            std::cmp::Ordering::Equal => self.packet_loss.partial_cmp(&other.packet_loss),
304            x => Some(x),
305        }
306    }
307}
308
309impl PartialEq for PathProperties {
310    fn eq(&self, other: &Self) -> bool {
311        // PartialEq must be consistent with PartialOrd
312        self.partial_cmp(other) == Some(std::cmp::Ordering::Equal)
313    }
314}
315
316impl core::ops::Add for PathProperties {
317    type Output = Self;
318
319    fn add(self, other: Self) -> Self::Output {
320        Self {
321            latency_ns: self.latency_ns + other.latency_ns,
322            packet_loss: 1f32 - (1f32 - self.packet_loss) * (1f32 - other.packet_loss),
323        }
324    }
325}
326
327impl std::convert::From<&ShadowEdge> for PathProperties {
328    fn from(e: &ShadowEdge) -> Self {
329        Self {
330            latency_ns: e.latency.convert(units::TimePrefix::Nano).unwrap().value(),
331            packet_loss: e.packet_loss,
332        }
333    }
334}
335
336#[derive(Debug)]
337pub struct IpPreviouslyAssignedError;
338impl std::error::Error for IpPreviouslyAssignedError {}
339
340impl std::fmt::Display for IpPreviouslyAssignedError {
341    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
342        write!(f, "IP address has already been assigned")
343    }
344}
345
346/// Tool for assigning IP addresses to graph nodes.
347#[derive(Debug)]
348pub struct IpAssignment<T: Copy + Eq + Hash + std::fmt::Display> {
349    /// A map of host IP addresses to node ids.
350    map: HashMap<std::net::IpAddr, T>,
351    /// The last dynamically assigned address.
352    last_assigned_addr: std::net::IpAddr,
353}
354
355impl<T: Copy + Eq + Hash + std::fmt::Display> IpAssignment<T> {
356    pub fn new() -> Self {
357        Self {
358            map: HashMap::new(),
359            last_assigned_addr: std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 0)),
360        }
361    }
362
363    /// Get an unused address and assign it to a node.
364    pub fn assign(&mut self, node_id: T) -> std::net::IpAddr {
365        // loop until we find an unused address
366        loop {
367            let ip_addr = Self::increment_address(&self.last_assigned_addr);
368            self.last_assigned_addr = ip_addr;
369            if let std::collections::hash_map::Entry::Vacant(e) = self.map.entry(ip_addr) {
370                e.insert(node_id);
371                break ip_addr;
372            }
373        }
374    }
375
376    /// Assign an address to a node.
377    pub fn assign_ip(
378        &mut self,
379        node_id: T,
380        ip_addr: std::net::IpAddr,
381    ) -> Result<(), IpPreviouslyAssignedError> {
382        let entry = self.map.entry(ip_addr);
383        if let Entry::Occupied(_) = &entry {
384            return Err(IpPreviouslyAssignedError);
385        }
386        entry.or_insert(node_id);
387        Ok(())
388    }
389
390    /// Get the node that an address is assigned to.
391    pub fn get_node(&self, ip_addr: std::net::IpAddr) -> Option<T> {
392        self.map.get(&ip_addr).copied()
393    }
394
395    /// Get all nodes with assigned addresses.
396    pub fn get_nodes(&self) -> std::collections::HashSet<T> {
397        self.map.values().copied().collect()
398    }
399
400    fn increment_address(addr: &std::net::IpAddr) -> std::net::IpAddr {
401        match addr {
402            std::net::IpAddr::V4(x) => {
403                let addr_bits = u32::from(*x);
404                let mut increment = 1;
405                loop {
406                    // increment the address
407                    let next_addr = std::net::Ipv4Addr::from(addr_bits + increment);
408                    match next_addr.octets()[3] {
409                        // if the address ends in ".0" or ".255" (broadcast), try the next
410                        0 | 255 => increment += 1,
411                        _ => break std::net::IpAddr::V4(next_addr),
412                    }
413                }
414            }
415            std::net::IpAddr::V6(_) => unimplemented!(),
416        }
417    }
418}
419
420impl<T: Copy + Eq + Hash + std::fmt::Display> Default for IpAssignment<T> {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426/// Routing information for paths between nodes.
427#[derive(Debug)]
428pub struct RoutingInfo<T: Eq + Hash + std::fmt::Display + Clone + Copy> {
429    paths: HashMap<(T, T), PathProperties>,
430    packet_counters: std::sync::RwLock<HashMap<(T, T), u64>>,
431}
432
433impl<T: Eq + Hash + std::fmt::Display + Clone + Copy> RoutingInfo<T> {
434    pub fn new(paths: HashMap<(T, T), PathProperties>) -> Self {
435        Self {
436            paths,
437            packet_counters: std::sync::RwLock::new(HashMap::new()),
438        }
439    }
440
441    /// Get properties for the path from one node to another.
442    pub fn path(&self, start: T, end: T) -> Option<PathProperties> {
443        self.paths.get(&(start, end)).copied()
444    }
445
446    /// Increment the number of packets sent from one node to another.
447    pub fn increment_packet_count(&self, start: T, end: T) {
448        let key = (start, end);
449        let mut packet_counters = self.packet_counters.write().unwrap();
450        match packet_counters.get_mut(&key) {
451            Some(x) => *x = x.saturating_add(1),
452            None => assert!(packet_counters.insert(key, 1).is_none()),
453        }
454    }
455
456    /// Log the number of packets sent between nodes.
457    pub fn log_packet_counts(&self) {
458        // only logs paths that have transmitted at least one packet
459        for ((start, end), count) in self.packet_counters.read().unwrap().iter() {
460            let path = self.paths.get(&(*start, *end)).unwrap();
461            log::debug!(
462                "Found path {}->{}: latency={}ns, packet_loss={}, packet_count={}",
463                start,
464                end,
465                path.latency_ns,
466                path.packet_loss,
467                count,
468            );
469        }
470    }
471
472    pub fn get_smallest_latency_ns(&self) -> Option<u64> {
473        self.paths.values().map(|x| x.latency_ns).min()
474    }
475}
476
477/// Read and decompress a file.
478fn read_xz<P: AsRef<std::path::Path>>(path: P) -> Result<String, NetGraphError> {
479    let path = path.as_ref();
480
481    let mut f = std::io::BufReader::new(
482        std::fs::File::open(path).with_context(|| format!("Failed to open file: {path:?}"))?,
483    );
484
485    let mut decomp: Vec<u8> = Vec::new();
486    lzma_rs::xz_decompress(&mut f, &mut decomp).context("Failed to decompress file")?;
487    decomp.shrink_to_fit();
488
489    Ok(String::from_utf8(decomp)?)
490}
491
492/// Get the network graph as a string.
493pub fn load_network_graph(graph_options: &GraphOptions) -> Result<String, NetGraphError> {
494    Ok(match graph_options {
495        GraphOptions::Gml(GraphSource::File(FileSource {
496            compression: None,
497            path: f,
498        })) => std::fs::read_to_string(tilde_expansion(f))
499            .with_context(|| format!("Failed to read file: {f}"))?,
500        GraphOptions::Gml(GraphSource::File(FileSource {
501            compression: Some(Compression::Xz),
502            path: f,
503        })) => read_xz(tilde_expansion(f))?,
504        GraphOptions::Gml(GraphSource::Inline(s)) => s.clone(),
505        GraphOptions::OneGbitSwitch => configuration::ONE_GBIT_SWITCH_GRAPH.to_string(),
506    })
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_path_add() {
515        let p1 = PathProperties {
516            latency_ns: 23,
517            packet_loss: 0.35,
518        };
519        let p2 = PathProperties {
520            latency_ns: 11,
521            packet_loss: 0.85,
522        };
523
524        let p3 = p1 + p2;
525        assert_eq!(p3.latency_ns, 34);
526        assert!((p3.packet_loss - 0.9025).abs() < 0.01);
527    }
528
529    #[test]
530    fn test_nonexistent_id() {
531        for id in &[2, 3] {
532            let graph = format!(
533                r#"graph [
534                node [
535                  id 1
536                ]
537                node [
538                  id 3
539                ]
540                edge [
541                  source 1
542                  target {id}
543                  latency "1 ns"
544                ]
545            ]"#,
546            );
547
548            if *id == 3 {
549                NetworkGraph::parse(&graph).unwrap();
550            } else {
551                NetworkGraph::parse(&graph).unwrap_err();
552            }
553        }
554    }
555
556    // disabled under miri due to https://github.com/rayon-rs/rayon/issues/952
557    #[test]
558    #[cfg_attr(miri, ignore)]
559    fn test_shortest_path() {
560        for directed in &[true, false] {
561            let graph = format!(
562                r#"graph [
563                  directed {}
564                  node [
565                    id 0
566                  ]
567                  node [
568                    id 1
569                  ]
570                  node [
571                    id 2
572                  ]
573                  edge [
574                    source 0
575                    target 0
576                    latency "3333 ns"
577                  ]
578                  edge [
579                    source 1
580                    target 1
581                    latency "5555 ns"
582                  ]
583                  edge [
584                    source 2
585                    target 2
586                    latency "7777 ns"
587                  ]
588                  edge [
589                    source 0
590                    target 1
591                    latency "3 ns"
592                  ]
593                  edge [
594                    source 1
595                    target 0
596                    latency "5 ns"
597                  ]
598                  edge [
599                    source 0
600                    target 2
601                    latency "7 ns"
602                  ]
603                  edge [
604                    source 2
605                    target 1
606                    latency "11 ns"
607                  ]
608                ]"#,
609                if *directed { 1 } else { 0 }
610            );
611            let graph = NetworkGraph::parse(&graph).unwrap();
612            let node_0 = *graph.node_id_to_index(0).unwrap();
613            let node_1 = *graph.node_id_to_index(1).unwrap();
614            let node_2 = *graph.node_id_to_index(2).unwrap();
615
616            let shortest_paths = graph
617                .compute_shortest_paths(&[node_0, node_1, node_2])
618                .unwrap();
619
620            let lookup_latency = |a, b| shortest_paths.get(&(a, b)).unwrap().latency_ns;
621
622            if *directed {
623                assert_eq!(lookup_latency(node_0, node_0), 3333);
624                assert_eq!(lookup_latency(node_0, node_1), 3);
625                assert_eq!(lookup_latency(node_0, node_2), 7);
626                assert_eq!(lookup_latency(node_1, node_0), 5);
627                assert_eq!(lookup_latency(node_1, node_1), 5555);
628                assert_eq!(lookup_latency(node_1, node_2), 12);
629                assert_eq!(lookup_latency(node_2, node_0), 16);
630                assert_eq!(lookup_latency(node_2, node_1), 11);
631                assert_eq!(lookup_latency(node_2, node_2), 7777);
632            } else {
633                assert_eq!(lookup_latency(node_0, node_0), 3333);
634                assert_eq!(lookup_latency(node_0, node_1), 3);
635                assert_eq!(lookup_latency(node_0, node_2), 7);
636                assert_eq!(lookup_latency(node_1, node_0), 3);
637                assert_eq!(lookup_latency(node_1, node_1), 5555);
638                assert_eq!(lookup_latency(node_1, node_2), 10);
639                assert_eq!(lookup_latency(node_2, node_0), 7);
640                assert_eq!(lookup_latency(node_2, node_1), 10);
641                assert_eq!(lookup_latency(node_2, node_2), 7777);
642            }
643        }
644    }
645
646    #[test]
647    fn test_increment_address_skip_broadcast() {
648        let addr = std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 254));
649        let incremented = IpAssignment::<i32>::increment_address(&addr);
650        assert!(incremented > addr);
651        assert_ne!(
652            incremented,
653            std::net::IpAddr::V4(std::net::Ipv4Addr::new(11, 0, 0, 255))
654        );
655    }
656}