guppy/petgraph_support/
topo.rs1use petgraph::{
5 graph::IndexType,
6 prelude::*,
7 visit::{
8 GraphRef, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCompactIndexable, VisitMap,
9 Visitable, Walker,
10 },
11};
12use std::marker::PhantomData;
13
14#[derive(Clone, Debug)]
16pub struct TopoWithCycles<Ix> {
17 reverse_index: Box<[usize]>,
19 _phantom: PhantomData<Ix>,
21}
22
23impl<Ix: IndexType> TopoWithCycles<Ix> {
24 pub fn new<G>(graph: G) -> Self
25 where
26 G: GraphRef
27 + Visitable<NodeId = NodeIndex<Ix>>
28 + IntoNodeIdentifiers
29 + IntoNeighborsDirected<NodeId = NodeIndex<Ix>>
30 + NodeCompactIndexable,
31 G::Map: VisitMap<NodeIndex<Ix>>,
32 {
33 let mut dfs = DfsPostOrder::empty(graph);
35
36 let roots = graph
37 .node_identifiers()
38 .filter(move |&a| graph.neighbors_directed(a, Incoming).next().is_none());
39 dfs.stack.extend(roots);
40
41 let mut topo: Vec<NodeIndex<Ix>> = (&mut dfs).iter(graph).collect();
42 topo.reverse();
45
46 let mut reverse_index = vec![0; graph.node_count()];
50 topo.iter().enumerate().for_each(|(topo_ix, node_ix)| {
51 reverse_index[node_ix.index()] = topo_ix;
52 });
53
54 assert!(
56 topo.len() <= graph.node_count(),
57 "topo.len() <= graph.node_count() ({} is actually > {})",
58 topo.len(),
59 graph.node_count(),
60 );
61 if topo.len() < graph.node_count() {
62 let mut next = topo.len();
72 for n in 0..graph.node_count() {
73 let a = NodeIndex::new(n);
74 if !dfs.finished.is_visited(&a) {
75 reverse_index[a.index()] = next;
77 next += 1;
78 }
79 }
80 }
81
82 Self {
83 reverse_index: reverse_index.into_boxed_slice(),
84 _phantom: PhantomData,
85 }
86 }
87
88 #[inline]
90 pub fn sort_nodes(&self, nodes: &mut [NodeIndex<Ix>]) {
91 nodes.sort_unstable_by_key(|node_ix| self.topo_ix(*node_ix))
92 }
93
94 #[inline]
95 pub fn topo_ix(&self, node_ix: NodeIndex<Ix>) -> usize {
96 self.reverse_index[node_ix.index()]
97 }
98}
99
100#[cfg(all(test, feature = "proptest1"))]
101mod proptests {
102 use super::*;
103 use proptest::prelude::*;
104
105 proptest! {
106 #[test]
107 fn graph_topo_sort(graph in possibly_cyclic_graph()) {
108 let topo = TopoWithCycles::new(&graph);
109 let mut nodes: Vec<_> = graph.node_indices().collect();
110
111 check_consistency(&topo, graph.node_count());
112
113 topo.sort_nodes(&mut nodes);
114 for (topo_ix, node_ix) in nodes.iter().enumerate() {
115 assert_eq!(topo.topo_ix(*node_ix), topo_ix);
116 }
117
118 }
119 }
120
121 fn possibly_cyclic_graph() -> impl Strategy<Value = Graph<(), ()>> {
122 (1..=100usize)
124 .prop_flat_map(|n| {
125 (
126 Just(n),
127 prop::collection::vec(prop::collection::vec(0..n, 0..n), n),
128 )
129 })
130 .prop_map(|(n, adj)| {
131 let mut graph =
132 Graph::<(), ()>::with_capacity(n, adj.iter().map(|x| x.len()).sum());
133 for _ in 0..n {
134 graph.add_node(());
136 }
137 for (src, dsts) in adj.into_iter().enumerate() {
138 let src = NodeIndex::new(src);
139 for dst in dsts {
140 let dst = NodeIndex::new(dst);
141 graph.update_edge(src, dst, ());
142 }
143 }
144 graph
145 })
146 }
147
148 fn check_consistency(topo: &TopoWithCycles<u32>, n: usize) {
149 let mut seen = vec![false; n];
151 for i in 0..n {
152 let topo_ix = topo.topo_ix(NodeIndex::new(i));
153 assert!(
154 !seen[topo_ix],
155 "topo_ix {} should be seen exactly once, but seen twice",
156 topo_ix
157 );
158 seen[topo_ix] = true;
159 }
160 for (i, &this_seen) in seen.iter().enumerate() {
161 assert!(this_seen, "topo_ix {} should be seen, but wasn't", i);
162 }
163 }
164}