nextest_runner/
partition.rs

1// Copyright (c) The nextest Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Support for partitioning test runs across several machines.
5//!
6//! At the moment this only supports simple hash-based and count-based sharding. In the future it
7//! could potentially be made smarter: e.g. using data to pick different sets of binaries and tests
8//! to run, with an aim to minimize total build and test times.
9
10use crate::errors::PartitionerBuilderParseError;
11use std::{fmt, str::FromStr};
12use xxhash_rust::xxh64::xxh64;
13
14/// A builder for creating `Partitioner` instances.
15///
16/// The relationship between `PartitionerBuilder` and `Partitioner` is similar to that between
17/// `std`'s `BuildHasher` and `Hasher`.
18#[derive(Clone, Debug, Eq, PartialEq)]
19#[non_exhaustive]
20pub enum PartitionerBuilder {
21    /// Partition based on counting test numbers.
22    Count {
23        /// The shard this is in, counting up from 1.
24        shard: u64,
25
26        /// The total number of shards.
27        total_shards: u64,
28    },
29
30    /// Partition based on hashing. Individual partitions are stateless.
31    Hash {
32        /// The shard this is in, counting up from 1.
33        shard: u64,
34
35        /// The total number of shards.
36        total_shards: u64,
37    },
38}
39
40/// Represents an individual partitioner, typically scoped to a test binary.
41pub trait Partitioner: fmt::Debug {
42    /// Returns true if the given test name matches the partition.
43    fn test_matches(&mut self, test_name: &str) -> bool;
44}
45
46impl PartitionerBuilder {
47    /// Creates a new `Partitioner` from this `PartitionerBuilder`.
48    pub fn build(&self) -> Box<dyn Partitioner> {
49        // Note we don't use test_binary at the moment but might in the future.
50        match self {
51            PartitionerBuilder::Count {
52                shard,
53                total_shards,
54            } => Box::new(CountPartitioner::new(*shard, *total_shards)),
55            PartitionerBuilder::Hash {
56                shard,
57                total_shards,
58            } => Box::new(HashPartitioner::new(*shard, *total_shards)),
59        }
60    }
61}
62
63impl FromStr for PartitionerBuilder {
64    type Err = PartitionerBuilderParseError;
65
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        // Parse the string: it looks like "hash:<shard>/<total_shards>".
68        if let Some(input) = s.strip_prefix("hash:") {
69            let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
70
71            Ok(PartitionerBuilder::Hash {
72                shard,
73                total_shards,
74            })
75        } else if let Some(input) = s.strip_prefix("count:") {
76            let (shard, total_shards) = parse_shards(input, "count:M/N")?;
77
78            Ok(PartitionerBuilder::Count {
79                shard,
80                total_shards,
81            })
82        } else {
83            Err(PartitionerBuilderParseError::new(
84                None,
85                format!("partition input '{s}' must begin with \"hash:\" or \"count:\""),
86            ))
87        }
88    }
89}
90
91fn parse_shards(
92    input: &str,
93    expected_format: &'static str,
94) -> Result<(u64, u64), PartitionerBuilderParseError> {
95    let mut split = input.splitn(2, '/');
96    // First "next" always returns a value.
97    let shard_str = split.next().expect("split should have at least 1 element");
98    // Second "next" may or may not return a value.
99    let total_shards_str = split.next().ok_or_else(|| {
100        PartitionerBuilderParseError::new(
101            Some(expected_format),
102            format!("expected input '{input}' to be in the format M/N"),
103        )
104    })?;
105
106    let shard: u64 = shard_str.parse().map_err(|err| {
107        PartitionerBuilderParseError::new(
108            Some(expected_format),
109            format!("failed to parse shard '{shard_str}' as u64: {err}"),
110        )
111    })?;
112
113    let total_shards: u64 = total_shards_str.parse().map_err(|err| {
114        PartitionerBuilderParseError::new(
115            Some(expected_format),
116            format!("failed to parse total_shards '{total_shards_str}' as u64: {err}"),
117        )
118    })?;
119
120    // Check that shard > 0 and <= total_shards.
121    if !(1..=total_shards).contains(&shard) {
122        return Err(PartitionerBuilderParseError::new(
123            Some(expected_format),
124            format!(
125                "shard {shard} must be a number between 1 and total shards {total_shards}, inclusive"
126            ),
127        ));
128    }
129
130    Ok((shard, total_shards))
131}
132
133#[derive(Clone, Debug)]
134struct CountPartitioner {
135    shard_minus_one: u64,
136    total_shards: u64,
137    curr: u64,
138}
139
140impl CountPartitioner {
141    fn new(shard: u64, total_shards: u64) -> Self {
142        let shard_minus_one = shard - 1;
143        Self {
144            shard_minus_one,
145            total_shards,
146            curr: 0,
147        }
148    }
149}
150
151impl Partitioner for CountPartitioner {
152    fn test_matches(&mut self, _test_name: &str) -> bool {
153        let matches = self.curr == self.shard_minus_one;
154        self.curr = (self.curr + 1) % self.total_shards;
155        matches
156    }
157}
158
159#[derive(Clone, Debug)]
160struct HashPartitioner {
161    shard_minus_one: u64,
162    total_shards: u64,
163}
164
165impl HashPartitioner {
166    fn new(shard: u64, total_shards: u64) -> Self {
167        let shard_minus_one = shard - 1;
168        Self {
169            shard_minus_one,
170            total_shards,
171        }
172    }
173}
174
175impl Partitioner for HashPartitioner {
176    fn test_matches(&mut self, test_name: &str) -> bool {
177        // NOTE: this is fixed to be xxhash64 for the entire cargo-nextest 0.9 series.
178        xxh64(test_name.as_bytes(), 0) % self.total_shards == self.shard_minus_one
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn partitioner_builder_from_str() {
188        let successes = vec![
189            (
190                "hash:1/2",
191                PartitionerBuilder::Hash {
192                    shard: 1,
193                    total_shards: 2,
194                },
195            ),
196            (
197                "hash:1/1",
198                PartitionerBuilder::Hash {
199                    shard: 1,
200                    total_shards: 1,
201                },
202            ),
203            (
204                "hash:99/200",
205                PartitionerBuilder::Hash {
206                    shard: 99,
207                    total_shards: 200,
208                },
209            ),
210        ];
211
212        let failures = vec![
213            "foo",
214            "hash",
215            "hash:",
216            "hash:1",
217            "hash:1/",
218            "hash:0/2",
219            "hash:3/2",
220            "hash:m/2",
221            "hash:1/n",
222            "hash:1/2/3",
223        ];
224
225        for (input, output) in successes {
226            assert_eq!(
227                PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
228                    "expected input '{input}' to succeed, failed with: {err}"
229                )),
230                output,
231                "success case '{input}' matches",
232            );
233        }
234
235        for input in failures {
236            PartitionerBuilder::from_str(input)
237                .expect_err(&format!("expected input '{input}' to fail"));
238        }
239    }
240}