1use alloc::vec::Vec;
2use core::{fmt::Debug, hash::Hash};
3
4use hashbrown::{HashMap, HashSet};
5
6use crate::algo::floyd_warshall::floyd_warshall_path;
7use crate::algo::{dijkstra, min_spanning_tree, BoundedMeasure, Measure};
8use crate::data::FromElements;
9use crate::graph::{IndexType, NodeIndex, UnGraph};
10use crate::visit::{
11 Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoNeighbors,
12 IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable, NodeIndexable, Visitable,
13};
14use crate::Undirected;
15
16#[cfg(feature = "stable_graph")]
17use crate::stable_graph::StableGraph;
18
19type Edge<G> = (<G as GraphBase>::NodeId, <G as GraphBase>::NodeId);
20type Subgraph<G> = HashSet<<G as GraphBase>::NodeId>;
21
22fn compute_shortest_path_length<G>(graph: G, source: G::NodeId, target: G::NodeId) -> G::EdgeWeight
23where
24 G: Visitable + IntoEdges,
25 G::NodeId: Eq + Hash,
26 G::EdgeWeight: Measure + Copy,
27{
28 let output = dijkstra(graph, source, Some(target), |e| *e.weight());
29 output[&target]
30}
31
32fn compute_metric_closure<G>(
33 graph: G,
34 terminals: &[G::NodeId],
35) -> HashMap<(usize, usize), G::EdgeWeight>
36where
37 G: Data + IntoNodeReferences + NodeIndexable + Visitable + IntoEdges,
38 G::EdgeWeight: Copy + Measure,
39 G::NodeId: PartialOrd + Eq + Hash,
40{
41 let mut closure = HashMap::new();
42 for (i, node_id_1) in terminals.iter().enumerate() {
43 for node_id_2 in terminals.iter().skip(i + 1) {
44 closure.insert(
45 (graph.to_index(*node_id_1), graph.to_index(*node_id_2)),
46 compute_shortest_path_length(graph, *node_id_1, *node_id_2),
47 );
48 }
49 }
50 closure
51}
52
53fn subgraph_edges_from_metric_closure<G>(
54 graph: G,
55 minimum_spanning_closure: G,
56) -> (Vec<Edge<G>>, Subgraph<G>)
57where
58 G: GraphBase
59 + NodeCompactIndexable
60 + IntoEdgeReferences
61 + IntoNodeIdentifiers
62 + GraphProp
63 + IntoNodeReferences,
64 G::EdgeWeight: BoundedMeasure + Copy,
65 G::NodeId: Eq + Hash + Ord + Debug,
66{
67 let mut retained_nodes = HashSet::new();
68 let mut retained_edges = Vec::new();
69 let (_, prev) = floyd_warshall_path(graph, |e| *e.weight()).unwrap();
70
71 for edge in minimum_spanning_closure.edge_references() {
72 let target = graph.to_index(edge.target());
73 let source = graph.to_index(edge.source());
74
75 let mut current = target;
76 while current != source {
77 if let Some(prev_node) = prev[source][current] {
78 retained_nodes.insert(graph.from_index(prev_node));
79 retained_nodes.insert(graph.from_index(current));
80 retained_edges.push((graph.from_index(prev_node), graph.from_index(current)));
81 current = prev_node;
82 }
83 }
84 }
85
86 (retained_edges, retained_nodes)
87}
88
89fn non_terminal_leaves<G>(graph: G, terminals: &[G::NodeId]) -> HashSet<G::NodeId>
90where
91 G: GraphBase + IntoNodeReferences + IntoNodeIdentifiers + IntoNeighbors,
92 G::NodeId: Hash + Eq + Debug,
93 G::NodeRef: Eq + Hash,
94{
95 let mut removed_leaves = HashSet::new();
96
97 let mut remaining_leaves = graph
98 .node_identifiers()
99 .filter(|node_id| {
100 graph.neighbors(*node_id).collect::<HashSet<_>>().len() == 1
101 && !terminals.contains(node_id)
102 })
103 .collect::<HashSet<_>>();
104
105 while !remaining_leaves.is_empty() {
106 remaining_leaves = graph
107 .node_identifiers()
108 .filter(|node_id| {
109 !terminals.contains(node_id)
110 && !removed_leaves.contains(node_id)
111 && (graph
112 .neighbors(*node_id)
113 .collect::<HashSet<_>>()
114 .difference(&removed_leaves))
115 .collect::<Vec<_>>()
116 .len()
117 == 1
118 })
119 .collect::<HashSet<_>>();
120
121 removed_leaves = removed_leaves
122 .union(&remaining_leaves)
123 .cloned()
124 .collect::<HashSet<_>>();
125 }
126
127 removed_leaves
128}
129
130#[cfg(feature = "stable_graph")]
179pub fn steiner_tree<N, E, Ix>(
180 graph: &UnGraph<N, E, Ix>,
181 terminals: &[NodeIndex<Ix>],
182) -> StableGraph<N, E, Undirected, Ix>
183where
184 N: Default + Clone + Eq + Hash + Debug,
185 E: Copy + Eq + Ord + Measure + BoundedMeasure,
186 Ix: IndexType,
187{
188 let metric_closure = compute_metric_closure(&graph, terminals);
189 let metric_closure_graph: UnGraph<N, E, _> = UnGraph::from_edges(
190 metric_closure
191 .iter()
192 .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
193 );
194
195 let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
196
197 let (subgraph_edges, subgraph_nodes) =
198 subgraph_edges_from_metric_closure(graph, &minimum_spanning);
199
200 let mut graph = StableGraph::from(graph.clone());
201 graph.retain_edges(|graph, e| {
202 let edge = graph.edge_endpoints(e).unwrap();
203 subgraph_edges.contains(&(edge.0, edge.1)) || subgraph_edges.contains(&(edge.1, edge.0))
204 });
205 graph.retain_nodes(|_, n| subgraph_nodes.contains(&n));
206
207 let non_terminal_nodes = non_terminal_leaves(&graph, terminals);
208 graph.retain_nodes(|_, n| !non_terminal_nodes.contains(&n));
209
210 graph
211}
212
213#[cfg(test)]
214mod test {
215 use alloc::vec;
216
217 use hashbrown::{HashMap, HashSet};
218
219 use super::{compute_metric_closure, non_terminal_leaves, subgraph_edges_from_metric_closure};
220 use crate::graph::NodeIndex;
221 use crate::{
222 algo::{min_spanning_tree, EdgeRef, UnGraph},
223 data::FromElements,
224 Graph, Undirected,
225 };
226
227 #[test]
228 fn test_compute_metric_closure() {
229 let mut graph = Graph::<(), i32, Undirected>::new_undirected();
230
231 let a = graph.add_node(());
232 let b = graph.add_node(());
233 let c = graph.add_node(());
234 let d = graph.add_node(());
235 let e = graph.add_node(());
236 let f = graph.add_node(());
237 graph.extend_with_edges([
238 (a, b, 7),
239 (a, f, 6),
240 (b, c, 1),
241 (b, f, 5),
242 (c, d, 1),
243 (c, e, 3),
244 (d, e, 1),
245 (d, f, 4),
246 (e, f, 10),
247 ]);
248
249 let terminals = vec![a, c, e, f];
250 let metric_closure = compute_metric_closure(&graph, &terminals);
251
252 let metric_closure_graph: UnGraph<&str, _, _> = UnGraph::from_edges(
253 metric_closure
254 .iter()
255 .map(|((node1, node2), &weight)| (*node1, *node2, weight)),
256 );
257
258 let ref_weights = HashMap::<_, _>::from([
259 ((0, 2), 8),
260 ((0, 4), 10),
261 ((0, 5), 6),
262 ((2, 4), 2),
263 ((2, 5), 5),
264 ((4, 5), 5),
265 ]);
266 for ((node1, node2), ref_weight) in ref_weights {
267 assert_eq!(metric_closure[&(node1, node2)], ref_weight);
268 assert_eq!(
269 *metric_closure_graph
270 .edge_weight(
271 metric_closure_graph
272 .find_edge(NodeIndex::new(node1), NodeIndex::new(node2))
273 .unwrap()
274 )
275 .unwrap(),
276 ref_weight
277 );
278 }
279 }
280
281 #[test]
282 fn test_subgraph_from_metric_closure() {
283 let mut graph = Graph::<(), i32, _>::new_undirected();
284
285 let a = graph.add_node(());
286 let b = graph.add_node(());
287 let c = graph.add_node(());
288 let d = graph.add_node(());
289 let e = graph.add_node(());
290 let f = graph.add_node(());
291 graph.extend_with_edges([
292 (a, b, 7),
293 (a, f, 6),
294 (b, c, 1),
295 (b, f, 5),
296 (c, d, 1),
297 (c, e, 3),
298 (d, e, 1),
299 (d, f, 4),
300 (e, f, 10),
301 ]);
302
303 let terminals = vec![a, c, e, f];
304 let metric_closure = compute_metric_closure(&graph, &terminals);
305
306 let metric_closure_graph: UnGraph<(), _, _> = UnGraph::from_edges(
307 metric_closure
308 .iter()
309 .map(|((node1, node2), &weight)| (*node1 as u32, *node2 as u32, weight)),
310 );
311
312 let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
313
314 let (subgraph_edges, _subgraph_nodes) =
315 subgraph_edges_from_metric_closure(&graph, &minimum_spanning);
316
317 graph.retain_edges(|graph, e| {
318 let edge = graph.edge_endpoints(e).unwrap();
319 subgraph_edges.contains(&(edge.0, edge.1))
320 });
321
322 let mut ref_graph = UnGraph::<(), _>::new_undirected();
323 let ref_a = ref_graph.add_node(());
324 let _ = ref_graph.add_node(());
325 let ref_c = ref_graph.add_node(());
326 let ref_d = ref_graph.add_node(());
327 let ref_e = ref_graph.add_node(());
328 let ref_f = ref_graph.add_node(());
329
330 ref_graph.extend_with_edges([
331 (ref_c, ref_d, 1),
332 (ref_d, ref_e, 1),
333 (ref_d, ref_f, 4),
334 (ref_a, ref_f, 6),
335 ]);
336
337 for ref_edge in ref_graph.edge_references() {
338 let (edge_index, _) = graph
339 .find_edge_undirected(ref_edge.source(), ref_edge.target())
340 .unwrap();
341 let edge_endpoints = graph.edge_endpoints(edge_index).unwrap();
342 assert_eq!(graph.edge_weight(edge_index).unwrap(), ref_edge.weight());
343 assert_eq!(edge_endpoints.0, ref_edge.source());
344 assert_eq!(edge_endpoints.1, ref_edge.target());
345 }
346 }
347
348 #[test]
349 fn test_remove_non_terminal_nodes() {
350 let mut graph = Graph::<(), i32, _>::new_undirected();
351
352 let a = graph.add_node(());
353 let b = graph.add_node(());
354 let c = graph.add_node(());
355 let d = graph.add_node(());
356 let e = graph.add_node(());
357 let f = graph.add_node(());
358 graph.extend_with_edges([(a, b, 7), (b, c, 6), (c, d, 1), (d, e, 5), (e, f, 1)]);
359
360 let terminals = vec![a, c];
361 let non_terminal_nodes = non_terminal_leaves(&graph, &terminals);
362 let non_terminal_refs = HashSet::from([d, e, f]);
363 assert_eq!(non_terminal_refs, non_terminal_nodes);
364 }
365}