petgraph/
csr.rs

1//! Compressed Sparse Row (CSR) is a sparse adjacency matrix graph.
2
3use alloc::{vec, vec::Vec};
4use core::{
5    cmp::{max, Ordering},
6    fmt,
7    iter::{Enumerate, Zip},
8    marker::PhantomData,
9    ops::{Index, IndexMut, Range},
10    slice::Windows,
11};
12
13use crate::visit::{
14    Data, EdgeCount, EdgeRef, GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences,
15    IntoEdges, IntoNeighbors, IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable,
16    NodeCount, NodeIndexable, Visitable,
17};
18
19use crate::util::zip;
20
21#[doc(no_inline)]
22pub use crate::graph::{DefaultIx, IndexType};
23
24use crate::{Directed, EdgeType, IntoWeightedEdge};
25
26/// Csr node index type, a plain integer.
27pub type NodeIndex<Ix = DefaultIx> = Ix;
28/// Csr edge index type, a plain integer.
29pub type EdgeIndex = usize;
30
31const BINARY_SEARCH_CUTOFF: usize = 32;
32
33/// The error type for fallible operations with `Csr`.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum CsrError {
36    /// Both vertex indexes go outside the graph.
37    IndicesOutBounds(usize, usize),
38}
39
40#[cfg(feature = "std")]
41impl std::error::Error for CsrError {}
42
43#[cfg(not(feature = "std"))]
44impl core::error::Error for CsrError {}
45
46impl fmt::Display for CsrError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            CsrError::IndicesOutBounds(a, b) => {
50                write!(f, "Both node indices {a} and {b} is out of Csr bounds")
51            }
52        }
53    }
54}
55
56/// Compressed Sparse Row ([`CSR`]) is a sparse adjacency matrix graph.
57///
58/// `CSR` is parameterized over:
59///
60/// - Associated data `N` for nodes and `E` for edges, called *weights*.
61///   The associated data can be of arbitrary type.
62/// - Edge type `Ty` that determines whether the graph edges are directed or undirected.
63/// - Index type `Ix`, which determines the maximum size of the graph.
64///
65///
66/// Using **O(|E| + |V|)** space.
67///
68/// Self loops are allowed, no parallel edges.
69///
70/// Fast iteration of the outgoing edges of a vertex.
71///
72/// [`CSR`]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
73#[derive(Debug)]
74pub struct Csr<N = (), E = (), Ty = Directed, Ix = DefaultIx> {
75    /// Column of next edge
76    column: Vec<NodeIndex<Ix>>,
77    /// weight of each edge; lock step with column
78    edges: Vec<E>,
79    /// Index of start of row Always node_count + 1 long.
80    /// Last element is always equal to column.len()
81    row: Vec<usize>,
82    node_weights: Vec<N>,
83    edge_count: usize,
84    ty: PhantomData<Ty>,
85}
86
87impl<N, E, Ty, Ix> Default for Csr<N, E, Ty, Ix>
88where
89    Ty: EdgeType,
90    Ix: IndexType,
91{
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl<N: Clone, E: Clone, Ty, Ix: Clone> Clone for Csr<N, E, Ty, Ix> {
98    fn clone(&self) -> Self {
99        Csr {
100            column: self.column.clone(),
101            edges: self.edges.clone(),
102            row: self.row.clone(),
103            node_weights: self.node_weights.clone(),
104            edge_count: self.edge_count,
105            ty: self.ty,
106        }
107    }
108}
109
110impl<N, E, Ty, Ix> Csr<N, E, Ty, Ix>
111where
112    Ty: EdgeType,
113    Ix: IndexType,
114{
115    /// Create an empty `Csr`.
116    pub fn new() -> Self {
117        Csr {
118            column: vec![],
119            edges: vec![],
120            row: vec![0; 1],
121            node_weights: vec![],
122            edge_count: 0,
123            ty: PhantomData,
124        }
125    }
126
127    /// Create a new `Csr` with `n` nodes. `N` must implement [`Default`] for the weight of each node.
128    ///
129    /// [`Default`]: https://doc.rust-lang.org/nightly/core/default/trait.Default.html
130    ///
131    /// # Example
132    /// ```rust
133    /// use petgraph::csr::Csr;
134    /// use petgraph::prelude::*;
135    ///
136    /// let graph = Csr::<u8,()>::with_nodes(5);
137    /// assert_eq!(graph.node_count(),5);
138    /// assert_eq!(graph.edge_count(),0);
139    ///
140    /// assert_eq!(graph[0],0);
141    /// assert_eq!(graph[4],0);
142    /// ```
143    pub fn with_nodes(n: usize) -> Self
144    where
145        N: Default,
146    {
147        Csr {
148            column: Vec::new(),
149            edges: Vec::new(),
150            row: vec![0; n + 1],
151            node_weights: (0..n).map(|_| N::default()).collect(),
152            edge_count: 0,
153            ty: PhantomData,
154        }
155    }
156}
157
158/// Csr creation error: edges were not in sorted order.
159#[derive(Clone, Debug)]
160pub struct EdgesNotSorted {
161    #[allow(unused)]
162    first_error: (usize, usize),
163}
164
165impl<N, E, Ix> Csr<N, E, Directed, Ix>
166where
167    Ix: IndexType,
168{
169    /// Create a new `Csr` from a sorted sequence of edges
170    ///
171    /// Edges **must** be sorted and unique, where the sort order is the default
172    /// order for the pair *(u, v)* in Rust (*u* has priority).
173    ///
174    /// Computes in **O(|E| + |V|)** time.
175    /// # Example
176    /// ```rust
177    /// use petgraph::csr::Csr;
178    /// use petgraph::prelude::*;
179    ///
180    /// let graph = Csr::<(),()>::from_sorted_edges(&[
181    ///                     (0, 1), (0, 2),
182    ///                     (1, 0), (1, 2), (1, 3),
183    ///                     (2, 0),
184    ///                     (3, 1),
185    /// ]);
186    /// ```
187    pub fn from_sorted_edges<Edge>(edges: &[Edge]) -> Result<Self, EdgesNotSorted>
188    where
189        Edge: Clone + IntoWeightedEdge<E, NodeId = NodeIndex<Ix>>,
190        N: Default,
191    {
192        let max_node_id = match edges
193            .iter()
194            .map(|edge| {
195                let (x, y, _) = edge.clone().into_weighted_edge();
196                max(x.index(), y.index())
197            })
198            .max()
199        {
200            None => return Ok(Self::with_nodes(0)),
201            Some(x) => x,
202        };
203        let mut self_ = Self::with_nodes(max_node_id + 1);
204        let mut iter = edges.iter().cloned().peekable();
205        {
206            let mut rows = self_.row.iter_mut();
207
208            let mut rstart = 0;
209            let mut last_target;
210            'outer: for (node, r) in (&mut rows).enumerate() {
211                *r = rstart;
212                last_target = None;
213                'inner: loop {
214                    if let Some(edge) = iter.peek() {
215                        let (n, m, weight) = edge.clone().into_weighted_edge();
216                        // check that the edges are in increasing sequence
217                        if node > n.index() {
218                            return Err(EdgesNotSorted {
219                                first_error: (n.index(), m.index()),
220                            });
221                        }
222                        /*
223                        debug_assert!(node <= n.index(),
224                                      concat!("edges are not sorted, ",
225                                              "failed assertion source {:?} <= {:?} ",
226                                              "for edge {:?}"),
227                                      node, n, (n, m));
228                                      */
229                        if n.index() != node {
230                            break 'inner;
231                        }
232                        // check that the edges are in increasing sequence
233                        /*
234                        debug_assert!(last_target.map_or(true, |x| m > x),
235                                      "edges are not sorted, failed assertion {:?} < {:?}",
236                                      last_target, m);
237                                      */
238                        if !last_target.map_or(true, |x| m > x) {
239                            return Err(EdgesNotSorted {
240                                first_error: (n.index(), m.index()),
241                            });
242                        }
243                        last_target = Some(m);
244                        self_.column.push(m);
245                        self_.edges.push(weight);
246                        rstart += 1;
247                    } else {
248                        break 'outer;
249                    }
250                    iter.next();
251                }
252            }
253            for r in rows {
254                *r = rstart;
255            }
256        }
257
258        Ok(self_)
259    }
260}
261
262impl<N, E, Ty, Ix> Csr<N, E, Ty, Ix>
263where
264    Ty: EdgeType,
265    Ix: IndexType,
266{
267    pub fn node_count(&self) -> usize {
268        self.row.len() - 1
269    }
270
271    pub fn edge_count(&self) -> usize {
272        if self.is_directed() {
273            self.column.len()
274        } else {
275            self.edge_count
276        }
277    }
278
279    pub fn is_directed(&self) -> bool {
280        Ty::is_directed()
281    }
282
283    /// Remove all edges
284    pub fn clear_edges(&mut self) {
285        self.column.clear();
286        self.edges.clear();
287        for r in &mut self.row {
288            *r = 0;
289        }
290        if !self.is_directed() {
291            self.edge_count = 0;
292        }
293    }
294
295    /// Adds a new node with the given weight, returning the corresponding node index.
296    pub fn add_node(&mut self, weight: N) -> NodeIndex<Ix> {
297        let i = self.row.len() - 1;
298        self.row.insert(i, self.column.len());
299        self.node_weights.insert(i, weight);
300        Ix::new(i)
301    }
302
303    /// Add an edge from `a` to `b` to the `Csr`, with its associated
304    /// data weight.
305    ///
306    /// Return `true` if the edge was added
307    ///
308    /// If you add all edges in row-major order, the time complexity
309    /// is **O(|V|·|E|)** for the whole operation.
310    ///
311    /// **Panics** if `a` or `b` are out of bounds.
312    #[track_caller]
313    pub fn add_edge(&mut self, a: NodeIndex<Ix>, b: NodeIndex<Ix>, weight: E) -> bool
314    where
315        E: Clone,
316    {
317        self.try_add_edge(a, b, weight).unwrap()
318    }
319
320    /// Try to add an edge from `a` to `b` to the `Csr`, with its associated
321    /// data weight.
322    ///
323    /// Return `true` if the edge was added
324    ///
325    /// If you add all edges in row-major order, the time complexity
326    /// is **O(|V|·|E|)** for the whole operation.
327    ///
328    /// Possible errors:
329    /// - [`CsrError::IndicesOutBounds`] - when both idxs `a` & `b` is out of bounds.
330    pub fn try_add_edge(
331        &mut self,
332        a: NodeIndex<Ix>,
333        b: NodeIndex<Ix>,
334        weight: E,
335    ) -> Result<bool, CsrError>
336    where
337        E: Clone,
338    {
339        let ret = self.add_edge_(a, b, weight.clone())?;
340        if ret && !self.is_directed() {
341            self.edge_count += 1;
342        }
343        if ret && !self.is_directed() && a != b {
344            let _ret2 = self.add_edge_(b, a, weight)?;
345            debug_assert_eq!(ret, _ret2);
346        }
347        Ok(ret)
348    }
349
350    // Return false if the edge already exists
351    fn add_edge_(
352        &mut self,
353        a: NodeIndex<Ix>,
354        b: NodeIndex<Ix>,
355        weight: E,
356    ) -> Result<bool, CsrError> {
357        if !(a.index() < self.node_count() && b.index() < self.node_count()) {
358            return Err(CsrError::IndicesOutBounds(a.index(), b.index()));
359        }
360        // a x b is at (a, b) in the matrix
361
362        // find current range of edges from a
363        let pos = match self.find_edge_pos(a, b) {
364            Ok(_) => return Ok(false), /* already exists */
365            Err(i) => i,
366        };
367        self.column.insert(pos, b);
368        self.edges.insert(pos, weight);
369        // update row vector
370        for r in &mut self.row[a.index() + 1..] {
371            *r += 1;
372        }
373        Ok(true)
374    }
375
376    fn find_edge_pos(&self, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> Result<usize, usize> {
377        let (index, neighbors) = self.neighbors_of(a);
378        if neighbors.len() < BINARY_SEARCH_CUTOFF {
379            for (i, elt) in neighbors.iter().enumerate() {
380                match elt.cmp(&b) {
381                    Ordering::Equal => return Ok(i + index),
382                    Ordering::Greater => return Err(i + index),
383                    Ordering::Less => {}
384                }
385            }
386            Err(neighbors.len() + index)
387        } else {
388            match neighbors.binary_search(&b) {
389                Ok(i) => Ok(i + index),
390                Err(i) => Err(i + index),
391            }
392        }
393    }
394
395    /// Computes in **O(log |V|)** time.
396    ///
397    /// **Panics** if the node `a` does not exist.
398    #[track_caller]
399    pub fn contains_edge(&self, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> bool {
400        self.find_edge_pos(a, b).is_ok()
401    }
402
403    fn neighbors_range(&self, a: NodeIndex<Ix>) -> Range<usize> {
404        let index = self.row[a.index()];
405        let end = self
406            .row
407            .get(a.index() + 1)
408            .cloned()
409            .unwrap_or(self.column.len());
410        index..end
411    }
412
413    fn neighbors_of(&self, a: NodeIndex<Ix>) -> (usize, &[Ix]) {
414        let r = self.neighbors_range(a);
415        (r.start, &self.column[r])
416    }
417
418    /// Computes in **O(1)** time.
419    ///
420    /// **Panics** if the node `a` does not exist.
421    #[track_caller]
422    pub fn out_degree(&self, a: NodeIndex<Ix>) -> usize {
423        let r = self.neighbors_range(a);
424        r.end - r.start
425    }
426
427    /// Computes in **O(1)** time.
428    ///
429    /// **Panics** if the node `a` does not exist.
430    #[track_caller]
431    pub fn neighbors_slice(&self, a: NodeIndex<Ix>) -> &[NodeIndex<Ix>] {
432        self.neighbors_of(a).1
433    }
434
435    /// Computes in **O(1)** time.
436    ///
437    /// **Panics** if the node `a` does not exist.
438    #[track_caller]
439    pub fn edges_slice(&self, a: NodeIndex<Ix>) -> &[E] {
440        &self.edges[self.neighbors_range(a)]
441    }
442
443    /// Return an iterator of all edges of `a`.
444    ///
445    /// - `Directed`: Outgoing edges from `a`.
446    /// - `Undirected`: All edges connected to `a`.
447    ///
448    /// **Panics** if the node `a` does not exist.<br>
449    /// Iterator element type is `EdgeReference<E, Ty, Ix>`.
450    #[track_caller]
451    pub fn edges(&self, a: NodeIndex<Ix>) -> Edges<E, Ty, Ix> {
452        let r = self.neighbors_range(a);
453        Edges {
454            index: r.start,
455            source: a,
456            iter: zip(&self.column[r.clone()], &self.edges[r]),
457            ty: self.ty,
458        }
459    }
460}
461
462#[derive(Clone, Debug)]
463pub struct Edges<'a, E: 'a, Ty = Directed, Ix: 'a = DefaultIx> {
464    index: usize,
465    source: NodeIndex<Ix>,
466    iter: Zip<SliceIter<'a, NodeIndex<Ix>>, SliceIter<'a, E>>,
467    ty: PhantomData<Ty>,
468}
469
470#[derive(Debug)]
471pub struct EdgeReference<'a, E: 'a, Ty, Ix: 'a = DefaultIx> {
472    index: EdgeIndex,
473    source: NodeIndex<Ix>,
474    target: NodeIndex<Ix>,
475    weight: &'a E,
476    ty: PhantomData<Ty>,
477}
478
479impl<E, Ty, Ix: Copy> Clone for EdgeReference<'_, E, Ty, Ix> {
480    fn clone(&self) -> Self {
481        *self
482    }
483}
484
485impl<E, Ty, Ix: Copy> Copy for EdgeReference<'_, E, Ty, Ix> {}
486
487impl<'a, Ty, E, Ix> EdgeReference<'a, E, Ty, Ix>
488where
489    Ty: EdgeType,
490{
491    /// Access the edge’s weight.
492    ///
493    /// **NOTE** that this method offers a longer lifetime
494    /// than the trait (unfortunately they don't match yet).
495    pub fn weight(&self) -> &'a E {
496        self.weight
497    }
498}
499
500impl<E, Ty, Ix> EdgeRef for EdgeReference<'_, E, Ty, Ix>
501where
502    Ty: EdgeType,
503    Ix: IndexType,
504{
505    type NodeId = NodeIndex<Ix>;
506    type EdgeId = EdgeIndex;
507    type Weight = E;
508
509    fn source(&self) -> Self::NodeId {
510        self.source
511    }
512    fn target(&self) -> Self::NodeId {
513        self.target
514    }
515    fn weight(&self) -> &E {
516        self.weight
517    }
518    fn id(&self) -> Self::EdgeId {
519        self.index
520    }
521}
522
523impl<'a, E, Ty, Ix> Iterator for Edges<'a, E, Ty, Ix>
524where
525    Ty: EdgeType,
526    Ix: IndexType,
527{
528    type Item = EdgeReference<'a, E, Ty, Ix>;
529    fn next(&mut self) -> Option<Self::Item> {
530        self.iter.next().map(move |(&j, w)| {
531            let index = self.index;
532            self.index += 1;
533            EdgeReference {
534                index,
535                source: self.source,
536                target: j,
537                weight: w,
538                ty: PhantomData,
539            }
540        })
541    }
542    fn size_hint(&self) -> (usize, Option<usize>) {
543        self.iter.size_hint()
544    }
545}
546
547impl<N, E, Ty, Ix> Data for Csr<N, E, Ty, Ix>
548where
549    Ty: EdgeType,
550    Ix: IndexType,
551{
552    type NodeWeight = N;
553    type EdgeWeight = E;
554}
555
556impl<'a, N, E, Ty, Ix> IntoEdgeReferences for &'a Csr<N, E, Ty, Ix>
557where
558    Ty: EdgeType,
559    Ix: IndexType,
560{
561    type EdgeRef = EdgeReference<'a, E, Ty, Ix>;
562    type EdgeReferences = EdgeReferences<'a, E, Ty, Ix>;
563    fn edge_references(self) -> Self::EdgeReferences {
564        EdgeReferences {
565            index: 0,
566            source_index: Ix::new(0),
567            edge_ranges: self.row.windows(2).enumerate(),
568            column: &self.column,
569            edges: &self.edges,
570            iter: zip(&[], &[]),
571            ty: self.ty,
572        }
573    }
574}
575
576#[derive(Debug, Clone)]
577pub struct EdgeReferences<'a, E: 'a, Ty, Ix: 'a> {
578    source_index: NodeIndex<Ix>,
579    index: usize,
580    edge_ranges: Enumerate<Windows<'a, usize>>,
581    column: &'a [NodeIndex<Ix>],
582    edges: &'a [E],
583    iter: Zip<SliceIter<'a, NodeIndex<Ix>>, SliceIter<'a, E>>,
584    ty: PhantomData<Ty>,
585}
586
587impl<'a, E, Ty, Ix> Iterator for EdgeReferences<'a, E, Ty, Ix>
588where
589    Ty: EdgeType,
590    Ix: IndexType,
591{
592    type Item = EdgeReference<'a, E, Ty, Ix>;
593    fn next(&mut self) -> Option<Self::Item> {
594        loop {
595            if let Some((&j, w)) = self.iter.next() {
596                let index = self.index;
597                self.index += 1;
598                return Some(EdgeReference {
599                    index,
600                    source: self.source_index,
601                    target: j,
602                    weight: w,
603                    ty: PhantomData,
604                });
605            }
606            if let Some((i, w)) = self.edge_ranges.next() {
607                let a = w[0];
608                let b = w[1];
609                self.iter = zip(&self.column[a..b], &self.edges[a..b]);
610                self.source_index = Ix::new(i);
611            } else {
612                return None;
613            }
614        }
615    }
616}
617
618impl<'a, N, E, Ty, Ix> IntoEdges for &'a Csr<N, E, Ty, Ix>
619where
620    Ty: EdgeType,
621    Ix: IndexType,
622{
623    type Edges = Edges<'a, E, Ty, Ix>;
624    fn edges(self, a: Self::NodeId) -> Self::Edges {
625        self.edges(a)
626    }
627}
628
629impl<N, E, Ty, Ix> GraphBase for Csr<N, E, Ty, Ix>
630where
631    Ty: EdgeType,
632    Ix: IndexType,
633{
634    type NodeId = NodeIndex<Ix>;
635    type EdgeId = EdgeIndex; // index into edges vector
636}
637
638use fixedbitset::FixedBitSet;
639
640impl<N, E, Ty, Ix> Visitable for Csr<N, E, Ty, Ix>
641where
642    Ty: EdgeType,
643    Ix: IndexType,
644{
645    type Map = FixedBitSet;
646    fn visit_map(&self) -> FixedBitSet {
647        FixedBitSet::with_capacity(self.node_count())
648    }
649    fn reset_map(&self, map: &mut Self::Map) {
650        map.clear();
651        map.grow(self.node_count());
652    }
653}
654
655use core::slice::Iter as SliceIter;
656
657#[derive(Clone, Debug)]
658pub struct Neighbors<'a, Ix: 'a = DefaultIx> {
659    iter: SliceIter<'a, NodeIndex<Ix>>,
660}
661
662impl<Ix> Iterator for Neighbors<'_, Ix>
663where
664    Ix: IndexType,
665{
666    type Item = NodeIndex<Ix>;
667
668    fn next(&mut self) -> Option<Self::Item> {
669        self.iter.next().cloned()
670    }
671
672    fn size_hint(&self) -> (usize, Option<usize>) {
673        self.iter.size_hint()
674    }
675}
676
677impl<'a, N, E, Ty, Ix> IntoNeighbors for &'a Csr<N, E, Ty, Ix>
678where
679    Ty: EdgeType,
680    Ix: IndexType,
681{
682    type Neighbors = Neighbors<'a, Ix>;
683
684    /// Return an iterator of all neighbors of `a`.
685    ///
686    /// - `Directed`: Targets of outgoing edges from `a`.
687    /// - `Undirected`: Opposing endpoints of all edges connected to `a`.
688    ///
689    /// **Panics** if the node `a` does not exist.<br>
690    /// Iterator element type is `NodeIndex<Ix>`.
691    #[track_caller]
692    fn neighbors(self, a: Self::NodeId) -> Self::Neighbors {
693        Neighbors {
694            iter: self.neighbors_slice(a).iter(),
695        }
696    }
697}
698
699impl<N, E, Ty, Ix> NodeIndexable for Csr<N, E, Ty, Ix>
700where
701    Ty: EdgeType,
702    Ix: IndexType,
703{
704    fn node_bound(&self) -> usize {
705        self.node_count()
706    }
707    fn to_index(&self, a: Self::NodeId) -> usize {
708        a.index()
709    }
710    fn from_index(&self, ix: usize) -> Self::NodeId {
711        Ix::new(ix)
712    }
713}
714
715impl<N, E, Ty, Ix> NodeCompactIndexable for Csr<N, E, Ty, Ix>
716where
717    Ty: EdgeType,
718    Ix: IndexType,
719{
720}
721
722impl<N, E, Ty, Ix> Index<NodeIndex<Ix>> for Csr<N, E, Ty, Ix>
723where
724    Ty: EdgeType,
725    Ix: IndexType,
726{
727    type Output = N;
728
729    fn index(&self, ix: NodeIndex<Ix>) -> &N {
730        &self.node_weights[ix.index()]
731    }
732}
733
734impl<N, E, Ty, Ix> IndexMut<NodeIndex<Ix>> for Csr<N, E, Ty, Ix>
735where
736    Ty: EdgeType,
737    Ix: IndexType,
738{
739    fn index_mut(&mut self, ix: NodeIndex<Ix>) -> &mut N {
740        &mut self.node_weights[ix.index()]
741    }
742}
743
744#[derive(Debug, Clone)]
745pub struct NodeIdentifiers<Ix = DefaultIx> {
746    r: Range<usize>,
747    ty: PhantomData<Ix>,
748}
749
750impl<Ix> Iterator for NodeIdentifiers<Ix>
751where
752    Ix: IndexType,
753{
754    type Item = NodeIndex<Ix>;
755
756    fn next(&mut self) -> Option<Self::Item> {
757        self.r.next().map(Ix::new)
758    }
759
760    fn size_hint(&self) -> (usize, Option<usize>) {
761        self.r.size_hint()
762    }
763}
764
765impl<N, E, Ty, Ix> IntoNodeIdentifiers for &Csr<N, E, Ty, Ix>
766where
767    Ty: EdgeType,
768    Ix: IndexType,
769{
770    type NodeIdentifiers = NodeIdentifiers<Ix>;
771    fn node_identifiers(self) -> Self::NodeIdentifiers {
772        NodeIdentifiers {
773            r: 0..self.node_count(),
774            ty: PhantomData,
775        }
776    }
777}
778
779impl<N, E, Ty, Ix> NodeCount for Csr<N, E, Ty, Ix>
780where
781    Ty: EdgeType,
782    Ix: IndexType,
783{
784    fn node_count(&self) -> usize {
785        (*self).node_count()
786    }
787}
788
789impl<N, E, Ty, Ix> EdgeCount for Csr<N, E, Ty, Ix>
790where
791    Ty: EdgeType,
792    Ix: IndexType,
793{
794    #[inline]
795    fn edge_count(&self) -> usize {
796        self.edge_count()
797    }
798}
799
800impl<N, E, Ty, Ix> GraphProp for Csr<N, E, Ty, Ix>
801where
802    Ty: EdgeType,
803    Ix: IndexType,
804{
805    type EdgeType = Ty;
806}
807
808impl<'a, N, E, Ty, Ix> IntoNodeReferences for &'a Csr<N, E, Ty, Ix>
809where
810    Ty: EdgeType,
811    Ix: IndexType,
812{
813    type NodeRef = (NodeIndex<Ix>, &'a N);
814    type NodeReferences = NodeReferences<'a, N, Ix>;
815    fn node_references(self) -> Self::NodeReferences {
816        NodeReferences {
817            iter: self.node_weights.iter().enumerate(),
818            ty: PhantomData,
819        }
820    }
821}
822
823/// Iterator over all nodes of a graph.
824#[derive(Debug, Clone)]
825pub struct NodeReferences<'a, N: 'a, Ix: IndexType = DefaultIx> {
826    iter: Enumerate<SliceIter<'a, N>>,
827    ty: PhantomData<Ix>,
828}
829
830impl<'a, N, Ix> Iterator for NodeReferences<'a, N, Ix>
831where
832    Ix: IndexType,
833{
834    type Item = (NodeIndex<Ix>, &'a N);
835
836    fn next(&mut self) -> Option<Self::Item> {
837        self.iter.next().map(|(i, weight)| (Ix::new(i), weight))
838    }
839
840    fn size_hint(&self) -> (usize, Option<usize>) {
841        self.iter.size_hint()
842    }
843}
844
845impl<N, Ix> DoubleEndedIterator for NodeReferences<'_, N, Ix>
846where
847    Ix: IndexType,
848{
849    fn next_back(&mut self) -> Option<Self::Item> {
850        self.iter
851            .next_back()
852            .map(|(i, weight)| (Ix::new(i), weight))
853    }
854}
855
856impl<N, Ix> ExactSizeIterator for NodeReferences<'_, N, Ix> where Ix: IndexType {}
857
858/// The adjacency matrix for **Csr** is a bitmap that's computed by
859/// `.adjacency_matrix()`.
860impl<N, E, Ty, Ix> GetAdjacencyMatrix for &Csr<N, E, Ty, Ix>
861where
862    Ix: IndexType,
863    Ty: EdgeType,
864{
865    type AdjMatrix = FixedBitSet;
866
867    fn adjacency_matrix(&self) -> FixedBitSet {
868        let n = self.node_count();
869        let mut matrix = FixedBitSet::with_capacity(n * n);
870        for edge in self.edge_references() {
871            let i = n * edge.source().index() + edge.target().index();
872            matrix.put(i);
873
874            if !self.is_directed() {
875                let j = edge.source().index() + n * edge.target().index();
876                matrix.put(j);
877            }
878        }
879        matrix
880    }
881
882    fn is_adjacent(&self, matrix: &FixedBitSet, a: NodeIndex<Ix>, b: NodeIndex<Ix>) -> bool {
883        let n = self.node_count();
884        let index = n * a.index() + b.index();
885        matrix.contains(index)
886    }
887}
888
889/*
890 *
891Example
892
893[ a 0 b
894  c d e
895  0 0 f ]
896
897Values: [a, b, c, d, e, f]
898Column: [0, 2, 0, 1, 2, 2]
899Row   : [0, 2, 5]   <- value index of row start
900
901 * */
902
903#[cfg(test)]
904mod tests {
905    use alloc::vec::Vec;
906    use std::println;
907
908    use super::Csr;
909    use crate::algo::bellman_ford;
910    use crate::algo::find_negative_cycle;
911    use crate::algo::tarjan_scc;
912    use crate::visit::Dfs;
913    use crate::visit::VisitMap;
914    use crate::Undirected;
915
916    #[test]
917    fn csr1() {
918        let mut m: Csr = Csr::with_nodes(3);
919        m.add_edge(0, 0, ());
920        m.add_edge(1, 2, ());
921        m.add_edge(2, 2, ());
922        m.add_edge(0, 2, ());
923        m.add_edge(1, 0, ());
924        m.add_edge(1, 1, ());
925        println!("{:?}", m);
926        assert_eq!(&m.column, &[0, 2, 0, 1, 2, 2]);
927        assert_eq!(&m.row, &[0, 2, 5, 6]);
928
929        let added = m.add_edge(1, 2, ());
930        assert!(!added);
931        assert_eq!(&m.column, &[0, 2, 0, 1, 2, 2]);
932        assert_eq!(&m.row, &[0, 2, 5, 6]);
933
934        assert_eq!(m.neighbors_slice(1), &[0, 1, 2]);
935        assert_eq!(m.node_count(), 3);
936        assert_eq!(m.edge_count(), 6);
937    }
938
939    #[test]
940    fn csr_undirected() {
941        /*
942           [ 1 . 1
943             . . 1
944             1 1 1 ]
945        */
946
947        let mut m: Csr<(), (), Undirected> = Csr::with_nodes(3);
948        m.add_edge(0, 0, ());
949        m.add_edge(0, 2, ());
950        m.add_edge(1, 2, ());
951        m.add_edge(2, 2, ());
952        println!("{:?}", m);
953        assert_eq!(&m.column, &[0, 2, 2, 0, 1, 2]);
954        assert_eq!(&m.row, &[0, 2, 3, 6]);
955        assert_eq!(m.node_count(), 3);
956        assert_eq!(m.edge_count(), 4);
957    }
958
959    #[should_panic]
960    #[test]
961    fn csr_from_error_1() {
962        // not sorted in source
963        let m: Csr = Csr::from_sorted_edges(&[(0, 1), (1, 0), (0, 2)]).unwrap();
964        println!("{:?}", m);
965    }
966
967    #[should_panic]
968    #[test]
969    fn csr_from_error_2() {
970        // not sorted in target
971        let m: Csr = Csr::from_sorted_edges(&[(0, 1), (1, 0), (1, 2), (1, 1)]).unwrap();
972        println!("{:?}", m);
973    }
974
975    #[test]
976    fn csr_from() {
977        let m: Csr =
978            Csr::from_sorted_edges(&[(0, 1), (0, 2), (1, 0), (1, 1), (2, 2), (2, 4)]).unwrap();
979        println!("{:?}", m);
980        assert_eq!(m.neighbors_slice(0), &[1, 2]);
981        assert_eq!(m.neighbors_slice(1), &[0, 1]);
982        assert_eq!(m.neighbors_slice(2), &[2, 4]);
983        assert_eq!(m.node_count(), 5);
984        assert_eq!(m.edge_count(), 6);
985    }
986
987    #[test]
988    fn csr_dfs() {
989        let mut m: Csr = Csr::from_sorted_edges(&[
990            (0, 1),
991            (0, 2),
992            (1, 0),
993            (1, 1),
994            (1, 3),
995            (2, 2),
996            // disconnected subgraph
997            (4, 4),
998            (4, 5),
999        ])
1000        .unwrap();
1001        println!("{:?}", m);
1002        let mut dfs = Dfs::new(&m, 0);
1003        while dfs.next(&m).is_some() {}
1004        for i in 0..m.node_count() - 2 {
1005            assert!(dfs.discovered.is_visited(&i), "visited {}", i)
1006        }
1007        assert!(!dfs.discovered[4]);
1008        assert!(!dfs.discovered[5]);
1009
1010        m.add_edge(1, 4, ());
1011        println!("{:?}", m);
1012
1013        dfs.reset(&m);
1014        dfs.move_to(0);
1015        while dfs.next(&m).is_some() {}
1016
1017        for i in 0..m.node_count() {
1018            assert!(dfs.discovered[i], "visited {}", i)
1019        }
1020    }
1021
1022    #[test]
1023    fn csr_tarjan() {
1024        let m: Csr = Csr::from_sorted_edges(&[
1025            (0, 1),
1026            (0, 2),
1027            (1, 0),
1028            (1, 1),
1029            (1, 3),
1030            (2, 2),
1031            (2, 4),
1032            (4, 4),
1033            (4, 5),
1034            (5, 2),
1035        ])
1036        .unwrap();
1037        println!("{:?}", m);
1038        println!("{:?}", tarjan_scc(&m));
1039    }
1040
1041    #[test]
1042    fn test_bellman_ford() {
1043        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1044            (0, 1, 0.5),
1045            (0, 2, 2.),
1046            (1, 0, 1.),
1047            (1, 1, 1.),
1048            (1, 2, 1.),
1049            (1, 3, 1.),
1050            (2, 3, 3.),
1051            (4, 5, 1.),
1052            (5, 7, 2.),
1053            (6, 7, 1.),
1054            (7, 8, 3.),
1055        ])
1056        .unwrap();
1057        println!("{:?}", m);
1058        let result = bellman_ford(&m, 0).unwrap();
1059        println!("{:?}", result);
1060        let answer = [0., 0.5, 1.5, 1.5];
1061        assert_eq!(&answer, &result.distances[..4]);
1062        assert!(result.distances[4..].iter().all(|&x| f64::is_infinite(x)));
1063    }
1064
1065    #[test]
1066    fn test_bellman_ford_neg_cycle() {
1067        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1068            (0, 1, 0.5),
1069            (0, 2, 2.),
1070            (1, 0, 1.),
1071            (1, 1, -1.),
1072            (1, 2, 1.),
1073            (1, 3, 1.),
1074            (2, 3, 3.),
1075        ])
1076        .unwrap();
1077        let result = bellman_ford(&m, 0);
1078        assert!(result.is_err());
1079    }
1080
1081    #[test]
1082    fn test_find_neg_cycle1() {
1083        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1084            (0, 1, 0.5),
1085            (0, 2, 2.),
1086            (1, 0, 1.),
1087            (1, 1, -1.),
1088            (1, 2, 1.),
1089            (1, 3, 1.),
1090            (2, 3, 3.),
1091        ])
1092        .unwrap();
1093        let result = find_negative_cycle(&m, 0);
1094        assert_eq!(result, Some([1].to_vec()));
1095    }
1096
1097    #[test]
1098    fn test_find_neg_cycle2() {
1099        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1100            (0, 1, 0.5),
1101            (0, 2, 2.),
1102            (1, 0, 1.),
1103            (1, 2, 1.),
1104            (1, 3, 1.),
1105            (2, 3, 3.),
1106        ])
1107        .unwrap();
1108        let result = find_negative_cycle(&m, 0);
1109        assert_eq!(result, None);
1110    }
1111
1112    #[test]
1113    fn test_find_neg_cycle3() {
1114        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1115            (0, 1, 1.),
1116            (0, 2, 1.),
1117            (0, 3, 1.),
1118            (1, 3, 1.),
1119            (2, 1, 1.),
1120            (3, 2, -3.),
1121        ])
1122        .unwrap();
1123        let result = find_negative_cycle(&m, 0);
1124        assert_eq!(result, Some([1, 3, 2].to_vec()));
1125    }
1126
1127    #[test]
1128    fn test_find_neg_cycle4() {
1129        let m: Csr<(), _> = Csr::from_sorted_edges(&[(0, 0, -1.)]).unwrap();
1130        let result = find_negative_cycle(&m, 0);
1131        assert_eq!(result, Some([0].to_vec()));
1132    }
1133
1134    #[test]
1135    fn test_edge_references() {
1136        use crate::visit::EdgeRef;
1137        use crate::visit::IntoEdgeReferences;
1138        let m: Csr<(), _> = Csr::from_sorted_edges(&[
1139            (0, 1, 0.5),
1140            (0, 2, 2.),
1141            (1, 0, 1.),
1142            (1, 1, 1.),
1143            (1, 2, 1.),
1144            (1, 3, 1.),
1145            (2, 3, 3.),
1146            (4, 5, 1.),
1147            (5, 7, 2.),
1148            (6, 7, 1.),
1149            (7, 8, 3.),
1150        ])
1151        .unwrap();
1152        let mut copy = Vec::new();
1153        for e in m.edge_references() {
1154            copy.push((e.source(), e.target(), *e.weight()));
1155            println!("{:?}", e);
1156        }
1157        let m2: Csr<(), _> = Csr::from_sorted_edges(&copy).unwrap();
1158        assert_eq!(&m.row, &m2.row);
1159        assert_eq!(&m.column, &m2.column);
1160        assert_eq!(&m.edges, &m2.edges);
1161    }
1162
1163    #[test]
1164    fn test_add_node() {
1165        let mut g: Csr = Csr::new();
1166        let a = g.add_node(());
1167        let b = g.add_node(());
1168        let c = g.add_node(());
1169
1170        assert!(g.add_edge(a, b, ()));
1171        assert!(g.add_edge(b, c, ()));
1172        assert!(g.add_edge(c, a, ()));
1173
1174        println!("{:?}", g);
1175
1176        assert_eq!(g.node_count(), 3);
1177
1178        assert_eq!(g.neighbors_slice(a), &[b]);
1179        assert_eq!(g.neighbors_slice(b), &[c]);
1180        assert_eq!(g.neighbors_slice(c), &[a]);
1181
1182        assert_eq!(g.edge_count(), 3);
1183    }
1184
1185    #[test]
1186    fn test_add_node_with_existing_edges() {
1187        let mut g: Csr = Csr::new();
1188        let a = g.add_node(());
1189        let b = g.add_node(());
1190
1191        assert!(g.add_edge(a, b, ()));
1192
1193        let c = g.add_node(());
1194
1195        println!("{:?}", g);
1196
1197        assert_eq!(g.node_count(), 3);
1198
1199        assert_eq!(g.neighbors_slice(a), &[b]);
1200        assert_eq!(g.neighbors_slice(b), &[]);
1201        assert_eq!(g.neighbors_slice(c), &[]);
1202
1203        assert_eq!(g.edge_count(), 1);
1204    }
1205
1206    #[test]
1207    fn test_node_references() {
1208        use crate::visit::IntoNodeReferences;
1209        let mut g: Csr<u32> = Csr::new();
1210        g.add_node(42);
1211        g.add_node(3);
1212        g.add_node(44);
1213
1214        let mut refs = g.node_references();
1215        assert_eq!(refs.next(), Some((0, &42)));
1216        assert_eq!(refs.next(), Some((1, &3)));
1217        assert_eq!(refs.next(), Some((2, &44)));
1218        assert_eq!(refs.next(), None);
1219    }
1220}