1use crate::errors::PartitionerBuilderParseError;
15use std::{fmt, str::FromStr};
16use xxhash_rust::xxh64::xxh64;
17
18#[derive(Clone, Debug, Eq, PartialEq)]
23pub enum PartitionerBuilder {
24 Count {
26 shard: u64,
28
29 total_shards: u64,
31 },
32
33 Hash {
35 shard: u64,
37
38 total_shards: u64,
40 },
41
42 Slice {
48 shard: u64,
50
51 total_shards: u64,
53 },
54}
55
56#[derive(Clone, Copy, Debug, Eq, PartialEq)]
58pub enum PartitionerScope {
59 PerBinary,
61
62 CrossBinary,
64}
65
66pub trait Partitioner: fmt::Debug {
68 fn test_matches(&mut self, test_name: &str) -> bool;
70}
71
72impl PartitionerBuilder {
73 pub fn scope(&self) -> PartitionerScope {
75 match self {
76 PartitionerBuilder::Count { .. } => {
77 PartitionerScope::PerBinary
80 }
81 PartitionerBuilder::Hash { .. } => {
82 PartitionerScope::PerBinary
85 }
86 PartitionerBuilder::Slice { .. } => PartitionerScope::CrossBinary,
87 }
88 }
89
90 pub fn build(&self) -> Box<dyn Partitioner> {
92 match self {
93 PartitionerBuilder::Count {
94 shard,
95 total_shards,
96 }
97 | PartitionerBuilder::Slice {
98 shard,
99 total_shards,
100 } => Box::new(CountPartitioner::new(*shard, *total_shards)),
101 PartitionerBuilder::Hash {
102 shard,
103 total_shards,
104 } => Box::new(HashPartitioner::new(*shard, *total_shards)),
105 }
106 }
107}
108
109impl FromStr for PartitionerBuilder {
110 type Err = PartitionerBuilderParseError;
111
112 fn from_str(s: &str) -> Result<Self, Self::Err> {
113 if let Some(input) = s.strip_prefix("hash:") {
114 let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
115
116 Ok(PartitionerBuilder::Hash {
117 shard,
118 total_shards,
119 })
120 } else if let Some(input) = s.strip_prefix("count:") {
121 let (shard, total_shards) = parse_shards(input, "count:M/N")?;
122
123 Ok(PartitionerBuilder::Count {
124 shard,
125 total_shards,
126 })
127 } else if let Some(input) = s.strip_prefix("slice:") {
128 let (shard, total_shards) = parse_shards(input, "slice:M/N")?;
129
130 Ok(PartitionerBuilder::Slice {
131 shard,
132 total_shards,
133 })
134 } else {
135 Err(PartitionerBuilderParseError::new(
136 None,
137 format!(
138 "partition input '{s}' must begin with \"hash:\", \"count:\", or \"slice:\""
139 ),
140 ))
141 }
142 }
143}
144
145fn parse_shards(
146 input: &str,
147 expected_format: &'static str,
148) -> Result<(u64, u64), PartitionerBuilderParseError> {
149 let mut split = input.splitn(2, '/');
150 let shard_str = split.next().expect("split should have at least 1 element");
152 let total_shards_str = split.next().ok_or_else(|| {
154 PartitionerBuilderParseError::new(
155 Some(expected_format),
156 format!("expected input '{input}' to be in the format M/N"),
157 )
158 })?;
159
160 let shard: u64 = shard_str.parse().map_err(|err| {
161 PartitionerBuilderParseError::new(
162 Some(expected_format),
163 format!("failed to parse shard '{shard_str}' as u64: {err}"),
164 )
165 })?;
166
167 let total_shards: u64 = total_shards_str.parse().map_err(|err| {
168 PartitionerBuilderParseError::new(
169 Some(expected_format),
170 format!("failed to parse total_shards '{total_shards_str}' as u64: {err}"),
171 )
172 })?;
173
174 if !(1..=total_shards).contains(&shard) {
176 return Err(PartitionerBuilderParseError::new(
177 Some(expected_format),
178 format!(
179 "shard {shard} must be a number between 1 and total shards {total_shards}, inclusive"
180 ),
181 ));
182 }
183
184 Ok((shard, total_shards))
185}
186
187#[derive(Clone, Debug)]
188struct CountPartitioner {
189 shard_minus_one: u64,
190 total_shards: u64,
191 curr: u64,
192}
193
194impl CountPartitioner {
195 fn new(shard: u64, total_shards: u64) -> Self {
196 let shard_minus_one = shard - 1;
197 Self {
198 shard_minus_one,
199 total_shards,
200 curr: 0,
201 }
202 }
203}
204
205impl Partitioner for CountPartitioner {
206 fn test_matches(&mut self, _test_name: &str) -> bool {
207 let matches = self.curr == self.shard_minus_one;
208 self.curr = (self.curr + 1) % self.total_shards;
209 matches
210 }
211}
212
213#[derive(Clone, Debug)]
214struct HashPartitioner {
215 shard_minus_one: u64,
216 total_shards: u64,
217}
218
219impl HashPartitioner {
220 fn new(shard: u64, total_shards: u64) -> Self {
221 let shard_minus_one = shard - 1;
222 Self {
223 shard_minus_one,
224 total_shards,
225 }
226 }
227}
228
229impl Partitioner for HashPartitioner {
230 fn test_matches(&mut self, test_name: &str) -> bool {
231 xxh64(test_name.as_bytes(), 0) % self.total_shards == self.shard_minus_one
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn partitioner_builder_scope() {
242 assert_eq!(
243 PartitionerBuilder::Count {
244 shard: 1,
245 total_shards: 2,
246 }
247 .scope(),
248 PartitionerScope::PerBinary,
249 );
250 assert_eq!(
251 PartitionerBuilder::Hash {
252 shard: 1,
253 total_shards: 2,
254 }
255 .scope(),
256 PartitionerScope::PerBinary,
257 );
258 assert_eq!(
259 PartitionerBuilder::Slice {
260 shard: 1,
261 total_shards: 3,
262 }
263 .scope(),
264 PartitionerScope::CrossBinary,
265 );
266 }
267
268 #[test]
269 fn partitioner_builder_from_str() {
270 let successes = vec![
271 (
272 "hash:1/2",
273 PartitionerBuilder::Hash {
274 shard: 1,
275 total_shards: 2,
276 },
277 ),
278 (
279 "hash:1/1",
280 PartitionerBuilder::Hash {
281 shard: 1,
282 total_shards: 1,
283 },
284 ),
285 (
286 "hash:99/200",
287 PartitionerBuilder::Hash {
288 shard: 99,
289 total_shards: 200,
290 },
291 ),
292 (
293 "slice:1/3",
294 PartitionerBuilder::Slice {
295 shard: 1,
296 total_shards: 3,
297 },
298 ),
299 (
300 "slice:3/3",
301 PartitionerBuilder::Slice {
302 shard: 3,
303 total_shards: 3,
304 },
305 ),
306 (
307 "slice:1/1",
308 PartitionerBuilder::Slice {
309 shard: 1,
310 total_shards: 1,
311 },
312 ),
313 ];
314
315 let failures = vec![
316 "foo",
317 "hash",
318 "hash:",
319 "hash:1",
320 "hash:1/",
321 "hash:0/2",
322 "hash:3/2",
323 "hash:m/2",
324 "hash:1/n",
325 "hash:1/2/3",
326 "slice:",
327 "slice:0/2",
328 "slice:4/3",
329 ];
330
331 for (input, output) in successes {
332 assert_eq!(
333 PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
334 "expected input '{input}' to succeed, failed with: {err}"
335 )),
336 output,
337 "success case '{input}' matches",
338 );
339 }
340
341 for input in failures {
342 PartitionerBuilder::from_str(input)
343 .expect_err(&format!("expected input '{input}' to fail"));
344 }
345 }
346}