nextest_runner/config/
threads_required.rs

1// Copyright (c) The nextest Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::get_num_cpus;
5use serde::Deserialize;
6use std::{cmp::Ordering, fmt};
7
8/// Type for the threads-required config key.
9#[derive(Clone, Copy, Debug, Eq, PartialEq)]
10pub enum ThreadsRequired {
11    /// Take up "slots" equal to the number of threads.
12    Count(usize),
13
14    /// Take up as many slots as the number of CPUs.
15    NumCpus,
16
17    /// Take up as many slots as the number of test threads specified.
18    NumTestThreads,
19}
20
21impl ThreadsRequired {
22    /// Gets the actual number of test threads computed at runtime.
23    pub fn compute(self, test_threads: usize) -> usize {
24        match self {
25            Self::Count(threads) => threads,
26            Self::NumCpus => get_num_cpus(),
27            Self::NumTestThreads => test_threads,
28        }
29    }
30}
31
32impl<'de> Deserialize<'de> for ThreadsRequired {
33    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34    where
35        D: serde::Deserializer<'de>,
36    {
37        struct V;
38
39        impl serde::de::Visitor<'_> for V {
40            type Value = ThreadsRequired;
41
42            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
43                write!(
44                    formatter,
45                    "an integer, the string \"num-cpus\" or the string \"num-test-threads\""
46                )
47            }
48
49            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
50            where
51                E: serde::de::Error,
52            {
53                if v == "num-cpus" {
54                    Ok(ThreadsRequired::NumCpus)
55                } else if v == "num-test-threads" {
56                    Ok(ThreadsRequired::NumTestThreads)
57                } else {
58                    Err(serde::de::Error::invalid_value(
59                        serde::de::Unexpected::Str(v),
60                        &self,
61                    ))
62                }
63            }
64
65            // Note that TOML uses i64, not u64.
66            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
67            where
68                E: serde::de::Error,
69            {
70                match v.cmp(&0) {
71                    Ordering::Greater => Ok(ThreadsRequired::Count(v as usize)),
72                    // TODO: we don't currently support negative numbers here because it's not clear
73                    // whether num-cpus or num-test-threads is better. It would probably be better
74                    // to support a small expression syntax with +, -, * and /.
75                    //
76                    // I (Rain) checked out a number of the expression syntax crates and found that they
77                    // either support too much or too little. We want just this minimal set of operators,
78                    // plus. Probably worth just forking https://docs.rs/mexe or working with upstream
79                    // to add support for operators.
80                    Ordering::Equal | Ordering::Less => Err(serde::de::Error::invalid_value(
81                        serde::de::Unexpected::Signed(v),
82                        &self,
83                    )),
84                }
85            }
86        }
87
88        deserializer.deserialize_any(V)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::config::{NextestConfig, test_helpers::*};
96    use camino_tempfile::tempdir;
97    use indoc::indoc;
98    use nextest_filtering::ParseContext;
99    use test_case::test_case;
100
101    #[test_case(
102        indoc! {r#"
103            [profile.custom]
104            threads-required = 2
105        "#},
106        Some(2)
107
108        ; "positive"
109    )]
110    #[test_case(
111        indoc! {r#"
112            [profile.custom]
113            threads-required = 0
114        "#},
115        None
116
117        ; "zero"
118    )]
119    #[test_case(
120        indoc! {r#"
121            [profile.custom]
122            threads-required = -1
123        "#},
124        None
125
126        ; "negative"
127    )]
128    #[test_case(
129        indoc! {r#"
130            [profile.custom]
131            threads-required = "num-cpus"
132        "#},
133        Some(get_num_cpus())
134
135        ; "num-cpus"
136    )]
137    #[test_case(
138        indoc! {r#"
139            [profile.custom]
140            test-threads = 1
141            threads-required = "num-cpus"
142        "#},
143        Some(get_num_cpus())
144
145        ; "num-cpus-with-custom-test-threads"
146    )]
147    #[test_case(
148        indoc! {r#"
149            [profile.custom]
150            threads-required = "num-test-threads"
151        "#},
152        Some(get_num_cpus())
153
154        ; "num-test-threads"
155    )]
156    #[test_case(
157        indoc! {r#"
158            [profile.custom]
159            test-threads = 1
160            threads-required = "num-test-threads"
161        "#},
162        Some(1)
163
164        ; "num-test-threads-with-custom-test-threads"
165    )]
166    fn parse_threads_required(config_contents: &str, threads_required: Option<usize>) {
167        let workspace_dir = tempdir().unwrap();
168
169        let graph = temp_workspace(&workspace_dir, config_contents);
170
171        let pcx = ParseContext::new(&graph);
172        let config = NextestConfig::from_sources(
173            graph.workspace().root(),
174            &pcx,
175            None,
176            [],
177            &Default::default(),
178        );
179        match threads_required {
180            None => assert!(config.is_err()),
181            Some(t) => {
182                let config = config.unwrap();
183                let profile = config
184                    .profile("custom")
185                    .unwrap()
186                    .apply_build_platforms(&build_platforms());
187
188                let test_threads = profile.test_threads().compute();
189                let threads_required = profile.threads_required().compute(test_threads);
190                assert_eq!(threads_required, t)
191            }
192        }
193    }
194}