petgraph/algo/
steiner_tree.rs

1use alloc::vec::Vec;
2use core::{fmt::Debug, hash::Hash};
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::algo::floyd_warshall::floyd_warshall_path;
7use crate::algo::{dijkstra, min_spanning_tree, BoundedMeasure, Measure};
8use crate::data::FromElements;
9use crate::graph::{IndexType, NodeIndex, UnGraph};
10use crate::visit::{
11    Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoNeighbors,
12    IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable, NodeIndexable, Visitable,
13};
14use crate::Undirected;
15
16#[cfg(feature = "stable_graph")]
17use crate::stable_graph::StableGraph;
18
19type Edge<G> = (<G as GraphBase>::NodeId, <G as GraphBase>::NodeId);
20type Subgraph<G> = HashSet<<G as GraphBase>::NodeId>;
21
22fn compute_shortest_path_length<G>(graph: G, source: G::NodeId, target: G::NodeId) -> G::EdgeWeight
23where
24    G: Visitable + IntoEdges,
25    G::NodeId: Eq + Hash,
26    G::EdgeWeight: Measure + Copy,
27{
28    let output = dijkstra(graph, source, Some(target), |e| *e.weight());
29    output[&target]
30}
31
32fn compute_metric_closure<G>(
33    graph: G,
34    terminals: &[G::NodeId],
35) -> HashMap<(usize, usize), G::EdgeWeight>
36where
37    G: Data + IntoNodeReferences + NodeIndexable + Visitable + IntoEdges,
38    G::EdgeWeight: Copy + Measure,
39    G::NodeId: PartialOrd + Eq + Hash,
40{
41    let mut closure = HashMap::new();
42    for (i, node_id_1) in terminals.iter().enumerate() {
43        for node_id_2 in terminals.iter().skip(i + 1) {
44            closure.insert(
45                (graph.to_index(*node_id_1), graph.to_index(*node_id_2)),
46                compute_shortest_path_length(graph, *node_id_1, *node_id_2),
47            );
48        }
49    }
50    closure
51}
52
53fn subgraph_edges_from_metric_closure<G>(
54    graph: G,
55    minimum_spanning_closure: G,
56) -> (Vec<Edge<G>>, Subgraph<G>)
57where
58    G: GraphBase
59        + NodeCompactIndexable
60        + IntoEdgeReferences
61        + IntoNodeIdentifiers
62        + GraphProp
63        + IntoNodeReferences,
64    G::EdgeWeight: BoundedMeasure + Copy,
65    G::NodeId: Eq + Hash + Ord + Debug,
66{
67    let mut retained_nodes = HashSet::new();
68    let mut retained_edges = Vec::new();
69    let (_, prev) = floyd_warshall_path(graph, |e| *e.weight()).unwrap();
70
71    for edge in minimum_spanning_closure.edge_references() {
72        let target = graph.to_index(edge.target());
73        let source = graph.to_index(edge.source());
74
75        let mut current = target;
76        while current != source {
77            if let Some(prev_node) = prev[source][current] {
78                retained_nodes.insert(graph.from_index(prev_node));
79                retained_nodes.insert(graph.from_index(current));
80                retained_edges.push((graph.from_index(prev_node), graph.from_index(current)));
81                current = prev_node;
82            }
83        }
84    }
85
86    (retained_edges, retained_nodes)
87}
88
89fn non_terminal_leaves<G>(graph: G, terminals: &[G::NodeId]) -> HashSet<G::NodeId>
90where
91    G: GraphBase + IntoNodeReferences + IntoNodeIdentifiers + IntoNeighbors,
92    G::NodeId: Hash + Eq + Debug,
93    G::NodeRef: Eq + Hash,
94{
95    let mut removed_leaves = HashSet::new();
96
97    let mut remaining_leaves = graph
98        .node_identifiers()
99        .filter(|node_id| {
100            graph.neighbors(*node_id).collect::<HashSet<_>>().len() == 1
101                && !terminals.contains(node_id)
102        })
103        .collect::<HashSet<_>>();
104
105    while !remaining_leaves.is_empty() {
106        remaining_leaves = graph
107            .node_identifiers()
108            .filter(|node_id| {
109                !terminals.contains(node_id)
110                    && !removed_leaves.contains(node_id)
111                    && (graph
112                        .neighbors(*node_id)
113                        .collect::<HashSet<_>>()
114                        .difference(&removed_leaves))
115                    .collect::<Vec<_>>()
116                    .len()
117                        == 1
118            })
119            .collect::<HashSet<_>>();
120
121        removed_leaves = removed_leaves
122            .union(&remaining_leaves)
123            .cloned()
124            .collect::<HashSet<_>>();
125    }
126
127    removed_leaves
128}
129
130/// \[Generic\] Steiner Tree algorithm.
131///
132/// Computes the Steiner tree of an undirected graph given a set of terminal nodes via [Kou's algorithm][pr]. Implementation details mirrors NetworkX implementation.
133///
134/// Returns a `Graph` representing the Steiner tree of the input graph.
135///
136///
137/// # Complexity
138/// Time complexity is **O(|S| |V|²)**.
139/// where **|V|** the number of vertices (i.e nodes) and **|E|** the number of edges.
140///
141/// [pr]: https://networkx.org/documentation/stable/_modules/networkx/algorithms/approximation/steinertree.html#steiner_tree
142///
143/// # Example
144/// ```rust
145/// use petgraph::Graph;
146/// use petgraph::algo::steiner_tree::steiner_tree;
147/// use petgraph::graph::UnGraph;
148/// let mut graph = UnGraph::<(), i32>::default();
149/// let a = graph.add_node(());
150/// let b = graph.add_node(());
151/// let c = graph.add_node(());
152/// let d = graph.add_node(());
153/// let e = graph.add_node(());
154/// let f = graph.add_node(());
155/// graph.extend_with_edges([
156///     (a, b, 7),
157///     (a, f, 6),
158///     (b, c, 1),
159///     (b, f, 5),
160///     (c, d, 1),
161///     (c, e, 3),
162///     (d, e, 1),
163///     (d, f, 4),
164///     (e, f, 10),
165/// ]);
166/// let terminals = vec![a, c, e, f];
167/// let tree = steiner_tree(&graph, &terminals);
168/// assert_eq!(tree.edge_weights().sum::<i32>(), 12);
169///
170#[cfg(feature = "stable_graph")]
171pub fn steiner_tree<N, E, Ix>(
172    graph: &UnGraph<N, E, Ix>,
173    terminals: &[NodeIndex<Ix>],
174) -> StableGraph<N, E, Undirected, Ix>
175where
176    N: Default + Clone + Eq + Hash + Debug,
177    E: Copy + Eq + Ord + Measure + BoundedMeasure,
178    Ix: IndexType,
179{
180    let metric_closure = compute_metric_closure(&graph, terminals);
181    let metric_closure_graph: UnGraph<N, E, _> = UnGraph::from_edges(
182        metric_closure
183            .iter()
184            .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
185    );
186
187    let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
188
189    let (subgraph_edges, subgraph_nodes) =
190        subgraph_edges_from_metric_closure(graph, &minimum_spanning);
191
192    let mut graph = StableGraph::from(graph.clone());
193    graph.retain_edges(|graph, e| {
194        let edge = graph.edge_endpoints(e).unwrap();
195        subgraph_edges.contains(&(edge.0, edge.1)) || subgraph_edges.contains(&(edge.1, edge.0))
196    });
197    graph.retain_nodes(|_, n| subgraph_nodes.contains(&n));
198
199    let non_terminal_nodes = non_terminal_leaves(&graph, terminals);
200    graph.retain_nodes(|_, n| !non_terminal_nodes.contains(&n));
201
202    graph
203}
204
205#[cfg(test)]
206mod test {
207    use alloc::vec;
208
209    use hashbrown::{HashMap, HashSet};
210
211    use super::{compute_metric_closure, non_terminal_leaves, subgraph_edges_from_metric_closure};
212    use crate::graph::NodeIndex;
213    use crate::{
214        algo::{min_spanning_tree, EdgeRef, UnGraph},
215        data::FromElements,
216        Graph, Undirected,
217    };
218
219    #[test]
220    fn test_compute_metric_closure() {
221        let mut graph = Graph::<(), i32, Undirected>::new_undirected();
222
223        let a = graph.add_node(());
224        let b = graph.add_node(());
225        let c = graph.add_node(());
226        let d = graph.add_node(());
227        let e = graph.add_node(());
228        let f = graph.add_node(());
229        graph.extend_with_edges([
230            (a, b, 7),
231            (a, f, 6),
232            (b, c, 1),
233            (b, f, 5),
234            (c, d, 1),
235            (c, e, 3),
236            (d, e, 1),
237            (d, f, 4),
238            (e, f, 10),
239        ]);
240
241        let terminals = vec![a, c, e, f];
242        let metric_closure = compute_metric_closure(&graph, &terminals);
243
244        let metric_closure_graph: UnGraph<&str, _, _> = UnGraph::from_edges(
245            metric_closure
246                .iter()
247                .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
248        );
249
250        let ref_weights = HashMap::<_, _>::from([
251            ((0, 2), 8),
252            ((0, 4), 10),
253            ((0, 5), 6),
254            ((2, 4), 2),
255            ((2, 5), 5),
256            ((4, 5), 5),
257        ]);
258        for ((node1, node2), ref_weight) in ref_weights {
259            assert_eq!(metric_closure[&(node1, node2)], ref_weight);
260            assert_eq!(
261                *metric_closure_graph
262                    .edge_weight(
263                        metric_closure_graph
264                            .find_edge(NodeIndex::new(node1), NodeIndex::new(node2))
265                            .unwrap()
266                    )
267                    .unwrap(),
268                ref_weight
269            );
270        }
271    }
272
273    #[test]
274    fn test_subgraph_from_metric_closure() {
275        let mut graph = Graph::<(), i32, _>::new_undirected();
276
277        let a = graph.add_node(());
278        let b = graph.add_node(());
279        let c = graph.add_node(());
280        let d = graph.add_node(());
281        let e = graph.add_node(());
282        let f = graph.add_node(());
283        graph.extend_with_edges([
284            (a, b, 7),
285            (a, f, 6),
286            (b, c, 1),
287            (b, f, 5),
288            (c, d, 1),
289            (c, e, 3),
290            (d, e, 1),
291            (d, f, 4),
292            (e, f, 10),
293        ]);
294
295        let terminals = vec![a, c, e, f];
296        let metric_closure = compute_metric_closure(&graph, &terminals);
297
298        let metric_closure_graph: UnGraph<(), _, _> = UnGraph::from_edges(
299            metric_closure
300                .iter()
301                .map(|((node1, node2), &weight)| (*node1 as u32, *node2 as u32, weight)),
302        );
303
304        let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
305
306        let (subgraph_edges, _subgraph_nodes) =
307            subgraph_edges_from_metric_closure(&graph, &minimum_spanning);
308
309        graph.retain_edges(|graph, e| {
310            let edge = graph.edge_endpoints(e).unwrap();
311            subgraph_edges.contains(&(edge.0, edge.1))
312        });
313
314        let mut ref_graph = UnGraph::<(), _>::new_undirected();
315        let ref_a = ref_graph.add_node(());
316        let _ = ref_graph.add_node(());
317        let ref_c = ref_graph.add_node(());
318        let ref_d = ref_graph.add_node(());
319        let ref_e = ref_graph.add_node(());
320        let ref_f = ref_graph.add_node(());
321
322        ref_graph.extend_with_edges([
323            (ref_c, ref_d, 1),
324            (ref_d, ref_e, 1),
325            (ref_d, ref_f, 4),
326            (ref_a, ref_f, 6),
327        ]);
328
329        for ref_edge in ref_graph.edge_references() {
330            let (edge_index, _) = graph
331                .find_edge_undirected(ref_edge.source(), ref_edge.target())
332                .unwrap();
333            let edge_endpoints = graph.edge_endpoints(edge_index).unwrap();
334            assert_eq!(graph.edge_weight(edge_index).unwrap(), ref_edge.weight());
335            assert_eq!(edge_endpoints.0, ref_edge.source());
336            assert_eq!(edge_endpoints.1, ref_edge.target());
337        }
338    }
339
340    #[test]
341    fn test_remove_non_terminal_nodes() {
342        let mut graph = Graph::<(), i32, _>::new_undirected();
343
344        let a = graph.add_node(());
345        let b = graph.add_node(());
346        let c = graph.add_node(());
347        let d = graph.add_node(());
348        let e = graph.add_node(());
349        let f = graph.add_node(());
350        graph.extend_with_edges([(a, b, 7), (b, c, 6), (c, d, 1), (d, e, 5), (e, f, 1)]);
351
352        let terminals = vec![a, c];
353        let non_terminal_nodes = non_terminal_leaves(&graph, &terminals);
354        let non_terminal_refs = HashSet::from([d, e, f]);
355        assert_eq!(non_terminal_refs, non_terminal_nodes);
356    }
357}