petgraph/algo/
floyd_warshall.rs

1use alloc::{vec, vec::Vec};
2use core::hash::Hash;
3
4use hashbrown::HashMap;
5
6use crate::algo::{BoundedMeasure, NegativeCycle};
7use crate::visit::{
8    EdgeRef, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers, NodeCompactIndexable,
9};
10
11#[allow(clippy::type_complexity, clippy::needless_range_loop)]
12/// \[Generic\] [Floyd–Warshall algorithm](https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm) is an algorithm for all pairs shortest path problem
13///
14/// Compute the length of each shortest path in a weighted graph with positive or negative edge weights (but with no negative cycles).
15///
16/// # Arguments
17/// * `graph`: graph with no negative cycle
18/// * `edge_cost`: closure that returns cost of a particular edge
19///
20/// # Returns
21/// * `Ok`: (if graph contains no negative cycle) a hashmap containing all pairs shortest paths
22/// * `Err`: if graph contains negative cycle.
23///
24/// # Examples
25/// ```rust
26/// use petgraph::{prelude::*, Graph, Directed};
27/// use petgraph::algo::floyd_warshall;
28/// use hashbrown::HashMap;
29///
30/// let mut graph: Graph<(), (), Directed> = Graph::new();
31/// let a = graph.add_node(());
32/// let b = graph.add_node(());
33/// let c = graph.add_node(());
34/// let d = graph.add_node(());
35///
36/// graph.extend_with_edges(&[
37///    (a, b),
38///    (a, c),
39///    (a, d),
40///    (b, c),
41///    (b, d),
42///    (c, d)
43/// ]);
44///
45/// let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
46///    ((a, a), 0), ((a, b), 1), ((a, c), 4), ((a, d), 10),
47///    ((b, b), 0), ((b, c), 2), ((b, d), 2),
48///    ((c, c), 0), ((c, d), 2)
49/// ].iter().cloned().collect();
50/// //     ----- b --------
51/// //    |      ^         | 2
52/// //    |    1 |    4    v
53/// //  2 |      a ------> c
54/// //    |   10 |         | 2
55/// //    |      v         v
56/// //     --->  d <-------
57///
58/// let inf = core::i32::MAX;
59/// let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
60///    ((a, a), 0), ((a, b), 1), ((a, c), 3), ((a, d), 3),
61///    ((b, a), inf), ((b, b), 0), ((b, c), 2), ((b, d), 2),
62///    ((c, a), inf), ((c, b), inf), ((c, c), 0), ((c, d), 2),
63///    ((d, a), inf), ((d, b), inf), ((d, c), inf), ((d, d), 0),
64/// ].iter().cloned().collect();
65///
66///
67/// let res = floyd_warshall(&graph, |edge| {
68///     if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
69///         *weight
70///     } else {
71///         inf
72///     }
73/// }).unwrap();
74///
75/// let nodes = [a, b, c, d];
76/// for node1 in &nodes {
77///     for node2 in &nodes {
78///         assert_eq!(res.get(&(*node1, *node2)).unwrap(), expected_res.get(&(*node1, *node2)).unwrap());
79///     }
80/// }
81/// ```
82pub fn floyd_warshall<G, F, K>(
83    graph: G,
84    edge_cost: F,
85) -> Result<HashMap<(G::NodeId, G::NodeId), K>, NegativeCycle>
86where
87    G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp,
88    G::NodeId: Eq + Hash,
89    F: FnMut(G::EdgeRef) -> K,
90    K: BoundedMeasure + Copy,
91{
92    let num_of_nodes = graph.node_count();
93
94    // |V|x|V| matrix
95    let mut m_dist = Some(vec![vec![K::max(); num_of_nodes]; num_of_nodes]);
96
97    _floyd_warshall_path(graph, edge_cost, &mut m_dist, &mut None)?;
98
99    let mut distance_map: HashMap<(G::NodeId, G::NodeId), K> =
100        HashMap::with_capacity(num_of_nodes * num_of_nodes);
101
102    if let Some(dist) = m_dist {
103        for i in 0..num_of_nodes {
104            for j in 0..num_of_nodes {
105                distance_map.insert((graph.from_index(i), graph.from_index(j)), dist[i][j]);
106            }
107        }
108    }
109
110    Ok(distance_map)
111}
112
113#[allow(clippy::type_complexity, clippy::needless_range_loop)]
114/// \[Generic\] [Floyd–Warshall algorithm](https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm) is an algorithm for all pairs shortest path problem
115///
116/// Compute all pairs shortest paths in a weighted graph with positive or negative edge weights (but with no negative cycles).
117/// Returns HashMap of shortest path lengths. Additionally, returns HashMap of intermediate nodes along shortest path for indicated edges.
118///
119/// # Arguments
120/// * `graph`: graph with no negative cycle
121/// * `edge_cost`: closure that returns cost of a particular edge
122///
123/// # Returns
124/// * `Ok`: (if graph contains no negative cycle) a hashmap containing all pairs shortest path distances and a hashmap for all pairs shortest paths
125/// * `Err`: if graph contains negative cycle.
126///
127/// # Examples
128/// ```rust
129/// use petgraph::{prelude::*, Graph, Directed};
130/// use petgraph::algo::floyd_warshall::floyd_warshall_path;
131/// use std::collections::HashMap;
132///
133/// let mut graph: Graph<(), (), Directed> = Graph::new();
134/// let a = graph.add_node(());
135/// let b = graph.add_node(());
136/// let c = graph.add_node(());
137/// let d = graph.add_node(());
138///
139/// graph.extend_with_edges(&[
140///    (a, b),
141///    (a, c),
142///    (a, d),
143///    (b, c),
144///    (b, d),
145///    (c, d)
146/// ]);
147///
148/// let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
149///    ((a, a), 0), ((a, b), 1), ((a, c), 4), ((a, d), 10),
150///    ((b, b), 0), ((b, c), 2), ((b, d), 2),
151///    ((c, c), 0), ((c, d), 2)
152/// ].iter().cloned().collect();
153/// //     ----- b --------
154/// //    |      ^         | 2
155/// //    |    1 |    4    v
156/// //  2 |      a ------> c
157/// //    |   10 |         | 2
158/// //    |      v         v
159/// //     --->  d <-------
160///
161/// let inf = std::i32::MAX;
162/// let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
163///    ((a, a), 0), ((a, b), 1), ((a, c), 3), ((a, d), 3),
164///    ((b, a), inf), ((b, b), 0), ((b, c), 2), ((b, d), 2),
165///    ((c, a), inf), ((c, b), inf), ((c, c), 0), ((c, d), 2),
166///    ((d, a), inf), ((d, b), inf), ((d, c), inf), ((d, d), 0),
167/// ].iter().cloned().collect();
168///
169///
170/// let (res, prev) = floyd_warshall_path(&graph, |edge| {
171///     if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
172///         *weight
173///     } else {
174///         inf
175///     }
176/// }).unwrap();
177///
178/// assert_eq!(prev[0][2], Some(1));
179///
180/// let nodes = [a, b, c, d];
181/// for node1 in &nodes {
182///     for node2 in &nodes {
183///         assert_eq!(res.get(&(*node1, *node2)).unwrap(), expected_res.get(&(*node1, *node2)).unwrap());
184///     }
185/// }
186///
187/// ```
188pub fn floyd_warshall_path<G, F, K>(
189    graph: G,
190    edge_cost: F,
191) -> Result<(HashMap<(G::NodeId, G::NodeId), K>, Vec<Vec<Option<usize>>>), NegativeCycle>
192where
193    G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp,
194    G::NodeId: Eq + Hash,
195    F: FnMut(G::EdgeRef) -> K,
196    K: BoundedMeasure + Copy,
197{
198    let num_of_nodes = graph.node_count();
199
200    // |V|x|V| matrix
201    let mut m_dist = Some(vec![vec![K::max(); num_of_nodes]; num_of_nodes]);
202    // `prev[source][target]` holds the penultimate vertex on path from `source` to `target`, except `prev[source][source]`, which always stores `source`.
203    let mut m_prev = Some(vec![vec![None; num_of_nodes]; num_of_nodes]);
204
205    _floyd_warshall_path(graph, edge_cost, &mut m_dist, &mut m_prev)?;
206
207    let mut distance_map = HashMap::with_capacity(num_of_nodes * num_of_nodes);
208
209    if let Some(dist) = m_dist {
210        for i in 0..num_of_nodes {
211            for j in 0..num_of_nodes {
212                distance_map.insert((graph.from_index(i), graph.from_index(j)), dist[i][j]);
213            }
214        }
215    }
216
217    Ok((distance_map, m_prev.unwrap()))
218}
219
220/// Helper function to copy a value to a 2D array
221fn set_object<K: Clone>(m_dist: &mut Option<Vec<Vec<K>>>, i: usize, j: usize, value: K) {
222    if let Some(dist) = m_dist {
223        dist[i][j] = value;
224    }
225}
226
227/// Helper to check if the distance map is greater then a specific value
228fn is_greater<K: PartialOrd>(
229    m_dist: &mut Option<Vec<Vec<K>>>,
230    i: usize,
231    j: usize,
232    value: K,
233) -> bool {
234    if let Some(dist) = m_dist {
235        return dist[i][j] > value;
236    }
237    false
238}
239
240/// Helper that implements the floyd warshall routine, but paths are optional for memory overhead.
241fn _floyd_warshall_path<G, F, K>(
242    graph: G,
243    mut edge_cost: F,
244    m_dist: &mut Option<Vec<Vec<K>>>,
245    m_prev: &mut Option<Vec<Vec<Option<usize>>>>,
246) -> Result<(), NegativeCycle>
247where
248    G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp,
249    G::NodeId: Eq + Hash,
250    F: FnMut(G::EdgeRef) -> K,
251    K: BoundedMeasure + Copy,
252{
253    let num_of_nodes = graph.node_count();
254
255    // Initialize distances and predecessors for edges
256    for edge in graph.edge_references() {
257        let source = graph.to_index(edge.source());
258        let target = graph.to_index(edge.target());
259        let cost = edge_cost(edge);
260        if is_greater(m_dist, source, target, cost) {
261            set_object(m_dist, source, target, cost);
262            set_object(m_prev, source, target, Some(source));
263
264            if !graph.is_directed() {
265                set_object(m_dist, target, source, cost);
266                set_object(m_prev, target, source, Some(target));
267            }
268        }
269    }
270
271    // Distance of each node to itself is the default value
272    for node in graph.node_identifiers() {
273        let index = graph.to_index(node);
274        set_object(m_dist, index, index, K::default());
275        set_object(m_prev, index, index, Some(index));
276    }
277
278    // Perform the Floyd-Warshall algorithm
279    for k in 0..num_of_nodes {
280        for i in 0..num_of_nodes {
281            for j in 0..num_of_nodes {
282                if let Some(dist) = m_dist {
283                    let (result, overflow) = dist[i][k].overflowing_add(dist[k][j]);
284                    if !overflow && dist[i][j] > result {
285                        dist[i][j] = result;
286                        if let Some(prev) = m_prev {
287                            prev[i][j] = prev[k][j];
288                        }
289                    }
290                }
291            }
292        }
293    }
294
295    // value less than 0(default value) indicates a negative cycle
296    for i in 0..num_of_nodes {
297        if let Some(dist) = m_dist {
298            if dist[i][i] < K::default() {
299                return Err(NegativeCycle(()));
300            }
301        }
302    }
303    Ok(())
304}