use crate::errors::PartitionerBuilderParseError;
use std::{fmt, str::FromStr};
use xxhash_rust::xxh64::xxh64;
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum PartitionerBuilder {
Count {
shard: u64,
total_shards: u64,
},
Hash {
shard: u64,
total_shards: u64,
},
}
pub trait Partitioner: fmt::Debug {
fn test_matches(&mut self, test_name: &str) -> bool;
}
impl PartitionerBuilder {
pub fn build(&self) -> Box<dyn Partitioner> {
match self {
PartitionerBuilder::Count {
shard,
total_shards,
} => Box::new(CountPartitioner::new(*shard, *total_shards)),
PartitionerBuilder::Hash {
shard,
total_shards,
} => Box::new(HashPartitioner::new(*shard, *total_shards)),
}
}
}
impl FromStr for PartitionerBuilder {
type Err = PartitionerBuilderParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Some(input) = s.strip_prefix("hash:") {
let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
Ok(PartitionerBuilder::Hash {
shard,
total_shards,
})
} else if let Some(input) = s.strip_prefix("count:") {
let (shard, total_shards) = parse_shards(input, "count:M/N")?;
Ok(PartitionerBuilder::Count {
shard,
total_shards,
})
} else {
Err(PartitionerBuilderParseError::new(
None,
format!("partition input '{s}' must begin with \"hash:\" or \"count:\""),
))
}
}
}
fn parse_shards(
input: &str,
expected_format: &'static str,
) -> Result<(u64, u64), PartitionerBuilderParseError> {
let mut split = input.splitn(2, '/');
let shard_str = split.next().expect("split should have at least 1 element");
let total_shards_str = split.next().ok_or_else(|| {
PartitionerBuilderParseError::new(
Some(expected_format),
format!("expected input '{input}' to be in the format M/N"),
)
})?;
let shard: u64 = shard_str.parse().map_err(|err| {
PartitionerBuilderParseError::new(
Some(expected_format),
format!("failed to parse shard '{shard_str}' as u64: {err}"),
)
})?;
let total_shards: u64 = total_shards_str.parse().map_err(|err| {
PartitionerBuilderParseError::new(
Some(expected_format),
format!("failed to parse total_shards '{total_shards_str}' as u64: {err}"),
)
})?;
if !(1..=total_shards).contains(&shard) {
return Err(PartitionerBuilderParseError::new(
Some(expected_format),
format!(
"shard {shard} must be a number between 1 and total shards {total_shards}, inclusive"
),
));
}
Ok((shard, total_shards))
}
#[derive(Clone, Debug)]
struct CountPartitioner {
shard_minus_one: u64,
total_shards: u64,
curr: u64,
}
impl CountPartitioner {
fn new(shard: u64, total_shards: u64) -> Self {
let shard_minus_one = shard - 1;
Self {
shard_minus_one,
total_shards,
curr: 0,
}
}
}
impl Partitioner for CountPartitioner {
fn test_matches(&mut self, _test_name: &str) -> bool {
let matches = self.curr == self.shard_minus_one;
self.curr = (self.curr + 1) % self.total_shards;
matches
}
}
#[derive(Clone, Debug)]
struct HashPartitioner {
shard_minus_one: u64,
total_shards: u64,
}
impl HashPartitioner {
fn new(shard: u64, total_shards: u64) -> Self {
let shard_minus_one = shard - 1;
Self {
shard_minus_one,
total_shards,
}
}
}
impl Partitioner for HashPartitioner {
fn test_matches(&mut self, test_name: &str) -> bool {
xxh64(test_name.as_bytes(), 0) % self.total_shards == self.shard_minus_one
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn partitioner_builder_from_str() {
let successes = vec![
(
"hash:1/2",
PartitionerBuilder::Hash {
shard: 1,
total_shards: 2,
},
),
(
"hash:1/1",
PartitionerBuilder::Hash {
shard: 1,
total_shards: 1,
},
),
(
"hash:99/200",
PartitionerBuilder::Hash {
shard: 99,
total_shards: 200,
},
),
];
let failures = vec![
"foo",
"hash",
"hash:",
"hash:1",
"hash:1/",
"hash:0/2",
"hash:3/2",
"hash:m/2",
"hash:1/n",
"hash:1/2/3",
];
for (input, output) in successes {
assert_eq!(
PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
"expected input '{input}' to succeed, failed with: {err}"
)),
output,
"success case '{input}' matches",
);
}
for input in failures {
PartitionerBuilder::from_str(input)
.expect_err(&format!("expected input '{input}' to fail"));
}
}
}