petgraph/algo/maximum_flow/
dinics.rs

1use alloc::{collections::VecDeque, vec, vec::Vec};
2use core::ops::Sub;
3
4use crate::{
5    algo::{EdgeRef, PositiveMeasure},
6    prelude::Direction,
7    visit::{
8        Data, EdgeCount, EdgeIndexable, IntoEdgeReferences, IntoEdges, IntoEdgesDirected,
9        NodeCount, NodeIndexable, VisitMap, Visitable,
10    },
11};
12
13/// Compute the maximum flow from `source` to `destination` in a directed graph.
14/// Implements [Dinic's (or Dinitz's) algorithm][dinics], which builds successive
15/// level graphs using breadth-first search and finds blocking flows within
16/// them through depth-first searches.
17///
18/// For simplicity, the algorithm requires `N::EdgeWeight` to implement
19/// only [PartialOrd] trait, and not [Ord], but will panic if it tries to
20/// compare two elements that aren't comparable (i.e., given two edge weights `a`
21/// and `b`, where neither `a >= b` nor `a < b`).
22///
23/// See also [`maximum_flow`][max flow mod] module for other maximum flow algorithms.
24///
25/// # Arguments
26/// * `network` — A directed graph with positive edge weights, namely "flow capacities".
27/// * `source` — The source node where flow originates.
28/// * `destination` — The destination node where flow terminates.
29///
30/// # Returns
31/// Returns a tuple of two values:
32/// * `N::EdgeWeight`: computed maximum flow;
33/// * `Vec<N::EdgeWeight>`: the flow of each edge. The vector is indexed by the graph's edge indices.
34///
35/// # Complexity
36/// * Time complexity:
37///   * In general: **O(|V|²|E|)**
38///   * In networks with only unit capacities: **O(min{|V|²ᐟ³, |E|¹ᐟ²} |E|)**
39/// * Auxiliary space: **O(|V| + |E|)**.
40///
41/// where **|V|** is the number of nodes and **|E|** is the number of edges.
42///
43/// [dinics]: https://en.wikipedia.org/wiki/Dinic%27s_algorithm
44/// [max flow mod]: index.html
45///
46/// # Example
47/// ```rust
48/// use petgraph::Graph;
49/// use petgraph::algo::dinics;
50/// // Example from CLRS book
51/// let mut graph = Graph::<u8, u8>::new();
52/// let source = graph.add_node(0);
53/// let _ = graph.add_node(1);
54/// let _ = graph.add_node(2);
55/// let _ = graph.add_node(3);
56/// let _ = graph.add_node(4);
57/// let destination = graph.add_node(5);
58/// graph.extend_with_edges(&[
59///    (0, 1, 16),
60///    (0, 2, 13),
61///    (1, 2, 10),
62///    (1, 3, 12),
63///    (2, 1, 4),
64///    (2, 4, 14),
65///    (3, 2, 9),
66///    (3, 5, 20),
67///    (4, 3, 7),
68///    (4, 5, 4),
69/// ]);
70/// let (max_flow, _) = dinics(&graph, source, destination);
71/// assert_eq!(23, max_flow);
72/// ```
73pub fn dinics<G>(
74    network: G,
75    source: G::NodeId,
76    destination: G::NodeId,
77) -> (G::EdgeWeight, Vec<G::EdgeWeight>)
78where
79    G: NodeCount + EdgeCount + IntoEdgesDirected + EdgeIndexable + NodeIndexable + Visitable,
80    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
81{
82    let mut max_flow = G::EdgeWeight::zero();
83    let mut flows = vec![G::EdgeWeight::zero(); network.edge_count()];
84    let mut visited = network.visit_map();
85    let mut level_edges = vec![Default::default(); network.node_bound()];
86
87    let dest_index = NodeIndexable::to_index(&network, destination);
88    while build_level_graph(&network, source, destination, &flows, &mut level_edges)[dest_index] > 0
89    {
90        let flow_increase = find_blocking_flow(
91            network,
92            source,
93            destination,
94            &mut flows,
95            &mut level_edges,
96            &mut visited,
97        );
98        max_flow = max_flow + flow_increase;
99    }
100    (max_flow, flows)
101}
102
103/// Makes a BFS that labels network vertices with levels representing
104/// their distance to the source vertex, considering only edges with
105/// positive residual capacity.
106///
107/// The source vertex is labeled as 1, and vertices not reachable are
108/// labeled as 0.
109///
110/// Aggregates in `level_edges` the edges that connects each
111/// vertex to its neighbours in the next level.
112///
113/// Returns the computed level graph.
114fn build_level_graph<G>(
115    network: G,
116    source: G::NodeId,
117    destination: G::NodeId,
118    flows: &[G::EdgeWeight],
119    level_edges: &mut [Vec<G::EdgeRef>],
120) -> Vec<usize>
121where
122    G: NodeCount + IntoEdgesDirected + NodeIndexable + EdgeIndexable,
123    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
124{
125    let mut level_graph = vec![0; network.node_bound()];
126    let mut bfs_queue = VecDeque::with_capacity(network.node_count());
127    bfs_queue.push_back(source);
128
129    level_graph[NodeIndexable::to_index(&network, source)] = 1;
130    while let Some(vertex) = bfs_queue.pop_front() {
131        let vertex_index = NodeIndexable::to_index(&network, vertex);
132        let out_edges = network.edges_directed(vertex, Direction::Outgoing);
133        let in_edges = network.edges_directed(vertex, Direction::Incoming);
134        level_edges[vertex_index].clear();
135        for edge in out_edges.chain(in_edges) {
136            let next_vertex = other_endpoint(&network, edge, vertex);
137            let edge_index = EdgeIndexable::to_index(&network, edge.id());
138            let residual_cap = residual_capacity(&network, edge, next_vertex, flows[edge_index]);
139            if residual_cap == G::EdgeWeight::zero() {
140                continue;
141            }
142            let next_vertex_index = NodeIndexable::to_index(&network, next_vertex);
143            if level_graph[next_vertex_index] == 0 {
144                level_graph[next_vertex_index] = level_graph[vertex_index] + 1;
145                level_edges[vertex_index].push(edge);
146                if next_vertex != destination {
147                    bfs_queue.push_back(next_vertex);
148                }
149            } else if level_graph[next_vertex_index] == level_graph[vertex_index] + 1 {
150                level_edges[vertex_index].push(edge);
151            }
152        }
153    }
154
155    level_graph
156}
157
158/// Find blocking flow for current level graph by repeatingly finding
159/// augmenting paths in it.
160///
161/// Attach computed flows to `flows` and returns the total flow increase from
162/// edges available in `level_edges` at this iteration.
163fn find_blocking_flow<G>(
164    network: G,
165    source: G::NodeId,
166    destination: G::NodeId,
167    flows: &mut [G::EdgeWeight],
168    level_edges: &mut [Vec<G::EdgeRef>],
169    visited: &mut G::Map,
170) -> G::EdgeWeight
171where
172    G: NodeCount + IntoEdges + NodeIndexable + EdgeIndexable + Visitable,
173    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
174{
175    let mut flow_increase = G::EdgeWeight::zero();
176    let mut edge_to = vec![None; network.node_bound()];
177    while find_augmenting_path(
178        &network,
179        source,
180        destination,
181        flows,
182        level_edges,
183        visited,
184        &mut edge_to,
185    ) {
186        let mut path_flow = G::EdgeWeight::max();
187
188        // Find the bottleneck capacity of the path
189        let mut vertex = destination;
190        while let Some(edge) = edge_to[NodeIndexable::to_index(&network, vertex)] {
191            let edge_index = EdgeIndexable::to_index(&network, edge.id());
192            let residual_capacity = residual_capacity(&network, edge, vertex, flows[edge_index]);
193            path_flow = min::<G>(path_flow, residual_capacity);
194            vertex = other_endpoint(&network, edge, vertex);
195        }
196
197        // Update the flow of each edge along the discovered path
198        let mut vertex = destination;
199        while let Some(edge) = edge_to[NodeIndexable::to_index(&network, vertex)] {
200            let edge_index = EdgeIndexable::to_index(&network, edge.id());
201            flows[edge_index] =
202                adjusted_residual_flow(&network, edge, vertex, flows[edge_index], path_flow);
203            vertex = other_endpoint(&network, edge, vertex);
204        }
205        flow_increase = flow_increase + path_flow;
206    }
207    flow_increase
208}
209
210/// Makes a DFS to find an augmenting path from source to destination vertex
211/// using previously computed `edge_levels` from level graph.
212///
213/// Returns a boolean indicating if an augmenting path to destination was found.
214fn find_augmenting_path<G>(
215    network: G,
216    source: G::NodeId,
217    destination: G::NodeId,
218    flows: &[G::EdgeWeight],
219    level_edges: &mut [Vec<G::EdgeRef>],
220    visited: &mut G::Map,
221    edge_to: &mut [Option<G::EdgeRef>],
222) -> bool
223where
224    G: IntoEdges + NodeIndexable + EdgeIndexable + Visitable,
225    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
226{
227    network.reset_map(visited);
228    let mut level_edges_i = vec![0; level_edges.len()];
229
230    let mut dfs_stack = Vec::new();
231    dfs_stack.push(source);
232    visited.visit(source);
233    while let Some(&vertex) = dfs_stack.last() {
234        let vertex_index = NodeIndexable::to_index(&network, vertex);
235
236        let mut found_next = false;
237        while level_edges_i[vertex_index] < level_edges[vertex_index].len() {
238            let curr_level_edges_i = level_edges_i[vertex_index];
239            let edge = level_edges[vertex_index][curr_level_edges_i];
240            let next_vertex = other_endpoint(&network, edge, vertex);
241
242            let edge_index: usize = EdgeIndexable::to_index(&network, edge.id());
243            let residual_cap = residual_capacity(&network, edge, next_vertex, flows[edge_index]);
244            if residual_cap == G::EdgeWeight::zero() {
245                level_edges[vertex_index].swap_remove(curr_level_edges_i);
246                continue;
247            }
248
249            if !visited.is_visited(&next_vertex) {
250                let next_vertex_index = NodeIndexable::to_index(&network, next_vertex);
251                edge_to[next_vertex_index] = Some(edge);
252                if destination == next_vertex {
253                    return true;
254                }
255                dfs_stack.push(next_vertex);
256                visited.visit(next_vertex);
257                found_next = true;
258                break;
259            }
260            level_edges_i[vertex_index] += 1;
261        }
262        if !found_next {
263            dfs_stack.pop();
264        }
265    }
266    false
267}
268
269/// Returns the adjusted residual flow for given edge and flow increase.
270fn adjusted_residual_flow<G>(
271    network: G,
272    edge: G::EdgeRef,
273    target_vertex: G::NodeId,
274    flow: G::EdgeWeight,
275    flow_increase: G::EdgeWeight,
276) -> G::EdgeWeight
277where
278    G: NodeIndexable + IntoEdges,
279    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
280{
281    if target_vertex == edge.source() {
282        // backward edge
283        flow - flow_increase
284    } else if target_vertex == edge.target() {
285        // forward edge
286        flow + flow_increase
287    } else {
288        let end_point = NodeIndexable::to_index(&network, target_vertex);
289        panic!("Illegal endpoint {}", end_point);
290    }
291}
292
293/// Returns the residual capacity of given edge.
294fn residual_capacity<G>(
295    network: G,
296    edge: G::EdgeRef,
297    target_vertex: G::NodeId,
298    flow: G::EdgeWeight,
299) -> G::EdgeWeight
300where
301    G: NodeIndexable + IntoEdges,
302    G::EdgeWeight: Sub<Output = G::EdgeWeight> + PositiveMeasure,
303{
304    if target_vertex == edge.source() {
305        // backward edge
306        flow
307    } else if target_vertex == edge.target() {
308        // forward edge
309        *edge.weight() - flow
310    } else {
311        let end_point = NodeIndexable::to_index(&network, target_vertex);
312        panic!("Illegal endpoint {}", end_point);
313    }
314}
315
316/// Returns the minimum value between given `a` and `b`.
317/// Will panic if it tries to compare two elements that aren't comparable
318/// (i.e., given two elements `a` and `b`, neither `a >= b` nor `a < b`).
319fn min<G>(a: G::EdgeWeight, b: G::EdgeWeight) -> G::EdgeWeight
320where
321    G: Data,
322    G::EdgeWeight: PartialOrd,
323{
324    if a < b {
325        a
326    } else if a >= b {
327        b
328    } else {
329        panic!("Invalid edge weights. Impossible to get min value.");
330    }
331}
332
333/// Gets the other endpoint of graph edge, if any, otherwise panics.
334fn other_endpoint<G>(network: G, edge: G::EdgeRef, vertex: G::NodeId) -> G::NodeId
335where
336    G: NodeIndexable + IntoEdgeReferences,
337{
338    if vertex == edge.source() {
339        edge.target()
340    } else if vertex == edge.target() {
341        edge.source()
342    } else {
343        let end_point = NodeIndexable::to_index(&network, vertex);
344        panic!("Illegal endpoint {}", end_point);
345    }
346}