Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rust/private/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ load("//rust/private:stamp.bzl", "stamp_build_setting")
# Exported for docs
exports_files(["providers.bzl"])

# Test sharding wrapper scripts
exports_files([
"test_sharding_wrapper.sh",
"test_sharding_wrapper.bat",
])

bzl_library(
name = "bazel_tools_bzl_lib",
srcs = ["@bazel_tools//tools:bzl_srcs"],
Expand Down
73 changes: 73 additions & 0 deletions rust/private/rust.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,55 @@ def _rust_test_impl(ctx):
rust_flags = get_rust_test_flags(ctx.attr),
skip_expanding_rustc_env = True,
)

# If sharding is enabled and we're using libtest harness, wrap the test binary
# with a script that handles test enumeration and shard partitioning
if ctx.attr.experimental_enable_sharding and ctx.attr.use_libtest_harness:
# Find DefaultInfo and CrateInfo in the providers list
# DefaultInfo is first, CrateInfo follows (or TestCrateInfo for staticlib/cdylib)
default_info = providers[0]
default_info_index = 0

# Get the test binary from CrateInfo - it's the output of the compiled test
crate_info_provider = None
for p in providers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do something like edit rustc_compile_action to return a dict of provider types to providers, and then do a key lookup, rather than sniffing an ordered list.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better - done.

if hasattr(p, "output") and hasattr(p, "is_test"):
crate_info_provider = p
break

if crate_info_provider:
test_binary = crate_info_provider.output

# Select the appropriate wrapper template based on target OS
if toolchain.target_os == "windows":
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.bat")
wrapper_template = ctx.file._test_sharding_wrapper_windows
else:
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.sh")
wrapper_template = ctx.file._test_sharding_wrapper_unix

# Generate wrapper script with test binary path substituted
ctx.actions.expand_template(
template = wrapper_template,
output = wrapper,
substitutions = {
"{{TEST_BINARY}}": test_binary.short_path,
},
is_executable = True,
)

# Update runfiles to include both wrapper and test binary
new_runfiles = default_info.default_runfiles.merge(
ctx.runfiles(files = [test_binary]),
)

# Replace DefaultInfo with wrapper as executable
providers[default_info_index] = DefaultInfo(
files = default_info.files,
runfiles = new_runfiles,
executable = wrapper,
)

data = getattr(ctx.attr, "data", [])

env = expand_dict_value_locations(
Expand Down Expand Up @@ -918,6 +967,22 @@ _rust_test_attrs = {
E.g. `bazel test //src:rust_test --test_arg=foo::test::test_fn`.
"""),
),
"experimental_enable_sharding": attr.bool(
mandatory = False,
default = False,
doc = dedent("""\
If True, enable support for Bazel test sharding (shard_count attribute).

When enabled, tests are executed via a wrapper script that:
1. Enumerates tests using libtest's --list flag
2. Partitions tests across shards based on TEST_SHARD_INDEX/TEST_TOTAL_SHARDS
3. Runs only the tests assigned to the current shard

This attribute only has an effect when use_libtest_harness is True.

This is experimental and may change in future releases.
"""),
),
} | _coverage_attrs | _experimental_use_cc_common_link_attrs

rust_library = rule(
Expand Down Expand Up @@ -1452,6 +1517,14 @@ rust_test = rule(
"_allowlist_function_transition": attr.label(
default = "@bazel_tools//tools/allowlists/function_transition_allowlist",
),
"_test_sharding_wrapper_unix": attr.label(
default = Label("//rust/private:test_sharding_wrapper.sh"),
allow_single_file = True,
),
"_test_sharding_wrapper_windows": attr.label(
default = Label("//rust/private:test_sharding_wrapper.bat"),
allow_single_file = True,
),
},
executable = True,
fragments = ["cpp"],
Expand Down
73 changes: 73 additions & 0 deletions rust/private/test_sharding_wrapper.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
@REM Copyright 2024 The Bazel Authors. All rights reserved.
@REM
@REM Licensed under the Apache License, Version 2.0 (the "License");
@REM you may not use this file except in compliance with the License.
@REM You may obtain a copy of the License at
@REM
@REM http://www.apache.org/licenses/LICENSE-2.0
@REM
@REM Unless required by applicable law or agreed to in writing, software
@REM distributed under the License is distributed on an "AS IS" BASIS,
@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@REM See the License for the specific language governing permissions and
@REM limitations under the License.

@REM Wrapper script for rust_test that enables Bazel test sharding support.
@REM This script intercepts test execution, enumerates tests using libtest's
@REM --list flag, partitions them by shard index, and runs only the relevant subset.

@ECHO OFF
SETLOCAL EnableDelayedExpansion

SET "TEST_BINARY={{TEST_BINARY}}"

@REM If sharding is not enabled, run test binary directly
IF "%TEST_TOTAL_SHARDS%"=="" (
"%TEST_BINARY%" %*
EXIT /B %ERRORLEVEL%
)

@REM Touch status file to advertise sharding support to Bazel
IF NOT "%TEST_SHARD_STATUS_FILE%"=="" (
ECHO.>"%TEST_SHARD_STATUS_FILE%"
)

@REM Create a temporary file for test list
SET "TEMP_LIST=%TEMP%\rust_test_list_%RANDOM%.txt"

@REM Enumerate all tests using libtest's --list flag
"%TEST_BINARY%" --list --format terse 2>NUL | FINDSTR /R ": test$" > "%TEMP_LIST%"

@REM Check if any tests were found
FOR %%A IN ("%TEMP_LIST%") DO IF %%~zA==0 (
DEL "%TEMP_LIST%" 2>NUL
EXIT /B 0
)

@REM Filter tests for this shard and build argument list
SET "INDEX=0"
SET "SHARD_TESTS="

FOR /F "usebackq delims=" %%T IN ("%TEMP_LIST%") DO (
SET "TEST_LINE=%%T"
@REM Strip ": test" suffix
SET "TEST_NAME=!TEST_LINE:: test=!"

@REM Calculate index % TEST_TOTAL_SHARDS
SET /A "MOD=INDEX %% TEST_TOTAL_SHARDS"

IF !MOD! EQU %TEST_SHARD_INDEX% (
SET "SHARD_TESTS=!SHARD_TESTS! "!TEST_NAME!""
)

SET /A "INDEX+=1"
)

DEL "%TEMP_LIST%" 2>NUL

@REM If no tests for this shard, exit successfully
IF "%SHARD_TESTS%"=="" EXIT /B 0

@REM Run the filtered tests with --exact to match exact test names
"%TEST_BINARY%" %SHARD_TESTS% --exact %*
EXIT /B %ERRORLEVEL%
60 changes: 60 additions & 0 deletions rust/private/test_sharding_wrapper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Wrapper script for rust_test that enables Bazel test sharding support.
# This script intercepts test execution, enumerates tests using libtest's
# --list flag, partitions them by shard index, and runs only the relevant subset.

set -euo pipefail

TEST_BINARY="{{TEST_BINARY}}"

# If sharding is not enabled, run test binary directly
if [[ -z "${TEST_TOTAL_SHARDS:-}" ]]; then
exec "./${TEST_BINARY}" "$@"
fi

# Touch status file to advertise sharding support to Bazel
if [[ -n "${TEST_SHARD_STATUS_FILE:-}" ]]; then
touch "${TEST_SHARD_STATUS_FILE}"
fi

# Enumerate all tests using libtest's --list flag
# Output format: "test_name: test" - we need to strip the ": test" suffix
test_list=$("./${TEST_BINARY}" --list --format terse 2>/dev/null | grep ': test$' | sed 's/: test$//' || true)

# If no tests found, exit successfully
if [[ -z "$test_list" ]]; then
exit 0
fi

# Filter tests for this shard
# test_index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX
shard_tests=()
index=0
while IFS= read -r test_name; do
if (( index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX )); then
shard_tests+=("$test_name")
fi
((index++)) || true
done <<< "$test_list"

# If no tests for this shard, exit successfully
if [[ ${#shard_tests[@]} -eq 0 ]]; then
exit 0
fi

# Run the filtered tests with --exact to match exact test names
exec "./${TEST_BINARY}" "${shard_tests[@]}" --exact "$@"
4 changes: 4 additions & 0 deletions test/unit/test_sharding/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
load(":test_sharding.bzl", "test_sharding_test_suite")

############################ UNIT TESTS #############################
test_sharding_test_suite(name = "test_sharding_test_suite")
43 changes: 43 additions & 0 deletions test/unit/test_sharding/sharded_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Test file with multiple tests for verifying sharding support.
//!
//! These tests are intentionally trivial - their purpose is to provide multiple
//! enumerable test functions that can be partitioned across shards. The BUILD
//! file runs this with `shard_count = 3`, so these 6 tests should be split
//! ~2 per shard. The sharding wrapper script enumerates tests via `--list`,
//! assigns each to a shard based on `index % shard_count`, and runs only the
//! subset for the current shard.

#[cfg(test)]
mod tests {
#[test]
fn test_1() {}

#[test]
fn test_2() {}

#[test]
fn test_3() {}

#[test]
fn test_4() {}

#[test]
fn test_5() {}

#[test]
fn test_6() {}
}
102 changes: 102 additions & 0 deletions test/unit/test_sharding/test_sharding.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Tests for rust_test sharding support."""

load("@bazel_skylib//lib:unittest.bzl", "analysistest", "asserts")
load("//rust:defs.bzl", "rust_test")

def _sharding_enabled_test(ctx):
"""Test that sharding wrapper is generated when experimental_enable_sharding is True."""
env = analysistest.begin(ctx)
tut = analysistest.target_under_test(env)

# Get the executable from DefaultInfo
default_info = tut[DefaultInfo]
executable = default_info.files_to_run.executable

# When sharding is enabled, the executable should be a wrapper script
asserts.true(
env,
executable.basename.endswith("_sharding_wrapper.sh") or
executable.basename.endswith("_sharding_wrapper.bat"),
"Expected sharding wrapper script, got: " + executable.basename,
)

return analysistest.end(env)

sharding_enabled_test = analysistest.make(_sharding_enabled_test)

def _sharding_disabled_test(ctx):
"""Test that no wrapper is generated when experimental_enable_sharding is False."""
env = analysistest.begin(ctx)
tut = analysistest.target_under_test(env)

# Get the executable from DefaultInfo
default_info = tut[DefaultInfo]
executable = default_info.files_to_run.executable

# When sharding is disabled, the executable should be the test binary directly
asserts.false(
env,
executable.basename.endswith("_sharding_wrapper.sh") or
executable.basename.endswith("_sharding_wrapper.bat"),
"Expected test binary, not wrapper script: " + executable.basename,
)

return analysistest.end(env)

sharding_disabled_test = analysistest.make(_sharding_disabled_test)

def _test_sharding_targets():
"""Create test targets for sharding tests."""

# Test with sharding enabled
rust_test(
name = "sharded_test_enabled",
srcs = ["sharded_test.rs"],
edition = "2021",
experimental_enable_sharding = True,
)

sharding_enabled_test(
name = "sharding_enabled_test",
target_under_test = ":sharded_test_enabled",
)

# Test with sharding disabled (default)
rust_test(
name = "sharded_test_disabled",
srcs = ["sharded_test.rs"],
edition = "2021",
experimental_enable_sharding = False,
)

sharding_disabled_test(
name = "sharding_disabled_test",
target_under_test = ":sharded_test_disabled",
)

# Integration test: actually run a sharded test
rust_test(
name = "sharded_integration_test",
srcs = ["sharded_test.rs"],
edition = "2021",
experimental_enable_sharding = True,
shard_count = 3,
)

def test_sharding_test_suite(name):
"""Entry-point macro called from the BUILD file.

Args:
name: Name of the macro.
"""

_test_sharding_targets()

native.test_suite(
name = name,
tests = [
":sharding_enabled_test",
":sharding_disabled_test",
":sharded_integration_test",
],
)