1use crate::errors::PartitionerBuilderParseError;
11use std::{fmt, str::FromStr};
12use xxhash_rust::xxh64::xxh64;
13
14#[derive(Clone, Debug, Eq, PartialEq)]
19#[non_exhaustive]
20pub enum PartitionerBuilder {
21 Count {
23 shard: u64,
25
26 total_shards: u64,
28 },
29
30 Hash {
32 shard: u64,
34
35 total_shards: u64,
37 },
38}
39
40pub trait Partitioner: fmt::Debug {
42 fn test_matches(&mut self, test_name: &str) -> bool;
44}
45
46impl PartitionerBuilder {
47 pub fn build(&self) -> Box<dyn Partitioner> {
49 match self {
51 PartitionerBuilder::Count {
52 shard,
53 total_shards,
54 } => Box::new(CountPartitioner::new(*shard, *total_shards)),
55 PartitionerBuilder::Hash {
56 shard,
57 total_shards,
58 } => Box::new(HashPartitioner::new(*shard, *total_shards)),
59 }
60 }
61}
62
63impl FromStr for PartitionerBuilder {
64 type Err = PartitionerBuilderParseError;
65
66 fn from_str(s: &str) -> Result<Self, Self::Err> {
67 if let Some(input) = s.strip_prefix("hash:") {
69 let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
70
71 Ok(PartitionerBuilder::Hash {
72 shard,
73 total_shards,
74 })
75 } else if let Some(input) = s.strip_prefix("count:") {
76 let (shard, total_shards) = parse_shards(input, "count:M/N")?;
77
78 Ok(PartitionerBuilder::Count {
79 shard,
80 total_shards,
81 })
82 } else {
83 Err(PartitionerBuilderParseError::new(
84 None,
85 format!("partition input '{s}' must begin with \"hash:\" or \"count:\""),
86 ))
87 }
88 }
89}
90
91fn parse_shards(
92 input: &str,
93 expected_format: &'static str,
94) -> Result<(u64, u64), PartitionerBuilderParseError> {
95 let mut split = input.splitn(2, '/');
96 let shard_str = split.next().expect("split should have at least 1 element");
98 let total_shards_str = split.next().ok_or_else(|| {
100 PartitionerBuilderParseError::new(
101 Some(expected_format),
102 format!("expected input '{input}' to be in the format M/N"),
103 )
104 })?;
105
106 let shard: u64 = shard_str.parse().map_err(|err| {
107 PartitionerBuilderParseError::new(
108 Some(expected_format),
109 format!("failed to parse shard '{shard_str}' as u64: {err}"),
110 )
111 })?;
112
113 let total_shards: u64 = total_shards_str.parse().map_err(|err| {
114 PartitionerBuilderParseError::new(
115 Some(expected_format),
116 format!("failed to parse total_shards '{total_shards_str}' as u64: {err}"),
117 )
118 })?;
119
120 if !(1..=total_shards).contains(&shard) {
122 return Err(PartitionerBuilderParseError::new(
123 Some(expected_format),
124 format!(
125 "shard {shard} must be a number between 1 and total shards {total_shards}, inclusive"
126 ),
127 ));
128 }
129
130 Ok((shard, total_shards))
131}
132
133#[derive(Clone, Debug)]
134struct CountPartitioner {
135 shard_minus_one: u64,
136 total_shards: u64,
137 curr: u64,
138}
139
140impl CountPartitioner {
141 fn new(shard: u64, total_shards: u64) -> Self {
142 let shard_minus_one = shard - 1;
143 Self {
144 shard_minus_one,
145 total_shards,
146 curr: 0,
147 }
148 }
149}
150
151impl Partitioner for CountPartitioner {
152 fn test_matches(&mut self, _test_name: &str) -> bool {
153 let matches = self.curr == self.shard_minus_one;
154 self.curr = (self.curr + 1) % self.total_shards;
155 matches
156 }
157}
158
159#[derive(Clone, Debug)]
160struct HashPartitioner {
161 shard_minus_one: u64,
162 total_shards: u64,
163}
164
165impl HashPartitioner {
166 fn new(shard: u64, total_shards: u64) -> Self {
167 let shard_minus_one = shard - 1;
168 Self {
169 shard_minus_one,
170 total_shards,
171 }
172 }
173}
174
175impl Partitioner for HashPartitioner {
176 fn test_matches(&mut self, test_name: &str) -> bool {
177 xxh64(test_name.as_bytes(), 0) % self.total_shards == self.shard_minus_one
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn partitioner_builder_from_str() {
188 let successes = vec![
189 (
190 "hash:1/2",
191 PartitionerBuilder::Hash {
192 shard: 1,
193 total_shards: 2,
194 },
195 ),
196 (
197 "hash:1/1",
198 PartitionerBuilder::Hash {
199 shard: 1,
200 total_shards: 1,
201 },
202 ),
203 (
204 "hash:99/200",
205 PartitionerBuilder::Hash {
206 shard: 99,
207 total_shards: 200,
208 },
209 ),
210 ];
211
212 let failures = vec![
213 "foo",
214 "hash",
215 "hash:",
216 "hash:1",
217 "hash:1/",
218 "hash:0/2",
219 "hash:3/2",
220 "hash:m/2",
221 "hash:1/n",
222 "hash:1/2/3",
223 ];
224
225 for (input, output) in successes {
226 assert_eq!(
227 PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
228 "expected input '{input}' to succeed, failed with: {err}"
229 )),
230 output,
231 "success case '{input}' matches",
232 );
233 }
234
235 for input in failures {
236 PartitionerBuilder::from_str(input)
237 .expect_err(&format!("expected input '{input}' to fail"));
238 }
239 }
240}