petgraph/algo/
steiner_tree.rs

1use alloc::vec::Vec;
2use core::{fmt::Debug, hash::Hash};
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::algo::floyd_warshall::floyd_warshall_path;
7use crate::algo::{dijkstra, min_spanning_tree, BoundedMeasure, Measure};
8use crate::data::FromElements;
9use crate::graph::{IndexType, NodeIndex, UnGraph};
10use crate::visit::{
11    Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoNeighbors,
12    IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable, NodeIndexable, Visitable,
13};
14use crate::Undirected;
15
16#[cfg(feature = "stable_graph")]
17use crate::stable_graph::StableGraph;
18
19type Edge<G> = (<G as GraphBase>::NodeId, <G as GraphBase>::NodeId);
20type Subgraph<G> = HashSet<<G as GraphBase>::NodeId>;
21
22fn compute_shortest_path_length<G>(graph: G, source: G::NodeId, target: G::NodeId) -> G::EdgeWeight
23where
24    G: Visitable + IntoEdges,
25    G::NodeId: Eq + Hash,
26    G::EdgeWeight: Measure + Copy,
27{
28    let output = dijkstra(graph, source, Some(target), |e| *e.weight());
29    output[&target]
30}
31
32fn compute_metric_closure<G>(
33    graph: G,
34    terminals: &[G::NodeId],
35) -> HashMap<(usize, usize), G::EdgeWeight>
36where
37    G: Data + IntoNodeReferences + NodeIndexable + Visitable + IntoEdges,
38    G::EdgeWeight: Copy + Measure,
39    G::NodeId: PartialOrd + Eq + Hash,
40{
41    let mut closure = HashMap::new();
42    for (i, node_id_1) in terminals.iter().enumerate() {
43        for node_id_2 in terminals.iter().skip(i + 1) {
44            closure.insert(
45                (graph.to_index(*node_id_1), graph.to_index(*node_id_2)),
46                compute_shortest_path_length(graph, *node_id_1, *node_id_2),
47            );
48        }
49    }
50    closure
51}
52
53fn subgraph_edges_from_metric_closure<G>(
54    graph: G,
55    minimum_spanning_closure: G,
56) -> (Vec<Edge<G>>, Subgraph<G>)
57where
58    G: GraphBase
59        + NodeCompactIndexable
60        + IntoEdgeReferences
61        + IntoNodeIdentifiers
62        + GraphProp
63        + IntoNodeReferences,
64    G::EdgeWeight: BoundedMeasure + Copy,
65    G::NodeId: Eq + Hash + Ord + Debug,
66{
67    let mut retained_nodes = HashSet::new();
68    let mut retained_edges = Vec::new();
69    let (_, prev) = floyd_warshall_path(graph, |e| *e.weight()).unwrap();
70
71    for edge in minimum_spanning_closure.edge_references() {
72        let target = graph.to_index(edge.target());
73        let source = graph.to_index(edge.source());
74
75        let mut current = target;
76        while current != source {
77            if let Some(prev_node) = prev[source][current] {
78                retained_nodes.insert(graph.from_index(prev_node));
79                retained_nodes.insert(graph.from_index(current));
80                retained_edges.push((graph.from_index(prev_node), graph.from_index(current)));
81                current = prev_node;
82            }
83        }
84    }
85
86    (retained_edges, retained_nodes)
87}
88
89fn non_terminal_leaves<G>(graph: G, terminals: &[G::NodeId]) -> HashSet<G::NodeId>
90where
91    G: GraphBase + IntoNodeReferences + IntoNodeIdentifiers + IntoNeighbors,
92    G::NodeId: Hash + Eq + Debug,
93    G::NodeRef: Eq + Hash,
94{
95    let mut removed_leaves = HashSet::new();
96
97    let mut remaining_leaves = graph
98        .node_identifiers()
99        .filter(|node_id| {
100            graph.neighbors(*node_id).collect::<HashSet<_>>().len() == 1
101                && !terminals.contains(node_id)
102        })
103        .collect::<HashSet<_>>();
104
105    while !remaining_leaves.is_empty() {
106        remaining_leaves = graph
107            .node_identifiers()
108            .filter(|node_id| {
109                !terminals.contains(node_id)
110                    && !removed_leaves.contains(node_id)
111                    && (graph
112                        .neighbors(*node_id)
113                        .collect::<HashSet<_>>()
114                        .difference(&removed_leaves))
115                    .collect::<Vec<_>>()
116                    .len()
117                        == 1
118            })
119            .collect::<HashSet<_>>();
120
121        removed_leaves = removed_leaves
122            .union(&remaining_leaves)
123            .cloned()
124            .collect::<HashSet<_>>();
125    }
126
127    removed_leaves
128}
129
130/// \[Generic\] [Steiner Tree][1] algorithm.
131///
132/// Computes the Steiner tree of an undirected connected graph given a set of terminal nodes via
133/// [Kou's algorithm][2]. Implementation details are the same as in the [NetworkX implementation][3].
134///
135/// ## Arguments
136/// * `graph`: The undirected graph in which to find the Steiner tree.
137/// * `terminals`: A slice of node indices representing the terminals for which the Steiner tree is computed.
138///
139/// ## Returns
140/// A `StableGraph` containing the nodes and edges of the Steiner tree.
141///
142/// ## Complexity
143/// Time complexity: **O(|S| |V|²)**.
144/// where **|V|** the number of vertices (i.e nodes) and **|S|** the number of provided terminals.
145///
146/// [1]: https://en.wikipedia.org/wiki/Steiner_tree_problem
147/// [2]: https://doi.org/10.1007/BF00288961
148/// [3]: https://networkx.org/documentation/stable/_modules/networkx/algorithms/approximation/steinertree.html#steiner_tree
149///
150/// # Example
151///
152/// ```
153/// use petgraph::Graph;
154/// use petgraph::algo::steiner_tree::steiner_tree;
155/// use petgraph::graph::UnGraph;
156/// let mut graph = UnGraph::<(), i32>::default();
157/// let a = graph.add_node(());
158/// let b = graph.add_node(());
159/// let c = graph.add_node(());
160/// let d = graph.add_node(());
161/// let e = graph.add_node(());
162/// let f = graph.add_node(());
163/// graph.extend_with_edges([
164///     (a, b, 7),
165///     (a, f, 6),
166///     (b, c, 1),
167///     (b, f, 5),
168///     (c, d, 1),
169///     (c, e, 3),
170///     (d, e, 1),
171///     (d, f, 4),
172///     (e, f, 10),
173/// ]);
174/// let terminals = vec![a, c, e, f];
175/// let tree = steiner_tree(&graph, &terminals);
176/// assert_eq!(tree.edge_weights().sum::<i32>(), 12);
177/// ```
178#[cfg(feature = "stable_graph")]
179pub fn steiner_tree<N, E, Ix>(
180    graph: &UnGraph<N, E, Ix>,
181    terminals: &[NodeIndex<Ix>],
182) -> StableGraph<N, E, Undirected, Ix>
183where
184    N: Default + Clone + Eq + Hash + Debug,
185    E: Copy + Eq + Ord + Measure + BoundedMeasure,
186    Ix: IndexType,
187{
188    let metric_closure = compute_metric_closure(&graph, terminals);
189    let metric_closure_graph: UnGraph<N, E, _> = UnGraph::from_edges(
190        metric_closure
191            .iter()
192            .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
193    );
194
195    let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
196
197    let (subgraph_edges, subgraph_nodes) =
198        subgraph_edges_from_metric_closure(graph, &minimum_spanning);
199
200    let mut graph = StableGraph::from(graph.clone());
201    graph.retain_edges(|graph, e| {
202        let edge = graph.edge_endpoints(e).unwrap();
203        subgraph_edges.contains(&(edge.0, edge.1)) || subgraph_edges.contains(&(edge.1, edge.0))
204    });
205    graph.retain_nodes(|_, n| subgraph_nodes.contains(&n));
206
207    let non_terminal_nodes = non_terminal_leaves(&graph, terminals);
208    graph.retain_nodes(|_, n| !non_terminal_nodes.contains(&n));
209
210    graph
211}
212
213#[cfg(test)]
214mod test {
215    use alloc::vec;
216
217    use hashbrown::{HashMap, HashSet};
218
219    use super::{compute_metric_closure, non_terminal_leaves, subgraph_edges_from_metric_closure};
220    use crate::graph::NodeIndex;
221    use crate::{
222        algo::{min_spanning_tree, EdgeRef, UnGraph},
223        data::FromElements,
224        Graph, Undirected,
225    };
226
227    #[test]
228    fn test_compute_metric_closure() {
229        let mut graph = Graph::<(), i32, Undirected>::new_undirected();
230
231        let a = graph.add_node(());
232        let b = graph.add_node(());
233        let c = graph.add_node(());
234        let d = graph.add_node(());
235        let e = graph.add_node(());
236        let f = graph.add_node(());
237        graph.extend_with_edges([
238            (a, b, 7),
239            (a, f, 6),
240            (b, c, 1),
241            (b, f, 5),
242            (c, d, 1),
243            (c, e, 3),
244            (d, e, 1),
245            (d, f, 4),
246            (e, f, 10),
247        ]);
248
249        let terminals = vec![a, c, e, f];
250        let metric_closure = compute_metric_closure(&graph, &terminals);
251
252        let metric_closure_graph: UnGraph<&str, _, _> = UnGraph::from_edges(
253            metric_closure
254                .iter()
255                .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
256        );
257
258        let ref_weights = HashMap::<_, _>::from([
259            ((0, 2), 8),
260            ((0, 4), 10),
261            ((0, 5), 6),
262            ((2, 4), 2),
263            ((2, 5), 5),
264            ((4, 5), 5),
265        ]);
266        for ((node1, node2), ref_weight) in ref_weights {
267            assert_eq!(metric_closure[&(node1, node2)], ref_weight);
268            assert_eq!(
269                *metric_closure_graph
270                    .edge_weight(
271                        metric_closure_graph
272                            .find_edge(NodeIndex::new(node1), NodeIndex::new(node2))
273                            .unwrap()
274                    )
275                    .unwrap(),
276                ref_weight
277            );
278        }
279    }
280
281    #[test]
282    fn test_subgraph_from_metric_closure() {
283        let mut graph = Graph::<(), i32, _>::new_undirected();
284
285        let a = graph.add_node(());
286        let b = graph.add_node(());
287        let c = graph.add_node(());
288        let d = graph.add_node(());
289        let e = graph.add_node(());
290        let f = graph.add_node(());
291        graph.extend_with_edges([
292            (a, b, 7),
293            (a, f, 6),
294            (b, c, 1),
295            (b, f, 5),
296            (c, d, 1),
297            (c, e, 3),
298            (d, e, 1),
299            (d, f, 4),
300            (e, f, 10),
301        ]);
302
303        let terminals = vec![a, c, e, f];
304        let metric_closure = compute_metric_closure(&graph, &terminals);
305
306        let metric_closure_graph: UnGraph<(), _, _> = UnGraph::from_edges(
307            metric_closure
308                .iter()
309                .map(|((node1, node2), &weight)| (*node1 as u32, *node2 as u32, weight)),
310        );
311
312        let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
313
314        let (subgraph_edges, _subgraph_nodes) =
315            subgraph_edges_from_metric_closure(&graph, &minimum_spanning);
316
317        graph.retain_edges(|graph, e| {
318            let edge = graph.edge_endpoints(e).unwrap();
319            subgraph_edges.contains(&(edge.0, edge.1))
320        });
321
322        let mut ref_graph = UnGraph::<(), _>::new_undirected();
323        let ref_a = ref_graph.add_node(());
324        let _ = ref_graph.add_node(());
325        let ref_c = ref_graph.add_node(());
326        let ref_d = ref_graph.add_node(());
327        let ref_e = ref_graph.add_node(());
328        let ref_f = ref_graph.add_node(());
329
330        ref_graph.extend_with_edges([
331            (ref_c, ref_d, 1),
332            (ref_d, ref_e, 1),
333            (ref_d, ref_f, 4),
334            (ref_a, ref_f, 6),
335        ]);
336
337        for ref_edge in ref_graph.edge_references() {
338            let (edge_index, _) = graph
339                .find_edge_undirected(ref_edge.source(), ref_edge.target())
340                .unwrap();
341            let edge_endpoints = graph.edge_endpoints(edge_index).unwrap();
342            assert_eq!(graph.edge_weight(edge_index).unwrap(), ref_edge.weight());
343            assert_eq!(edge_endpoints.0, ref_edge.source());
344            assert_eq!(edge_endpoints.1, ref_edge.target());
345        }
346    }
347
348    #[test]
349    fn test_remove_non_terminal_nodes() {
350        let mut graph = Graph::<(), i32, _>::new_undirected();
351
352        let a = graph.add_node(());
353        let b = graph.add_node(());
354        let c = graph.add_node(());
355        let d = graph.add_node(());
356        let e = graph.add_node(());
357        let f = graph.add_node(());
358        graph.extend_with_edges([(a, b, 7), (b, c, 6), (c, d, 1), (d, e, 5), (e, f, 1)]);
359
360        let terminals = vec![a, c];
361        let non_terminal_nodes = non_terminal_leaves(&graph, &terminals);
362        let non_terminal_refs = HashSet::from([d, e, f]);
363        assert_eq!(non_terminal_refs, non_terminal_nodes);
364    }
365}