1use alloc::vec::Vec;
2use core::{fmt::Debug, hash::Hash};
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::algo::floyd_warshall::floyd_warshall_path;
7use crate::algo::{dijkstra, min_spanning_tree, BoundedMeasure, Measure};
8use crate::data::FromElements;
9use crate::graph::{IndexType, NodeIndex, UnGraph};
10use crate::visit::{
11 Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoNeighbors,
12 IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable, NodeIndexable, Visitable,
13};
14use crate::Undirected;
15
16#[cfg(feature = "stable_graph")]
17use crate::stable_graph::StableGraph;
18
19type Edge<G> = (<G as GraphBase>::NodeId, <G as GraphBase>::NodeId);
20type Subgraph<G> = HashSet<<G as GraphBase>::NodeId>;
21
22fn compute_shortest_path_length<G>(graph: G, source: G::NodeId, target: G::NodeId) -> G::EdgeWeight
23where
24 G: Visitable + IntoEdges,
25 G::NodeId: Eq + Hash,
26 G::EdgeWeight: Measure + Copy,
27{
28 let output = dijkstra(graph, source, Some(target), |e| *e.weight());
29 output[&target]
30}
31
32fn compute_metric_closure<G>(
33 graph: G,
34 terminals: &[G::NodeId],
35) -> HashMap<(usize, usize), G::EdgeWeight>
36where
37 G: Data + IntoNodeReferences + NodeIndexable + Visitable + IntoEdges,
38 G::EdgeWeight: Copy + Measure,
39 G::NodeId: PartialOrd + Eq + Hash,
40{
41 let mut closure = HashMap::new();
42 for (i, node_id_1) in terminals.iter().enumerate() {
43 for node_id_2 in terminals.iter().skip(i + 1) {
44 closure.insert(
45 (graph.to_index(*node_id_1), graph.to_index(*node_id_2)),
46 compute_shortest_path_length(graph, *node_id_1, *node_id_2),
47 );
48 }
49 }
50 closure
51}
52
53fn subgraph_edges_from_metric_closure<G>(
54 graph: G,
55 minimum_spanning_closure: G,
56) -> (Vec<Edge<G>>, Subgraph<G>)
57where
58 G: GraphBase
59 + NodeCompactIndexable
60 + IntoEdgeReferences
61 + IntoNodeIdentifiers
62 + GraphProp
63 + IntoNodeReferences,
64 G::EdgeWeight: BoundedMeasure + Copy,
65 G::NodeId: Eq + Hash + Ord + Debug,
66{
67 let mut retained_nodes = HashSet::new();
68 let mut retained_edges = Vec::new();
69 let (_, prev) = floyd_warshall_path(graph, |e| *e.weight()).unwrap();
70
71 for edge in minimum_spanning_closure.edge_references() {
72 let target = graph.to_index(edge.target());
73 let source = graph.to_index(edge.source());
74
75 let mut current = target;
76 while current != source {
77 if let Some(prev_node) = prev[source][current] {
78 retained_nodes.insert(graph.from_index(prev_node));
79 retained_nodes.insert(graph.from_index(current));
80 retained_edges.push((graph.from_index(prev_node), graph.from_index(current)));
81 current = prev_node;
82 }
83 }
84 }
85
86 (retained_edges, retained_nodes)
87}
88
89fn non_terminal_leaves<G>(graph: G, terminals: &[G::NodeId]) -> HashSet<G::NodeId>
90where
91 G: GraphBase + IntoNodeReferences + IntoNodeIdentifiers + IntoNeighbors,
92 G::NodeId: Hash + Eq + Debug,
93 G::NodeRef: Eq + Hash,
94{
95 let mut removed_leaves = HashSet::new();
96
97 let mut remaining_leaves = graph
98 .node_identifiers()
99 .filter(|node_id| {
100 graph.neighbors(*node_id).collect::<HashSet<_>>().len() == 1
101 && !terminals.contains(node_id)
102 })
103 .collect::<HashSet<_>>();
104
105 while !remaining_leaves.is_empty() {
106 remaining_leaves = graph
107 .node_identifiers()
108 .filter(|node_id| {
109 !terminals.contains(node_id)
110 && !removed_leaves.contains(node_id)
111 && (graph
112 .neighbors(*node_id)
113 .collect::<HashSet<_>>()
114 .difference(&removed_leaves))
115 .collect::<Vec<_>>()
116 .len()
117 == 1
118 })
119 .collect::<HashSet<_>>();
120
121 removed_leaves = removed_leaves
122 .union(&remaining_leaves)
123 .cloned()
124 .collect::<HashSet<_>>();
125 }
126
127 removed_leaves
128}
129
130#[cfg(feature = "stable_graph")]
171pub fn steiner_tree<N, E, Ix>(
172 graph: &UnGraph<N, E, Ix>,
173 terminals: &[NodeIndex<Ix>],
174) -> StableGraph<N, E, Undirected, Ix>
175where
176 N: Default + Clone + Eq + Hash + Debug,
177 E: Copy + Eq + Ord + Measure + BoundedMeasure,
178 Ix: IndexType,
179{
180 let metric_closure = compute_metric_closure(&graph, terminals);
181 let metric_closure_graph: UnGraph<N, E, _> = UnGraph::from_edges(
182 metric_closure
183 .iter()
184 .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
185 );
186
187 let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
188
189 let (subgraph_edges, subgraph_nodes) =
190 subgraph_edges_from_metric_closure(graph, &minimum_spanning);
191
192 let mut graph = StableGraph::from(graph.clone());
193 graph.retain_edges(|graph, e| {
194 let edge = graph.edge_endpoints(e).unwrap();
195 subgraph_edges.contains(&(edge.0, edge.1)) || subgraph_edges.contains(&(edge.1, edge.0))
196 });
197 graph.retain_nodes(|_, n| subgraph_nodes.contains(&n));
198
199 let non_terminal_nodes = non_terminal_leaves(&graph, terminals);
200 graph.retain_nodes(|_, n| !non_terminal_nodes.contains(&n));
201
202 graph
203}
204
205#[cfg(test)]
206mod test {
207 use alloc::vec;
208
209 use hashbrown::{HashMap, HashSet};
210
211 use super::{compute_metric_closure, non_terminal_leaves, subgraph_edges_from_metric_closure};
212 use crate::graph::NodeIndex;
213 use crate::{
214 algo::{min_spanning_tree, EdgeRef, UnGraph},
215 data::FromElements,
216 Graph, Undirected,
217 };
218
219 #[test]
220 fn test_compute_metric_closure() {
221 let mut graph = Graph::<(), i32, Undirected>::new_undirected();
222
223 let a = graph.add_node(());
224 let b = graph.add_node(());
225 let c = graph.add_node(());
226 let d = graph.add_node(());
227 let e = graph.add_node(());
228 let f = graph.add_node(());
229 graph.extend_with_edges([
230 (a, b, 7),
231 (a, f, 6),
232 (b, c, 1),
233 (b, f, 5),
234 (c, d, 1),
235 (c, e, 3),
236 (d, e, 1),
237 (d, f, 4),
238 (e, f, 10),
239 ]);
240
241 let terminals = vec![a, c, e, f];
242 let metric_closure = compute_metric_closure(&graph, &terminals);
243
244 let metric_closure_graph: UnGraph<&str, _, _> = UnGraph::from_edges(
245 metric_closure
246 .iter()
247 .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
248 );
249
250 let ref_weights = HashMap::<_, _>::from([
251 ((0, 2), 8),
252 ((0, 4), 10),
253 ((0, 5), 6),
254 ((2, 4), 2),
255 ((2, 5), 5),
256 ((4, 5), 5),
257 ]);
258 for ((node1, node2), ref_weight) in ref_weights {
259 assert_eq!(metric_closure[&(node1, node2)], ref_weight);
260 assert_eq!(
261 *metric_closure_graph
262 .edge_weight(
263 metric_closure_graph
264 .find_edge(NodeIndex::new(node1), NodeIndex::new(node2))
265 .unwrap()
266 )
267 .unwrap(),
268 ref_weight
269 );
270 }
271 }
272
273 #[test]
274 fn test_subgraph_from_metric_closure() {
275 let mut graph = Graph::<(), i32, _>::new_undirected();
276
277 let a = graph.add_node(());
278 let b = graph.add_node(());
279 let c = graph.add_node(());
280 let d = graph.add_node(());
281 let e = graph.add_node(());
282 let f = graph.add_node(());
283 graph.extend_with_edges([
284 (a, b, 7),
285 (a, f, 6),
286 (b, c, 1),
287 (b, f, 5),
288 (c, d, 1),
289 (c, e, 3),
290 (d, e, 1),
291 (d, f, 4),
292 (e, f, 10),
293 ]);
294
295 let terminals = vec![a, c, e, f];
296 let metric_closure = compute_metric_closure(&graph, &terminals);
297
298 let metric_closure_graph: UnGraph<(), _, _> = UnGraph::from_edges(
299 metric_closure
300 .iter()
301 .map(|((node1, node2), &weight)| (*node1 as u32, *node2 as u32, weight)),
302 );
303
304 let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
305
306 let (subgraph_edges, _subgraph_nodes) =
307 subgraph_edges_from_metric_closure(&graph, &minimum_spanning);
308
309 graph.retain_edges(|graph, e| {
310 let edge = graph.edge_endpoints(e).unwrap();
311 subgraph_edges.contains(&(edge.0, edge.1))
312 });
313
314 let mut ref_graph = UnGraph::<(), _>::new_undirected();
315 let ref_a = ref_graph.add_node(());
316 let _ = ref_graph.add_node(());
317 let ref_c = ref_graph.add_node(());
318 let ref_d = ref_graph.add_node(());
319 let ref_e = ref_graph.add_node(());
320 let ref_f = ref_graph.add_node(());
321
322 ref_graph.extend_with_edges([
323 (ref_c, ref_d, 1),
324 (ref_d, ref_e, 1),
325 (ref_d, ref_f, 4),
326 (ref_a, ref_f, 6),
327 ]);
328
329 for ref_edge in ref_graph.edge_references() {
330 let (edge_index, _) = graph
331 .find_edge_undirected(ref_edge.source(), ref_edge.target())
332 .unwrap();
333 let edge_endpoints = graph.edge_endpoints(edge_index).unwrap();
334 assert_eq!(graph.edge_weight(edge_index).unwrap(), ref_edge.weight());
335 assert_eq!(edge_endpoints.0, ref_edge.source());
336 assert_eq!(edge_endpoints.1, ref_edge.target());
337 }
338 }
339
340 #[test]
341 fn test_remove_non_terminal_nodes() {
342 let mut graph = Graph::<(), i32, _>::new_undirected();
343
344 let a = graph.add_node(());
345 let b = graph.add_node(());
346 let c = graph.add_node(());
347 let d = graph.add_node(());
348 let e = graph.add_node(());
349 let f = graph.add_node(());
350 graph.extend_with_edges([(a, b, 7), (b, c, 6), (c, d, 1), (d, e, 5), (e, f, 1)]);
351
352 let terminals = vec![a, c];
353 let non_terminal_nodes = non_terminal_leaves(&graph, &terminals);
354 let non_terminal_refs = HashSet::from([d, e, f]);
355 assert_eq!(non_terminal_refs, non_terminal_nodes);
356 }
357}