Skip to content

Commit 94cd8d6

Browse files
committed
Add test sharding support
1 parent cdaf15f commit 94cd8d6

File tree

7 files changed

+361
-0
lines changed

7 files changed

+361
-0
lines changed

rust/private/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ load("//rust/private:stamp.bzl", "stamp_build_setting")
66
# Exported for docs
77
exports_files(["providers.bzl"])
88

9+
# Test sharding wrapper scripts
10+
exports_files([
11+
"test_sharding_wrapper.sh",
12+
"test_sharding_wrapper.bat",
13+
])
14+
915
bzl_library(
1016
name = "bazel_tools_bzl_lib",
1117
srcs = ["@bazel_tools//tools:bzl_srcs"],

rust/private/rust.bzl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,55 @@ def _rust_test_impl(ctx):
516516
rust_flags = get_rust_test_flags(ctx.attr),
517517
skip_expanding_rustc_env = True,
518518
)
519+
520+
# If sharding is enabled and we're using libtest harness, wrap the test binary
521+
# with a script that handles test enumeration and shard partitioning
522+
if ctx.attr.experimental_enable_sharding and ctx.attr.use_libtest_harness:
523+
# Find DefaultInfo and CrateInfo in the providers list
524+
# DefaultInfo is first, CrateInfo follows (or TestCrateInfo for staticlib/cdylib)
525+
default_info = providers[0]
526+
default_info_index = 0
527+
528+
# Get the test binary from CrateInfo - it's the output of the compiled test
529+
crate_info_provider = None
530+
for p in providers:
531+
if hasattr(p, "output") and hasattr(p, "is_test"):
532+
crate_info_provider = p
533+
break
534+
535+
if crate_info_provider:
536+
test_binary = crate_info_provider.output
537+
538+
# Select the appropriate wrapper template based on target OS
539+
if toolchain.target_os == "windows":
540+
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.bat")
541+
wrapper_template = ctx.file._test_sharding_wrapper_windows
542+
else:
543+
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.sh")
544+
wrapper_template = ctx.file._test_sharding_wrapper_unix
545+
546+
# Generate wrapper script with test binary path substituted
547+
ctx.actions.expand_template(
548+
template = wrapper_template,
549+
output = wrapper,
550+
substitutions = {
551+
"{{TEST_BINARY}}": test_binary.short_path,
552+
},
553+
is_executable = True,
554+
)
555+
556+
# Update runfiles to include both wrapper and test binary
557+
new_runfiles = default_info.default_runfiles.merge(
558+
ctx.runfiles(files = [test_binary]),
559+
)
560+
561+
# Replace DefaultInfo with wrapper as executable
562+
providers[default_info_index] = DefaultInfo(
563+
files = default_info.files,
564+
runfiles = new_runfiles,
565+
executable = wrapper,
566+
)
567+
519568
data = getattr(ctx.attr, "data", [])
520569

521570
env = expand_dict_value_locations(
@@ -918,6 +967,22 @@ _rust_test_attrs = {
918967
E.g. `bazel test //src:rust_test --test_arg=foo::test::test_fn`.
919968
"""),
920969
),
970+
"experimental_enable_sharding": attr.bool(
971+
mandatory = False,
972+
default = False,
973+
doc = dedent("""\
974+
If True, enable support for Bazel test sharding (shard_count attribute).
975+
976+
When enabled, tests are executed via a wrapper script that:
977+
1. Enumerates tests using libtest's --list flag
978+
2. Partitions tests across shards based on TEST_SHARD_INDEX/TEST_TOTAL_SHARDS
979+
3. Runs only the tests assigned to the current shard
980+
981+
This attribute only has an effect when use_libtest_harness is True.
982+
983+
This is experimental and may change in future releases.
984+
"""),
985+
),
921986
} | _coverage_attrs | _experimental_use_cc_common_link_attrs
922987

923988
rust_library = rule(
@@ -1452,6 +1517,14 @@ rust_test = rule(
14521517
"_allowlist_function_transition": attr.label(
14531518
default = "@bazel_tools//tools/allowlists/function_transition_allowlist",
14541519
),
1520+
"_test_sharding_wrapper_unix": attr.label(
1521+
default = Label("//rust/private:test_sharding_wrapper.sh"),
1522+
allow_single_file = True,
1523+
),
1524+
"_test_sharding_wrapper_windows": attr.label(
1525+
default = Label("//rust/private:test_sharding_wrapper.bat"),
1526+
allow_single_file = True,
1527+
),
14551528
},
14561529
executable = True,
14571530
fragments = ["cpp"],
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
@REM Copyright 2024 The Bazel Authors. All rights reserved.
2+
@REM
3+
@REM Licensed under the Apache License, Version 2.0 (the "License");
4+
@REM you may not use this file except in compliance with the License.
5+
@REM You may obtain a copy of the License at
6+
@REM
7+
@REM http://www.apache.org/licenses/LICENSE-2.0
8+
@REM
9+
@REM Unless required by applicable law or agreed to in writing, software
10+
@REM distributed under the License is distributed on an "AS IS" BASIS,
11+
@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
@REM See the License for the specific language governing permissions and
13+
@REM limitations under the License.
14+
15+
@REM Wrapper script for rust_test that enables Bazel test sharding support.
16+
@REM This script intercepts test execution, enumerates tests using libtest's
17+
@REM --list flag, partitions them by shard index, and runs only the relevant subset.
18+
19+
@ECHO OFF
20+
SETLOCAL EnableDelayedExpansion
21+
22+
SET "TEST_BINARY={{TEST_BINARY}}"
23+
24+
@REM If sharding is not enabled, run test binary directly
25+
IF "%TEST_TOTAL_SHARDS%"=="" (
26+
"%TEST_BINARY%" %*
27+
EXIT /B %ERRORLEVEL%
28+
)
29+
30+
@REM Touch status file to advertise sharding support to Bazel
31+
IF NOT "%TEST_SHARD_STATUS_FILE%"=="" (
32+
ECHO.>"%TEST_SHARD_STATUS_FILE%"
33+
)
34+
35+
@REM Create a temporary file for test list
36+
SET "TEMP_LIST=%TEMP%\rust_test_list_%RANDOM%.txt"
37+
38+
@REM Enumerate all tests using libtest's --list flag
39+
"%TEST_BINARY%" --list --format terse 2>NUL | FINDSTR /R ": test$" > "%TEMP_LIST%"
40+
41+
@REM Check if any tests were found
42+
FOR %%A IN ("%TEMP_LIST%") DO IF %%~zA==0 (
43+
DEL "%TEMP_LIST%" 2>NUL
44+
EXIT /B 0
45+
)
46+
47+
@REM Filter tests for this shard and build argument list
48+
SET "INDEX=0"
49+
SET "SHARD_TESTS="
50+
51+
FOR /F "usebackq delims=" %%T IN ("%TEMP_LIST%") DO (
52+
SET "TEST_LINE=%%T"
53+
@REM Strip ": test" suffix
54+
SET "TEST_NAME=!TEST_LINE:: test=!"
55+
56+
@REM Calculate index % TEST_TOTAL_SHARDS
57+
SET /A "MOD=INDEX %% TEST_TOTAL_SHARDS"
58+
59+
IF !MOD! EQU %TEST_SHARD_INDEX% (
60+
SET "SHARD_TESTS=!SHARD_TESTS! "!TEST_NAME!""
61+
)
62+
63+
SET /A "INDEX+=1"
64+
)
65+
66+
DEL "%TEMP_LIST%" 2>NUL
67+
68+
@REM If no tests for this shard, exit successfully
69+
IF "%SHARD_TESTS%"=="" EXIT /B 0
70+
71+
@REM Run the filtered tests with --exact to match exact test names
72+
"%TEST_BINARY%" %SHARD_TESTS% --exact %*
73+
EXIT /B %ERRORLEVEL%
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env bash
2+
# Copyright 2024 The Bazel Authors. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Wrapper script for rust_test that enables Bazel test sharding support.
17+
# This script intercepts test execution, enumerates tests using libtest's
18+
# --list flag, partitions them by shard index, and runs only the relevant subset.
19+
20+
set -euo pipefail
21+
22+
TEST_BINARY="{{TEST_BINARY}}"
23+
24+
# If sharding is not enabled, run test binary directly
25+
if [[ -z "${TEST_TOTAL_SHARDS:-}" ]]; then
26+
exec "./${TEST_BINARY}" "$@"
27+
fi
28+
29+
# Touch status file to advertise sharding support to Bazel
30+
if [[ -n "${TEST_SHARD_STATUS_FILE:-}" ]]; then
31+
touch "${TEST_SHARD_STATUS_FILE}"
32+
fi
33+
34+
# Enumerate all tests using libtest's --list flag
35+
# Output format: "test_name: test" - we need to strip the ": test" suffix
36+
test_list=$("./${TEST_BINARY}" --list --format terse 2>/dev/null | grep ': test$' | sed 's/: test$//' || true)
37+
38+
# If no tests found, exit successfully
39+
if [[ -z "$test_list" ]]; then
40+
exit 0
41+
fi
42+
43+
# Filter tests for this shard
44+
# test_index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX
45+
shard_tests=()
46+
index=0
47+
while IFS= read -r test_name; do
48+
if (( index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX )); then
49+
shard_tests+=("$test_name")
50+
fi
51+
((index++)) || true
52+
done <<< "$test_list"
53+
54+
# If no tests for this shard, exit successfully
55+
if [[ ${#shard_tests[@]} -eq 0 ]]; then
56+
exit 0
57+
fi
58+
59+
# Run the filtered tests with --exact to match exact test names
60+
exec "./${TEST_BINARY}" "${shard_tests[@]}" --exact "$@"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
load(":test_sharding.bzl", "test_sharding_test_suite")
2+
3+
############################ UNIT TESTS #############################
4+
test_sharding_test_suite(name = "test_sharding_test_suite")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2024 The Bazel Authors. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//! Test file with multiple tests for verifying sharding support.
16+
//!
17+
//! These tests are intentionally trivial - their purpose is to provide multiple
18+
//! enumerable test functions that can be partitioned across shards. The BUILD
19+
//! file runs this with `shard_count = 3`, so these 6 tests should be split
20+
//! ~2 per shard. The sharding wrapper script enumerates tests via `--list`,
21+
//! assigns each to a shard based on `index % shard_count`, and runs only the
22+
//! subset for the current shard.
23+
24+
#[cfg(test)]
25+
mod tests {
26+
#[test]
27+
fn test_1() {}
28+
29+
#[test]
30+
fn test_2() {}
31+
32+
#[test]
33+
fn test_3() {}
34+
35+
#[test]
36+
fn test_4() {}
37+
38+
#[test]
39+
fn test_5() {}
40+
41+
#[test]
42+
fn test_6() {}
43+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Tests for rust_test sharding support."""
2+
3+
load("@bazel_skylib//lib:unittest.bzl", "analysistest", "asserts")
4+
load("//rust:defs.bzl", "rust_test")
5+
6+
def _sharding_enabled_test(ctx):
7+
"""Test that sharding wrapper is generated when experimental_enable_sharding is True."""
8+
env = analysistest.begin(ctx)
9+
tut = analysistest.target_under_test(env)
10+
11+
# Get the executable from DefaultInfo
12+
default_info = tut[DefaultInfo]
13+
executable = default_info.files_to_run.executable
14+
15+
# When sharding is enabled, the executable should be a wrapper script
16+
asserts.true(
17+
env,
18+
executable.basename.endswith("_sharding_wrapper.sh") or
19+
executable.basename.endswith("_sharding_wrapper.bat"),
20+
"Expected sharding wrapper script, got: " + executable.basename,
21+
)
22+
23+
return analysistest.end(env)
24+
25+
sharding_enabled_test = analysistest.make(_sharding_enabled_test)
26+
27+
def _sharding_disabled_test(ctx):
28+
"""Test that no wrapper is generated when experimental_enable_sharding is False."""
29+
env = analysistest.begin(ctx)
30+
tut = analysistest.target_under_test(env)
31+
32+
# Get the executable from DefaultInfo
33+
default_info = tut[DefaultInfo]
34+
executable = default_info.files_to_run.executable
35+
36+
# When sharding is disabled, the executable should be the test binary directly
37+
asserts.false(
38+
env,
39+
executable.basename.endswith("_sharding_wrapper.sh") or
40+
executable.basename.endswith("_sharding_wrapper.bat"),
41+
"Expected test binary, not wrapper script: " + executable.basename,
42+
)
43+
44+
return analysistest.end(env)
45+
46+
sharding_disabled_test = analysistest.make(_sharding_disabled_test)
47+
48+
def _test_sharding_targets():
49+
"""Create test targets for sharding tests."""
50+
51+
# Test with sharding enabled
52+
rust_test(
53+
name = "sharded_test_enabled",
54+
srcs = ["sharded_test.rs"],
55+
edition = "2021",
56+
experimental_enable_sharding = True,
57+
)
58+
59+
sharding_enabled_test(
60+
name = "sharding_enabled_test",
61+
target_under_test = ":sharded_test_enabled",
62+
)
63+
64+
# Test with sharding disabled (default)
65+
rust_test(
66+
name = "sharded_test_disabled",
67+
srcs = ["sharded_test.rs"],
68+
edition = "2021",
69+
experimental_enable_sharding = False,
70+
)
71+
72+
sharding_disabled_test(
73+
name = "sharding_disabled_test",
74+
target_under_test = ":sharded_test_disabled",
75+
)
76+
77+
# Integration test: actually run a sharded test
78+
rust_test(
79+
name = "sharded_integration_test",
80+
srcs = ["sharded_test.rs"],
81+
edition = "2021",
82+
experimental_enable_sharding = True,
83+
shard_count = 3,
84+
)
85+
86+
def test_sharding_test_suite(name):
87+
"""Entry-point macro called from the BUILD file.
88+
89+
Args:
90+
name: Name of the macro.
91+
"""
92+
93+
_test_sharding_targets()
94+
95+
native.test_suite(
96+
name = name,
97+
tests = [
98+
":sharding_enabled_test",
99+
":sharding_disabled_test",
100+
":sharded_integration_test",
101+
],
102+
)

0 commit comments

Comments
 (0)