petgraph/algo/
min_spanning_tree.rs

1//! Minimum Spanning Tree algorithms.
2
3use alloc::collections::BinaryHeap;
4
5use hashbrown::{HashMap, HashSet};
6
7use crate::data::Element;
8use crate::prelude::*;
9use crate::scored::MinScored;
10use crate::unionfind::UnionFind;
11use crate::visit::{Data, IntoEdges, IntoNodeReferences, NodeRef};
12use crate::visit::{IntoEdgeReferences, NodeIndexable};
13
14/// \[Generic\] Compute a *minimum spanning tree* of a graph.
15///
16/// The input graph is treated as if undirected.
17///
18/// Using Kruskal's algorithm with runtime **O(|E| log |E|)**. We actually
19/// return a minimum spanning forest, i.e. a minimum spanning tree for each connected
20/// component of the graph.
21///
22/// The resulting graph has all the vertices of the input graph (with identical node indices),
23/// and **|V| - c** edges, where **c** is the number of connected components in `g`.
24///
25/// Use `from_elements` to create a graph from the resulting iterator.
26///
27/// See also: [`.min_spanning_tree_prim(g)`][1] for implementation using Prim's algorithm.
28///
29/// [1]: fn.min_spanning_tree_prim.html
30pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
31where
32    G::NodeWeight: Clone,
33    G::EdgeWeight: Clone + PartialOrd,
34    G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
35{
36    // Initially each vertex is its own disjoint subgraph, track the connectedness
37    // of the pre-MST with a union & find datastructure.
38    let subgraphs = UnionFind::new(g.node_bound());
39
40    let edges = g.edge_references();
41    let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
42    for edge in edges {
43        sort_edges.push(MinScored(
44            edge.weight().clone(),
45            (edge.source(), edge.target()),
46        ));
47    }
48
49    MinSpanningTree {
50        graph: g,
51        node_ids: Some(g.node_references()),
52        subgraphs,
53        sort_edges,
54        node_map: HashMap::new(),
55        node_count: 0,
56    }
57}
58
59/// An iterator producing a minimum spanning forest of a graph.
60/// It will first iterate all Node elements from original graph,
61/// then iterate Edge elements from computed minimum spanning forest.
62#[derive(Debug, Clone)]
63pub struct MinSpanningTree<G>
64where
65    G: Data + IntoNodeReferences,
66{
67    graph: G,
68    node_ids: Option<G::NodeReferences>,
69    subgraphs: UnionFind<usize>,
70    #[allow(clippy::type_complexity)]
71    sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
72    node_map: HashMap<usize, usize>,
73    node_count: usize,
74}
75
76impl<G> Iterator for MinSpanningTree<G>
77where
78    G: IntoNodeReferences + NodeIndexable,
79    G::NodeWeight: Clone,
80    G::EdgeWeight: PartialOrd,
81{
82    type Item = Element<G::NodeWeight, G::EdgeWeight>;
83
84    fn next(&mut self) -> Option<Self::Item> {
85        let g = self.graph;
86        if let Some(ref mut iter) = self.node_ids {
87            if let Some(node) = iter.next() {
88                self.node_map.insert(g.to_index(node.id()), self.node_count);
89                self.node_count += 1;
90                return Some(Element::Node {
91                    weight: node.weight().clone(),
92                });
93            }
94        }
95        self.node_ids = None;
96
97        // Kruskal's algorithm.
98        // Algorithm is this:
99        //
100        // 1. Create a pre-MST with all the vertices and no edges.
101        // 2. Repeat:
102        //
103        //  a. Remove the shortest edge from the original graph.
104        //  b. If the edge connects two disjoint trees in the pre-MST,
105        //     add the edge.
106        while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
107            // check if the edge would connect two disjoint parts
108            let (a_index, b_index) = (g.to_index(a), g.to_index(b));
109            if self.subgraphs.union(a_index, b_index) {
110                let (&a_order, &b_order) =
111                    match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
112                        (Some(a_id), Some(b_id)) => (a_id, b_id),
113                        _ => panic!("Edge references unknown node"),
114                    };
115                return Some(Element::Edge {
116                    source: a_order,
117                    target: b_order,
118                    weight: score,
119                });
120            }
121        }
122        None
123    }
124}
125
126/// \[Generic\] Compute a *minimum spanning tree* of a graph using Prim's algorithm.
127///
128/// Graph is treated as if undirected. The computed minimum spanning tree can be wrong
129/// if this is not true.
130///
131/// Graph is treated as if connected (has only 1 component). Otherwise, the resulting
132/// graph will only contain edges for an arbitrary minimum spanning tree for a single component.
133///
134/// Using Prim's algorithm with runtime **O(|E| log |V|)**.
135///
136/// The resulting graph has all the vertices of the input graph (with identical node indices),
137/// and **|V| - 1** edges if input graph is connected, and |W| edges if disconnected, where |W| < |V| - 1.
138///
139/// Use `from_elements` to create a graph from the resulting iterator.
140///
141/// See also: [`.min_spanning_tree(g)`][1] for implementation using Kruskal's algorithm and support for minimum spanning forest.
142///
143/// [1]: fn.min_spanning_tree.html
144pub fn min_spanning_tree_prim<G>(g: G) -> MinSpanningTreePrim<G>
145where
146    G::EdgeWeight: PartialOrd,
147    G: IntoNodeReferences + IntoEdgeReferences,
148{
149    let sort_edges = BinaryHeap::with_capacity(g.edge_references().size_hint().0);
150    let nodes_taken = HashSet::with_capacity(g.node_references().size_hint().0);
151    let initial_node = g.node_references().next();
152
153    MinSpanningTreePrim {
154        graph: g,
155        node_ids: Some(g.node_references()),
156        node_map: HashMap::new(),
157        node_count: 0,
158        sort_edges,
159        nodes_taken,
160        initial_node,
161    }
162}
163
164/// An iterator producing a minimum spanning tree of a graph.
165/// It will first iterate all Node elements from original graph,
166/// then iterate Edge elements from computed minimum spanning tree.
167#[derive(Debug, Clone)]
168pub struct MinSpanningTreePrim<G>
169where
170    G: IntoNodeReferences,
171{
172    graph: G,
173    node_ids: Option<G::NodeReferences>,
174    node_map: HashMap<usize, usize>,
175    node_count: usize,
176    #[allow(clippy::type_complexity)]
177    sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
178    nodes_taken: HashSet<usize>,
179    initial_node: Option<G::NodeRef>,
180}
181
182impl<G> Iterator for MinSpanningTreePrim<G>
183where
184    G: IntoNodeReferences + IntoEdges + NodeIndexable,
185    G::NodeWeight: Clone,
186    G::EdgeWeight: Clone + PartialOrd,
187{
188    type Item = Element<G::NodeWeight, G::EdgeWeight>;
189
190    fn next(&mut self) -> Option<Self::Item> {
191        // Iterate through Node elements
192        let g = self.graph;
193        if let Some(ref mut iter) = self.node_ids {
194            if let Some(node) = iter.next() {
195                self.node_map.insert(g.to_index(node.id()), self.node_count);
196                self.node_count += 1;
197                return Some(Element::Node {
198                    weight: node.weight().clone(),
199                });
200            }
201        }
202        self.node_ids = None;
203
204        // Bootstrap Prim's algorithm to find MST Edge elements.
205        // Mark initial node as taken and add its edges to priority queue.
206        if let Some(initial_node) = self.initial_node {
207            let initial_node_index = g.to_index(initial_node.id());
208            self.nodes_taken.insert(initial_node_index);
209
210            let initial_edges = g.edges(initial_node.id());
211            for edge in initial_edges {
212                self.sort_edges.push(MinScored(
213                    edge.weight().clone(),
214                    (edge.source(), edge.target()),
215                ));
216            }
217        };
218        self.initial_node = None;
219
220        // Clear edges queue if all nodes were already included in MST.
221        if self.nodes_taken.len() == self.node_count {
222            self.sort_edges.clear();
223        };
224
225        // Prim's algorithm:
226        // Iterate through Edge elements, adding an edge to the MST iff some of it's nodes are not part of MST yet.
227        while let Some(MinScored(score, (source, target))) = self.sort_edges.pop() {
228            let (source_index, target_index) = (g.to_index(source), g.to_index(target));
229
230            if self.nodes_taken.contains(&target_index) {
231                continue;
232            }
233
234            self.nodes_taken.insert(target_index);
235            for edge in g.edges(target) {
236                self.sort_edges.push(MinScored(
237                    edge.weight().clone(),
238                    (edge.source(), edge.target()),
239                ));
240            }
241
242            let (&source_order, &target_order) = match (
243                self.node_map.get(&source_index),
244                self.node_map.get(&target_index),
245            ) {
246                (Some(source_order), Some(target_order)) => (source_order, target_order),
247                _ => panic!("Edge references unknown node"),
248            };
249
250            return Some(Element::Edge {
251                source: source_order,
252                target: target_order,
253                weight: score,
254            });
255        }
256
257        None
258    }
259}