nextest_runner/config/elements/
test_threads.rs

1// Copyright (c) The nextest Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{config::core::get_num_cpus, errors::TestThreadsParseError};
5use serde::Deserialize;
6use std::{cmp::Ordering, fmt, str::FromStr};
7
8/// Type for the test-threads config key.
9#[derive(Clone, Copy, Debug, Eq, PartialEq)]
10pub enum TestThreads {
11    /// Run tests with a specified number of threads.
12    Count(usize),
13
14    /// Run tests with a number of threads equal to the logical CPU count.
15    NumCpus,
16}
17
18impl TestThreads {
19    /// Gets the actual number of test threads computed at runtime.
20    pub fn compute(self) -> usize {
21        match self {
22            Self::Count(threads) => threads,
23            Self::NumCpus => get_num_cpus(),
24        }
25    }
26}
27
28impl FromStr for TestThreads {
29    type Err = TestThreadsParseError;
30
31    fn from_str(s: &str) -> Result<Self, Self::Err> {
32        if s == "num-cpus" {
33            return Ok(Self::NumCpus);
34        }
35
36        match s.parse::<isize>() {
37            Err(e) => Err(TestThreadsParseError::new(format!(
38                "Error: {e} parsing {s}"
39            ))),
40            Ok(0) => Err(TestThreadsParseError::new("jobs may not be 0")),
41            Ok(j) if j < 0 => Ok(TestThreads::Count(
42                (get_num_cpus() as isize + j).max(1) as usize
43            )),
44            Ok(j) => Ok(TestThreads::Count(j as usize)),
45        }
46    }
47}
48
49impl fmt::Display for TestThreads {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            Self::Count(threads) => write!(f, "{threads}"),
53            Self::NumCpus => write!(f, "num-cpus"),
54        }
55    }
56}
57
58impl<'de> Deserialize<'de> for TestThreads {
59    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        struct V;
64
65        impl serde::de::Visitor<'_> for V {
66            type Value = TestThreads;
67
68            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
69                write!(formatter, "an integer or the string \"num-cpus\"")
70            }
71
72            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
73            where
74                E: serde::de::Error,
75            {
76                if v == "num-cpus" {
77                    Ok(TestThreads::NumCpus)
78                } else {
79                    Err(serde::de::Error::invalid_value(
80                        serde::de::Unexpected::Str(v),
81                        &self,
82                    ))
83                }
84            }
85
86            // Note that TOML uses i64, not u64.
87            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
88            where
89                E: serde::de::Error,
90            {
91                match v.cmp(&0) {
92                    Ordering::Greater => Ok(TestThreads::Count(v as usize)),
93                    Ordering::Less => Ok(TestThreads::Count(
94                        (get_num_cpus() as i64 + v).max(1) as usize
95                    )),
96                    Ordering::Equal => Err(serde::de::Error::invalid_value(
97                        serde::de::Unexpected::Signed(v),
98                        &self,
99                    )),
100                }
101            }
102        }
103
104        deserializer.deserialize_any(V)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::config::{core::NextestConfig, utils::test_helpers::*};
112    use camino_tempfile::tempdir;
113    use indoc::indoc;
114    use nextest_filtering::ParseContext;
115    use test_case::test_case;
116
117    #[test_case(
118        indoc! {r#"
119            [profile.custom]
120            test-threads = -1
121        "#},
122        Some(get_num_cpus() - 1)
123
124        ; "negative"
125    )]
126    #[test_case(
127        indoc! {r#"
128            [profile.custom]
129            test-threads = 2
130        "#},
131        Some(2)
132
133        ; "positive"
134    )]
135    #[test_case(
136        indoc! {r#"
137            [profile.custom]
138            test-threads = 0
139        "#},
140        None
141
142        ; "zero"
143    )]
144    #[test_case(
145        indoc! {r#"
146            [profile.custom]
147            test-threads = "num-cpus"
148        "#},
149        Some(get_num_cpus())
150
151        ; "num-cpus"
152    )]
153    fn parse_test_threads(config_contents: &str, n_threads: Option<usize>) {
154        let workspace_dir = tempdir().unwrap();
155
156        let graph = temp_workspace(&workspace_dir, config_contents);
157
158        let pcx = ParseContext::new(&graph);
159        let config = NextestConfig::from_sources(
160            graph.workspace().root(),
161            &pcx,
162            None,
163            [],
164            &Default::default(),
165        );
166        match n_threads {
167            None => assert!(config.is_err()),
168            Some(n) => assert_eq!(
169                config
170                    .unwrap()
171                    .profile("custom")
172                    .unwrap()
173                    .apply_build_platforms(&build_platforms())
174                    .custom_profile()
175                    .unwrap()
176                    .test_threads()
177                    .unwrap()
178                    .compute(),
179                n,
180            ),
181        }
182    }
183}