1use alloc::collections::BinaryHeap;
4
5use hashbrown::{HashMap, HashSet};
6
7use crate::data::Element;
8use crate::prelude::*;
9use crate::scored::MinScored;
10use crate::unionfind::UnionFind;
11use crate::visit::{Data, IntoEdges, IntoNodeReferences, NodeRef};
12use crate::visit::{IntoEdgeReferences, NodeIndexable};
13
14pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
31where
32 G::NodeWeight: Clone,
33 G::EdgeWeight: Clone + PartialOrd,
34 G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
35{
36 let subgraphs = UnionFind::new(g.node_bound());
39
40 let edges = g.edge_references();
41 let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
42 for edge in edges {
43 sort_edges.push(MinScored(
44 edge.weight().clone(),
45 (edge.source(), edge.target()),
46 ));
47 }
48
49 MinSpanningTree {
50 graph: g,
51 node_ids: Some(g.node_references()),
52 subgraphs,
53 sort_edges,
54 node_map: HashMap::new(),
55 node_count: 0,
56 }
57}
58
59#[derive(Debug, Clone)]
63pub struct MinSpanningTree<G>
64where
65 G: Data + IntoNodeReferences,
66{
67 graph: G,
68 node_ids: Option<G::NodeReferences>,
69 subgraphs: UnionFind<usize>,
70 #[allow(clippy::type_complexity)]
71 sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
72 node_map: HashMap<usize, usize>,
73 node_count: usize,
74}
75
76impl<G> Iterator for MinSpanningTree<G>
77where
78 G: IntoNodeReferences + NodeIndexable,
79 G::NodeWeight: Clone,
80 G::EdgeWeight: PartialOrd,
81{
82 type Item = Element<G::NodeWeight, G::EdgeWeight>;
83
84 fn next(&mut self) -> Option<Self::Item> {
85 let g = self.graph;
86 if let Some(ref mut iter) = self.node_ids {
87 if let Some(node) = iter.next() {
88 self.node_map.insert(g.to_index(node.id()), self.node_count);
89 self.node_count += 1;
90 return Some(Element::Node {
91 weight: node.weight().clone(),
92 });
93 }
94 }
95 self.node_ids = None;
96
97 while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
107 let (a_index, b_index) = (g.to_index(a), g.to_index(b));
109 if self.subgraphs.union(a_index, b_index) {
110 let (&a_order, &b_order) =
111 match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
112 (Some(a_id), Some(b_id)) => (a_id, b_id),
113 _ => panic!("Edge references unknown node"),
114 };
115 return Some(Element::Edge {
116 source: a_order,
117 target: b_order,
118 weight: score,
119 });
120 }
121 }
122 None
123 }
124}
125
126pub fn min_spanning_tree_prim<G>(g: G) -> MinSpanningTreePrim<G>
145where
146 G::EdgeWeight: PartialOrd,
147 G: IntoNodeReferences + IntoEdgeReferences,
148{
149 let sort_edges = BinaryHeap::with_capacity(g.edge_references().size_hint().0);
150 let nodes_taken = HashSet::with_capacity(g.node_references().size_hint().0);
151 let initial_node = g.node_references().next();
152
153 MinSpanningTreePrim {
154 graph: g,
155 node_ids: Some(g.node_references()),
156 node_map: HashMap::new(),
157 node_count: 0,
158 sort_edges,
159 nodes_taken,
160 initial_node,
161 }
162}
163
164#[derive(Debug, Clone)]
168pub struct MinSpanningTreePrim<G>
169where
170 G: IntoNodeReferences,
171{
172 graph: G,
173 node_ids: Option<G::NodeReferences>,
174 node_map: HashMap<usize, usize>,
175 node_count: usize,
176 #[allow(clippy::type_complexity)]
177 sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
178 nodes_taken: HashSet<usize>,
179 initial_node: Option<G::NodeRef>,
180}
181
182impl<G> Iterator for MinSpanningTreePrim<G>
183where
184 G: IntoNodeReferences + IntoEdges + NodeIndexable,
185 G::NodeWeight: Clone,
186 G::EdgeWeight: Clone + PartialOrd,
187{
188 type Item = Element<G::NodeWeight, G::EdgeWeight>;
189
190 fn next(&mut self) -> Option<Self::Item> {
191 let g = self.graph;
193 if let Some(ref mut iter) = self.node_ids {
194 if let Some(node) = iter.next() {
195 self.node_map.insert(g.to_index(node.id()), self.node_count);
196 self.node_count += 1;
197 return Some(Element::Node {
198 weight: node.weight().clone(),
199 });
200 }
201 }
202 self.node_ids = None;
203
204 if let Some(initial_node) = self.initial_node {
207 let initial_node_index = g.to_index(initial_node.id());
208 self.nodes_taken.insert(initial_node_index);
209
210 let initial_edges = g.edges(initial_node.id());
211 for edge in initial_edges {
212 self.sort_edges.push(MinScored(
213 edge.weight().clone(),
214 (edge.source(), edge.target()),
215 ));
216 }
217 };
218 self.initial_node = None;
219
220 if self.nodes_taken.len() == self.node_count {
222 self.sort_edges.clear();
223 };
224
225 while let Some(MinScored(score, (source, target))) = self.sort_edges.pop() {
228 let (source_index, target_index) = (g.to_index(source), g.to_index(target));
229
230 if self.nodes_taken.contains(&target_index) {
231 continue;
232 }
233
234 self.nodes_taken.insert(target_index);
235 for edge in g.edges(target) {
236 self.sort_edges.push(MinScored(
237 edge.weight().clone(),
238 (edge.source(), edge.target()),
239 ));
240 }
241
242 let (&source_order, &target_order) = match (
243 self.node_map.get(&source_index),
244 self.node_map.get(&target_index),
245 ) {
246 (Some(source_order), Some(target_order)) => (source_order, target_order),
247 _ => panic!("Edge references unknown node"),
248 };
249
250 return Some(Element::Edge {
251 source: source_order,
252 target: target_order,
253 weight: score,
254 });
255 }
256
257 None
258 }
259}