1use super::{NextestConfig, ToolConfigFile};
7use crate::errors::{ConfigParseError, ConfigParseErrorKind};
8use camino::Utf8Path;
9use semver::Version;
10use serde::{Deserialize, Deserializer};
11use std::{borrow::Cow, collections::BTreeSet, fmt, str::FromStr};
12
13#[derive(Debug, Default, Clone, PartialEq, Eq)]
18pub struct VersionOnlyConfig {
19 nextest_version: NextestVersionConfig,
21
22 experimental: BTreeSet<ConfigExperimental>,
24}
25
26impl VersionOnlyConfig {
27 pub fn from_sources<'a, I>(
31 workspace_root: &Utf8Path,
32 config_file: Option<&Utf8Path>,
33 tool_config_files: impl IntoIterator<IntoIter = I>,
34 ) -> Result<Self, ConfigParseError>
35 where
36 I: Iterator<Item = &'a ToolConfigFile> + DoubleEndedIterator,
37 {
38 let tool_config_files_rev = tool_config_files.into_iter().rev();
39
40 Self::read_from_sources(workspace_root, config_file, tool_config_files_rev)
41 }
42
43 pub fn nextest_version(&self) -> &NextestVersionConfig {
45 &self.nextest_version
46 }
47
48 pub fn experimental(&self) -> &BTreeSet<ConfigExperimental> {
50 &self.experimental
51 }
52
53 fn read_from_sources<'a>(
54 workspace_root: &Utf8Path,
55 config_file: Option<&Utf8Path>,
56 tool_config_files_rev: impl Iterator<Item = &'a ToolConfigFile>,
57 ) -> Result<Self, ConfigParseError> {
58 let mut nextest_version = NextestVersionConfig::default();
59 let mut experimental = BTreeSet::new();
60
61 for ToolConfigFile { config_file, tool } in tool_config_files_rev {
63 if let Some(v) = Self::read_and_deserialize(config_file, Some(tool))?.nextest_version {
64 nextest_version.accumulate(v, Some(tool));
65 }
66 }
67
68 let config_file = match config_file {
70 Some(file) => Some(Cow::Borrowed(file)),
71 None => {
72 let config_file = workspace_root.join(NextestConfig::CONFIG_PATH);
73 config_file.exists().then_some(Cow::Owned(config_file))
74 }
75 };
76 if let Some(config_file) = config_file {
77 let d = Self::read_and_deserialize(&config_file, None)?;
78 if let Some(v) = d.nextest_version {
79 nextest_version.accumulate(v, None);
80 }
81
82 let unknown: BTreeSet<_> = d
84 .experimental
85 .into_iter()
86 .filter(|feature| {
87 if let Ok(feature) = feature.parse::<ConfigExperimental>() {
88 experimental.insert(feature);
89 false
90 } else {
91 true
92 }
93 })
94 .collect();
95 if !unknown.is_empty() {
96 let known = ConfigExperimental::known().collect();
97 return Err(ConfigParseError::new(
98 config_file.into_owned(),
99 None,
100 ConfigParseErrorKind::UnknownExperimentalFeatures { unknown, known },
101 ));
102 }
103 }
104
105 Ok(Self {
106 nextest_version,
107 experimental,
108 })
109 }
110
111 fn read_and_deserialize(
112 config_file: &Utf8Path,
113 tool: Option<&str>,
114 ) -> Result<VersionOnlyDeserialize, ConfigParseError> {
115 let toml_str = std::fs::read_to_string(config_file.as_str()).map_err(|error| {
116 ConfigParseError::new(
117 config_file,
118 tool,
119 ConfigParseErrorKind::VersionOnlyReadError(error),
120 )
121 })?;
122 let toml_de = toml::de::Deserializer::new(&toml_str);
123 let v: VersionOnlyDeserialize =
124 serde_path_to_error::deserialize(toml_de).map_err(|error| {
125 ConfigParseError::new(
126 config_file,
127 tool,
128 ConfigParseErrorKind::VersionOnlyDeserializeError(Box::new(error)),
129 )
130 })?;
131 if tool.is_some() && !v.experimental.is_empty() {
132 return Err(ConfigParseError::new(
133 config_file,
134 tool,
135 ConfigParseErrorKind::ExperimentalFeaturesInToolConfig {
136 features: v.experimental,
137 },
138 ));
139 }
140
141 Ok(v)
142 }
143}
144
145#[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize)]
147#[serde(rename_all = "kebab-case")]
148struct VersionOnlyDeserialize {
149 #[serde(default)]
150 nextest_version: Option<NextestVersionDeserialize>,
151 #[serde(default)]
152 experimental: BTreeSet<String>,
153}
154
155#[derive(Debug, Default, Clone, PartialEq, Eq)]
161pub struct NextestVersionConfig {
162 pub required: NextestVersionReq,
164
165 pub recommended: NextestVersionReq,
170}
171
172impl NextestVersionConfig {
173 pub(crate) fn accumulate(&mut self, v: NextestVersionDeserialize, v_tool: Option<&str>) {
175 if let Some(v) = v.required {
176 self.required.accumulate(v, v_tool);
177 }
178 if let Some(v) = v.recommended {
179 self.recommended.accumulate(v, v_tool);
180 }
181 }
182
183 pub fn eval(
185 &self,
186 current_version: &Version,
187 override_version_check: bool,
188 ) -> NextestVersionEval {
189 match self.required.satisfies(current_version) {
190 Ok(()) => {}
191 Err((required, tool)) => {
192 if override_version_check {
193 return NextestVersionEval::ErrorOverride {
194 required: required.clone(),
195 current: current_version.clone(),
196 tool: tool.map(|s| s.to_owned()),
197 };
198 } else {
199 return NextestVersionEval::Error {
200 required: required.clone(),
201 current: current_version.clone(),
202 tool: tool.map(|s| s.to_owned()),
203 };
204 }
205 }
206 }
207
208 match self.recommended.satisfies(current_version) {
209 Ok(()) => NextestVersionEval::Satisfied,
210 Err((recommended, tool)) => {
211 if override_version_check {
212 NextestVersionEval::WarnOverride {
213 recommended: recommended.clone(),
214 current: current_version.clone(),
215 tool: tool.map(|s| s.to_owned()),
216 }
217 } else {
218 NextestVersionEval::Warn {
219 recommended: recommended.clone(),
220 current: current_version.clone(),
221 tool: tool.map(|s| s.to_owned()),
222 }
223 }
224 }
225 }
226 }
227}
228
229#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
231#[non_exhaustive]
232pub enum ConfigExperimental {
233 SetupScripts,
235}
236
237impl ConfigExperimental {
238 fn known() -> impl Iterator<Item = Self> {
239 vec![Self::SetupScripts].into_iter()
240 }
241}
242
243impl FromStr for ConfigExperimental {
244 type Err = ();
245
246 fn from_str(s: &str) -> Result<Self, Self::Err> {
247 match s {
248 "setup-scripts" => Ok(Self::SetupScripts),
249 _ => Err(()),
250 }
251 }
252}
253
254impl fmt::Display for ConfigExperimental {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 match self {
257 Self::SetupScripts => write!(f, "setup-scripts"),
258 }
259 }
260}
261
262#[derive(Debug, Default, Clone, PartialEq, Eq)]
264pub enum NextestVersionReq {
265 Version {
267 version: Version,
269
270 tool: Option<String>,
272 },
273
274 #[default]
276 None,
277}
278
279impl NextestVersionReq {
280 fn accumulate(&mut self, v: Version, v_tool: Option<&str>) {
281 match self {
282 NextestVersionReq::Version { version, tool } => {
283 if &v >= version {
286 *version = v;
287 *tool = v_tool.map(|s| s.to_owned());
288 }
289 }
290 NextestVersionReq::None => {
291 *self = NextestVersionReq::Version {
292 version: v,
293 tool: v_tool.map(|s| s.to_owned()),
294 };
295 }
296 }
297 }
298
299 fn satisfies(&self, version: &Version) -> Result<(), (&Version, Option<&str>)> {
300 match self {
301 NextestVersionReq::Version {
302 version: required,
303 tool,
304 } => {
305 if version >= required {
306 Ok(())
307 } else {
308 Err((required, tool.as_deref()))
309 }
310 }
311 NextestVersionReq::None => Ok(()),
312 }
313 }
314}
315
316#[derive(Debug, Clone, PartialEq, Eq)]
320pub enum NextestVersionEval {
321 Satisfied,
323
324 Error {
326 required: Version,
328 current: Version,
330 tool: Option<String>,
332 },
333
334 Warn {
336 recommended: Version,
338 current: Version,
340 tool: Option<String>,
342 },
343
344 ErrorOverride {
346 required: Version,
348 current: Version,
350 tool: Option<String>,
352 },
353
354 WarnOverride {
356 recommended: Version,
358 current: Version,
360 tool: Option<String>,
362 },
363}
364
365#[derive(Debug, Clone, PartialEq, Eq)]
371pub(crate) struct NextestVersionDeserialize {
372 required: Option<Version>,
374
375 recommended: Option<Version>,
377}
378
379impl<'de> Deserialize<'de> for NextestVersionDeserialize {
380 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
381 where
382 D: Deserializer<'de>,
383 {
384 struct V;
385
386 impl<'de2> serde::de::Visitor<'de2> for V {
387 type Value = NextestVersionDeserialize;
388
389 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
390 formatter.write_str(
391 "a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")",
392 )
393 }
394
395 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
396 where
397 E: serde::de::Error,
398 {
399 let required = parse_version::<E>(s.to_owned())?;
400 Ok(NextestVersionDeserialize {
401 required: Some(required),
402 recommended: None,
403 })
404 }
405
406 fn visit_map<A>(self, map: A) -> std::result::Result<Self::Value, A::Error>
407 where
408 A: serde::de::MapAccess<'de2>,
409 {
410 #[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
411 struct NextestVersionMap {
412 #[serde(default, deserialize_with = "deserialize_version_opt")]
413 required: Option<Version>,
414 #[serde(default, deserialize_with = "deserialize_version_opt")]
415 recommended: Option<Version>,
416 }
417
418 let NextestVersionMap {
419 required,
420 recommended,
421 } = NextestVersionMap::deserialize(serde::de::value::MapAccessDeserializer::new(
422 map,
423 ))?;
424
425 if let (Some(required), Some(recommended)) = (&required, &recommended) {
426 if required > recommended {
427 return Err(serde::de::Error::custom(format!(
428 "required version ({required}) must not be greater than recommended version ({recommended})"
429 )));
430 }
431 }
432
433 Ok(NextestVersionDeserialize {
434 required,
435 recommended,
436 })
437 }
438 }
439
440 deserializer.deserialize_any(V)
441 }
442}
443
444fn deserialize_version_opt<'de, D>(
449 deserializer: D,
450) -> std::result::Result<Option<Version>, D::Error>
451where
452 D: Deserializer<'de>,
453{
454 let s = Option::<String>::deserialize(deserializer)?;
455 s.map(parse_version::<D::Error>).transpose()
456}
457
458fn parse_version<E>(mut s: String) -> std::result::Result<Version, E>
459where
460 E: serde::de::Error,
461{
462 for ch in s.chars() {
463 if ch == '-' {
464 return Err(E::custom(
465 "pre-release identifiers are not supported in nextest-version",
466 ));
467 } else if ch == '+' {
468 return Err(E::custom(
469 "build metadata is not supported in nextest-version",
470 ));
471 }
472 }
473
474 if s.matches('.').count() == 1 {
477 s.push_str(".0");
479 }
480
481 Version::parse(&s).map_err(E::custom)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use test_case::test_case;
488
489 #[test_case(
490 r#"
491 nextest-version = "0.9"
492 "#,
493 NextestVersionDeserialize { required: Some("0.9.0".parse().unwrap()), recommended: None } ; "basic"
494 )]
495 #[test_case(
496 r#"
497 nextest-version = "0.9.30"
498 "#,
499 NextestVersionDeserialize { required: Some("0.9.30".parse().unwrap()), recommended: None } ; "basic with patch"
500 )]
501 #[test_case(
502 r#"
503 nextest-version = { recommended = "0.9.20" }
504 "#,
505 NextestVersionDeserialize { required: None, recommended: Some("0.9.20".parse().unwrap()) } ; "with warning"
506 )]
507 #[test_case(
508 r#"
509 nextest-version = { required = "0.9.20", recommended = "0.9.25" }
510 "#,
511 NextestVersionDeserialize {
512 required: Some("0.9.20".parse().unwrap()),
513 recommended: Some("0.9.25".parse().unwrap()),
514 } ; "with error and warning"
515 )]
516 fn test_valid_nextest_version(input: &str, expected: NextestVersionDeserialize) {
517 let actual: VersionOnlyDeserialize = toml::from_str(input).unwrap();
518 assert_eq!(actual.nextest_version.unwrap(), expected);
519 }
520
521 #[test_case(
522 r#"
523 nextest-version = 42
524 "#,
525 "a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")" ; "empty"
526 )]
527 #[test_case(
528 r#"
529 nextest-version = "0.9.30-rc.1"
530 "#,
531 "pre-release identifiers are not supported in nextest-version" ; "pre-release"
532 )]
533 #[test_case(
534 r#"
535 nextest-version = "0.9.40+mybuild"
536 "#,
537 "build metadata is not supported in nextest-version" ; "build metadata"
538 )]
539 #[test_case(
540 r#"
541 nextest-version = { required = "0.9.20", recommended = "0.9.10" }
542 "#,
543 "required version (0.9.20) must not be greater than recommended version (0.9.10)" ; "error greater than warning"
544 )]
545 fn test_invalid_nextest_version(input: &str, error_message: &str) {
546 let err = toml::from_str::<VersionOnlyDeserialize>(input).unwrap_err();
547 assert!(
548 err.to_string().contains(error_message),
549 "error `{err}` contains `{error_message}`"
550 );
551 }
552
553 #[test]
554 fn test_accumulate() {
555 let mut nextest_version = NextestVersionConfig::default();
556 nextest_version.accumulate(
557 NextestVersionDeserialize {
558 required: Some("0.9.20".parse().unwrap()),
559 recommended: None,
560 },
561 Some("tool1"),
562 );
563 nextest_version.accumulate(
564 NextestVersionDeserialize {
565 required: Some("0.9.30".parse().unwrap()),
566 recommended: Some("0.9.35".parse().unwrap()),
567 },
568 Some("tool2"),
569 );
570 nextest_version.accumulate(
571 NextestVersionDeserialize {
572 required: None,
573 recommended: Some("0.9.25".parse().unwrap()),
576 },
577 Some("tool3"),
578 );
579 nextest_version.accumulate(
580 NextestVersionDeserialize {
581 required: Some("0.9.30".parse().unwrap()),
584 recommended: None,
585 },
586 Some("tool4"),
587 );
588
589 assert_eq!(
590 nextest_version,
591 NextestVersionConfig {
592 required: NextestVersionReq::Version {
593 version: "0.9.30".parse().unwrap(),
594 tool: Some("tool4".to_owned()),
595 },
596 recommended: NextestVersionReq::Version {
597 version: "0.9.35".parse().unwrap(),
598 tool: Some("tool2".to_owned()),
599 },
600 }
601 );
602 }
603}