nextest_runner/
partition.rs

1// Copyright (c) The nextest Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Support for partitioning test runs across several machines.
5//!
6//! Three kinds of partitioning are currently supported:
7//! - **Counted** (`count:M/N`): round-robin partitioning within each binary.
8//! - **Hashed** (`hash:M/N`): deterministic hash-based partitioning within each binary.
9//! - **Sliced** (`slice:M/N`): round-robin partitioning across all binaries (cross-binary).
10//!
11//! In the future, partitioning could potentially be made smarter: e.g. using data to pick different
12//! sets of binaries and tests to run, with an aim to minimize total build and test times.
13
14use crate::errors::PartitionerBuilderParseError;
15use std::{fmt, str::FromStr};
16use xxhash_rust::xxh64::xxh64;
17
18/// A builder for creating `Partitioner` instances.
19///
20/// The relationship between `PartitionerBuilder` and `Partitioner` is similar to that between
21/// `std`'s `BuildHasher` and `Hasher`.
22#[derive(Clone, Debug, Eq, PartialEq)]
23pub enum PartitionerBuilder {
24    /// Partition based on counting test numbers.
25    Count {
26        /// The shard this is in, counting up from 1.
27        shard: u64,
28
29        /// The total number of shards.
30        total_shards: u64,
31    },
32
33    /// Partition based on hashing. Individual partitions are stateless.
34    Hash {
35        /// The shard this is in, counting up from 1.
36        shard: u64,
37
38        /// The total number of shards.
39        total_shards: u64,
40    },
41
42    /// Partition by slicing across all binaries (cross-binary round-robin).
43    ///
44    /// Unlike `Count` (which partitions independently within each binary), `Slice` collects all
45    /// tests across all binaries and distributes them round-robin. This produces even shard sizes
46    /// regardless of how tests are distributed across binaries.
47    Slice {
48        /// The shard this is in, counting up from 1.
49        shard: u64,
50
51        /// The total number of shards.
52        total_shards: u64,
53    },
54}
55
56/// The scope at which a partitioner operates.
57#[derive(Clone, Copy, Debug, Eq, PartialEq)]
58pub enum PartitionerScope {
59    /// Partitioning is applied independently to each test binary.
60    PerBinary,
61
62    /// Partitioning is applied across all test binaries together.
63    CrossBinary,
64}
65
66/// Represents an individual partitioner, typically scoped to a test binary.
67pub trait Partitioner: fmt::Debug {
68    /// Returns true if the given test name matches the partition.
69    fn test_matches(&mut self, test_name: &str) -> bool;
70}
71
72impl PartitionerBuilder {
73    /// Returns the scope at which this partitioner operates.
74    pub fn scope(&self) -> PartitionerScope {
75        match self {
76            PartitionerBuilder::Count { .. } => {
77                // Count is stateful (round-robin), so it must be per-binary
78                // to preserve existing shard assignment behavior.
79                PartitionerScope::PerBinary
80            }
81            PartitionerBuilder::Hash { .. } => {
82                // Hash is stateless: scope doesn't affect results. Per-binary
83                // is chosen arbitrarily.
84                PartitionerScope::PerBinary
85            }
86            PartitionerBuilder::Slice { .. } => PartitionerScope::CrossBinary,
87        }
88    }
89
90    /// Creates a new `Partitioner` from this `PartitionerBuilder`.
91    pub fn build(&self) -> Box<dyn Partitioner> {
92        match self {
93            PartitionerBuilder::Count {
94                shard,
95                total_shards,
96            }
97            | PartitionerBuilder::Slice {
98                shard,
99                total_shards,
100            } => Box::new(CountPartitioner::new(*shard, *total_shards)),
101            PartitionerBuilder::Hash {
102                shard,
103                total_shards,
104            } => Box::new(HashPartitioner::new(*shard, *total_shards)),
105        }
106    }
107}
108
109impl FromStr for PartitionerBuilder {
110    type Err = PartitionerBuilderParseError;
111
112    fn from_str(s: &str) -> Result<Self, Self::Err> {
113        if let Some(input) = s.strip_prefix("hash:") {
114            let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
115
116            Ok(PartitionerBuilder::Hash {
117                shard,
118                total_shards,
119            })
120        } else if let Some(input) = s.strip_prefix("count:") {
121            let (shard, total_shards) = parse_shards(input, "count:M/N")?;
122
123            Ok(PartitionerBuilder::Count {
124                shard,
125                total_shards,
126            })
127        } else if let Some(input) = s.strip_prefix("slice:") {
128            let (shard, total_shards) = parse_shards(input, "slice:M/N")?;
129
130            Ok(PartitionerBuilder::Slice {
131                shard,
132                total_shards,
133            })
134        } else {
135            Err(PartitionerBuilderParseError::new(
136                None,
137                format!(
138                    "partition input '{s}' must begin with \"hash:\", \"count:\", or \"slice:\""
139                ),
140            ))
141        }
142    }
143}
144
145fn parse_shards(
146    input: &str,
147    expected_format: &'static str,
148) -> Result<(u64, u64), PartitionerBuilderParseError> {
149    let mut split = input.splitn(2, '/');
150    // First "next" always returns a value.
151    let shard_str = split.next().expect("split should have at least 1 element");
152    // Second "next" may or may not return a value.
153    let total_shards_str = split.next().ok_or_else(|| {
154        PartitionerBuilderParseError::new(
155            Some(expected_format),
156            format!("expected input '{input}' to be in the format M/N"),
157        )
158    })?;
159
160    let shard: u64 = shard_str.parse().map_err(|err| {
161        PartitionerBuilderParseError::new(
162            Some(expected_format),
163            format!("failed to parse shard '{shard_str}' as u64: {err}"),
164        )
165    })?;
166
167    let total_shards: u64 = total_shards_str.parse().map_err(|err| {
168        PartitionerBuilderParseError::new(
169            Some(expected_format),
170            format!("failed to parse total_shards '{total_shards_str}' as u64: {err}"),
171        )
172    })?;
173
174    // Check that shard > 0 and <= total_shards.
175    if !(1..=total_shards).contains(&shard) {
176        return Err(PartitionerBuilderParseError::new(
177            Some(expected_format),
178            format!(
179                "shard {shard} must be a number between 1 and total shards {total_shards}, inclusive"
180            ),
181        ));
182    }
183
184    Ok((shard, total_shards))
185}
186
187#[derive(Clone, Debug)]
188struct CountPartitioner {
189    shard_minus_one: u64,
190    total_shards: u64,
191    curr: u64,
192}
193
194impl CountPartitioner {
195    fn new(shard: u64, total_shards: u64) -> Self {
196        let shard_minus_one = shard - 1;
197        Self {
198            shard_minus_one,
199            total_shards,
200            curr: 0,
201        }
202    }
203}
204
205impl Partitioner for CountPartitioner {
206    fn test_matches(&mut self, _test_name: &str) -> bool {
207        let matches = self.curr == self.shard_minus_one;
208        self.curr = (self.curr + 1) % self.total_shards;
209        matches
210    }
211}
212
213#[derive(Clone, Debug)]
214struct HashPartitioner {
215    shard_minus_one: u64,
216    total_shards: u64,
217}
218
219impl HashPartitioner {
220    fn new(shard: u64, total_shards: u64) -> Self {
221        let shard_minus_one = shard - 1;
222        Self {
223            shard_minus_one,
224            total_shards,
225        }
226    }
227}
228
229impl Partitioner for HashPartitioner {
230    fn test_matches(&mut self, test_name: &str) -> bool {
231        // NOTE: this is fixed to be xxhash64 for the entire cargo-nextest 0.9 series.
232        xxh64(test_name.as_bytes(), 0) % self.total_shards == self.shard_minus_one
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn partitioner_builder_scope() {
242        assert_eq!(
243            PartitionerBuilder::Count {
244                shard: 1,
245                total_shards: 2,
246            }
247            .scope(),
248            PartitionerScope::PerBinary,
249        );
250        assert_eq!(
251            PartitionerBuilder::Hash {
252                shard: 1,
253                total_shards: 2,
254            }
255            .scope(),
256            PartitionerScope::PerBinary,
257        );
258        assert_eq!(
259            PartitionerBuilder::Slice {
260                shard: 1,
261                total_shards: 3,
262            }
263            .scope(),
264            PartitionerScope::CrossBinary,
265        );
266    }
267
268    #[test]
269    fn partitioner_builder_from_str() {
270        let successes = vec![
271            (
272                "hash:1/2",
273                PartitionerBuilder::Hash {
274                    shard: 1,
275                    total_shards: 2,
276                },
277            ),
278            (
279                "hash:1/1",
280                PartitionerBuilder::Hash {
281                    shard: 1,
282                    total_shards: 1,
283                },
284            ),
285            (
286                "hash:99/200",
287                PartitionerBuilder::Hash {
288                    shard: 99,
289                    total_shards: 200,
290                },
291            ),
292            (
293                "slice:1/3",
294                PartitionerBuilder::Slice {
295                    shard: 1,
296                    total_shards: 3,
297                },
298            ),
299            (
300                "slice:3/3",
301                PartitionerBuilder::Slice {
302                    shard: 3,
303                    total_shards: 3,
304                },
305            ),
306            (
307                "slice:1/1",
308                PartitionerBuilder::Slice {
309                    shard: 1,
310                    total_shards: 1,
311                },
312            ),
313        ];
314
315        let failures = vec![
316            "foo",
317            "hash",
318            "hash:",
319            "hash:1",
320            "hash:1/",
321            "hash:0/2",
322            "hash:3/2",
323            "hash:m/2",
324            "hash:1/n",
325            "hash:1/2/3",
326            "slice:",
327            "slice:0/2",
328            "slice:4/3",
329        ];
330
331        for (input, output) in successes {
332            assert_eq!(
333                PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
334                    "expected input '{input}' to succeed, failed with: {err}"
335                )),
336                output,
337                "success case '{input}' matches",
338            );
339        }
340
341        for input in failures {
342            PartitionerBuilder::from_str(input)
343                .expect_err(&format!("expected input '{input}' to fail"));
344        }
345    }
346}