petgraph/visit/
undirected_adaptor.rs

1use crate::visit::{
2    Data, EdgeRef, GraphBase, GraphProp, GraphRef, IntoEdgeReferences, IntoEdges,
3    IntoEdgesDirected, IntoNeighbors, IntoNeighborsDirected, IntoNodeIdentifiers,
4    IntoNodeReferences, NodeCompactIndexable, NodeCount, NodeIndexable, Visitable,
5};
6use crate::Direction;
7
8/// An edge direction removing graph adaptor.
9#[derive(Copy, Clone, Debug)]
10pub struct UndirectedAdaptor<G>(pub G);
11
12impl<G: GraphRef> GraphRef for UndirectedAdaptor<G> {}
13
14impl<G> IntoNeighbors for UndirectedAdaptor<G>
15where
16    G: IntoNeighborsDirected,
17{
18    type Neighbors = core::iter::Chain<G::NeighborsDirected, G::NeighborsDirected>;
19    fn neighbors(self, n: G::NodeId) -> Self::Neighbors {
20        self.0
21            .neighbors_directed(n, Direction::Incoming)
22            .chain(self.0.neighbors_directed(n, Direction::Outgoing))
23    }
24}
25
26impl<G> IntoEdges for UndirectedAdaptor<G>
27where
28    G: IntoEdgesDirected,
29{
30    type Edges = core::iter::Chain<
31        MaybeReversedEdges<G::EdgesDirected>,
32        MaybeReversedEdges<G::EdgesDirected>,
33    >;
34    fn edges(self, a: Self::NodeId) -> Self::Edges {
35        let incoming = MaybeReversedEdges {
36            iter: self.0.edges_directed(a, Direction::Incoming),
37            reversed: true,
38        };
39        let outgoing = MaybeReversedEdges {
40            iter: self.0.edges_directed(a, Direction::Outgoing),
41            reversed: false,
42        };
43        incoming.chain(outgoing)
44    }
45}
46
47impl<G> GraphProp for UndirectedAdaptor<G>
48where
49    G: GraphBase,
50{
51    type EdgeType = crate::Undirected;
52
53    fn is_directed(&self) -> bool {
54        false
55    }
56}
57
58/// An edges iterator which may reverse the edge orientation.
59#[derive(Debug, Clone)]
60pub struct MaybeReversedEdges<I> {
61    iter: I,
62    reversed: bool,
63}
64
65impl<I> Iterator for MaybeReversedEdges<I>
66where
67    I: Iterator,
68    I::Item: EdgeRef,
69{
70    type Item = MaybeReversedEdgeReference<I::Item>;
71    fn next(&mut self) -> Option<Self::Item> {
72        self.iter.next().map(|x| MaybeReversedEdgeReference {
73            inner: x,
74            reversed: self.reversed,
75        })
76    }
77    fn size_hint(&self) -> (usize, Option<usize>) {
78        self.iter.size_hint()
79    }
80}
81
82/// An edge reference which may reverse the edge orientation.
83#[derive(Copy, Clone, Debug)]
84pub struct MaybeReversedEdgeReference<R> {
85    inner: R,
86    reversed: bool,
87}
88
89impl<R> EdgeRef for MaybeReversedEdgeReference<R>
90where
91    R: EdgeRef,
92{
93    type NodeId = R::NodeId;
94    type EdgeId = R::EdgeId;
95    type Weight = R::Weight;
96    fn source(&self) -> Self::NodeId {
97        if self.reversed {
98            self.inner.target()
99        } else {
100            self.inner.source()
101        }
102    }
103    fn target(&self) -> Self::NodeId {
104        if self.reversed {
105            self.inner.source()
106        } else {
107            self.inner.target()
108        }
109    }
110    fn weight(&self) -> &Self::Weight {
111        self.inner.weight()
112    }
113    fn id(&self) -> Self::EdgeId {
114        self.inner.id()
115    }
116}
117
118/// An edges iterator which may reverse the edge orientation.
119#[derive(Debug, Clone)]
120pub struct MaybeReversedEdgeReferences<I> {
121    iter: I,
122}
123
124impl<I> Iterator for MaybeReversedEdgeReferences<I>
125where
126    I: Iterator,
127    I::Item: EdgeRef,
128{
129    type Item = MaybeReversedEdgeReference<I::Item>;
130    fn next(&mut self) -> Option<Self::Item> {
131        self.iter.next().map(|x| MaybeReversedEdgeReference {
132            inner: x,
133            reversed: false,
134        })
135    }
136    fn size_hint(&self) -> (usize, Option<usize>) {
137        self.iter.size_hint()
138    }
139}
140
141impl<G> IntoEdgeReferences for UndirectedAdaptor<G>
142where
143    G: IntoEdgeReferences,
144{
145    type EdgeRef = MaybeReversedEdgeReference<G::EdgeRef>;
146    type EdgeReferences = MaybeReversedEdgeReferences<G::EdgeReferences>;
147
148    fn edge_references(self) -> Self::EdgeReferences {
149        MaybeReversedEdgeReferences {
150            iter: self.0.edge_references(),
151        }
152    }
153}
154
155macro_rules! access0 {
156    ($e:expr) => {
157        $e.0
158    };
159}
160
161GraphBase! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
162Data! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
163Visitable! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
164NodeIndexable! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
165NodeCompactIndexable! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
166IntoNodeIdentifiers! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
167IntoNodeReferences! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
168NodeCount! {delegate_impl [[G], G, UndirectedAdaptor<G>, access0]}
169
170#[cfg(test)]
171mod tests {
172    use alloc::vec::Vec;
173    use std::collections::HashSet;
174
175    use super::*;
176    use crate::algo::astar::*;
177    use crate::graph::{DiGraph, Graph};
178    use crate::visit::Dfs;
179
180    static LINEAR_EDGES: [(u32, u32); 5] = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)];
181
182    #[test]
183    pub fn test_is_reachable() {
184        // create a linear digraph, choose a node in the centre and check all nodes are visited
185        // by a dfs
186
187        let graph = DiGraph::<(), ()>::from_edges(LINEAR_EDGES);
188
189        let mut nodes = graph.node_identifiers().collect::<Vec<_>>();
190        nodes.sort();
191
192        let graph = UndirectedAdaptor(&graph);
193
194        use crate::visit::Walker;
195        let mut visited_nodes: Vec<_> = Dfs::new(&graph, nodes[2]).iter(&graph).collect();
196        visited_nodes.sort();
197        assert_eq!(visited_nodes, nodes);
198    }
199
200    #[test]
201    pub fn test_undirected_adaptor_can_traverse() {
202        let graph = DiGraph::<(), ()>::from_edges(LINEAR_EDGES);
203        let mut nodes = graph.node_identifiers().collect::<Vec<_>>();
204        nodes.sort();
205        let ungraph = UndirectedAdaptor(&graph);
206        let path = astar(&ungraph, nodes[5], |n| n == nodes[0], |_| 1, |_| 0);
207
208        let true_path = (0..=5).rev().map(|i| nodes[i]).collect::<Vec<_>>();
209
210        let (cost, path) = path.unwrap();
211        assert_eq!(cost, 5);
212        assert_eq!(path, true_path);
213    }
214
215    #[test]
216    pub fn test_undirected_edge_refs_point_both_ways() {
217        let graph = DiGraph::<(), ()>::from_edges(LINEAR_EDGES);
218        let mut nodes = graph.node_identifiers().collect::<Vec<_>>();
219        nodes.sort();
220        let ungraph = UndirectedAdaptor(&graph);
221
222        let expected_edge_targets = [
223            &[nodes[1]][..],
224            &[nodes[0], nodes[2]],
225            &[nodes[1], nodes[3]],
226            &[nodes[2], nodes[4]],
227            &[nodes[3], nodes[5]],
228            &[nodes[4]],
229        ];
230
231        for i in 0..nodes.len() {
232            let node = nodes[i];
233
234            let targets = ungraph
235                .edges(node)
236                .map(|e| e.target())
237                .collect::<HashSet<_>>();
238            let expected = expected_edge_targets[i]
239                .iter()
240                .cloned()
241                .collect::<HashSet<_>>();
242            assert_eq!(targets, expected);
243        }
244    }
245
246    #[test]
247    pub fn test_neighbors_count() {
248        {
249            let graph = Graph::<(), ()>::from_edges(LINEAR_EDGES);
250            let graph = UndirectedAdaptor(&graph);
251
252            let mut nodes = graph.node_identifiers().collect::<Vec<_>>();
253            nodes.sort();
254            assert_eq!(graph.neighbors(nodes[1]).count(), 2);
255        }
256
257        {
258            let graph = Graph::<(), ()>::from_edges(LINEAR_EDGES);
259            let graph = UndirectedAdaptor(&graph);
260
261            let mut nodes = graph.node_identifiers().collect::<Vec<_>>();
262            nodes.sort();
263            assert_eq!(graph.neighbors(nodes[1]).count(), 2);
264        }
265    }
266}