diff --git a/extensions/prost/private/prost.bzl b/extensions/prost/private/prost.bzl index 4eecf85a06..71ed212486 100644 --- a/extensions/prost/private/prost.bzl +++ b/extensions/prost/private/prost.bzl @@ -125,25 +125,22 @@ def _compile_proto( return lib_rs, package_info_file def _get_crate_info(providers): - """Finds the CrateInfo provider in the list of providers.""" - for provider in providers: - if hasattr(provider, "name"): - return provider - fail("Couldn't find a CrateInfo in the list of providers") + """Finds the CrateInfo provider in the dict of providers.""" + if "crate_info" in providers: + return providers["crate_info"] + fail("Couldn't find a CrateInfo in the providers") def _get_dep_info(providers): - """Finds the DepInfo provider in the list of providers.""" - for provider in providers: - if hasattr(provider, "direct_crates"): - return provider - fail("Couldn't find a DepInfo in the list of providers") + """Finds the DepInfo provider in the dict of providers.""" + if "dep_info" in providers: + return providers["dep_info"] + fail("Couldn't find a DepInfo in the providers") def _get_cc_info(providers): - """Finds the CcInfo provider in the list of providers.""" - for provider in providers: - if hasattr(provider, "linking_context"): - return provider - fail("Couldn't find a CcInfo in the list of providers") + """Finds the CcInfo provider in the dict of providers.""" + if "CcInfo" in providers: + return providers["CcInfo"] + fail("Couldn't find a CcInfo in the providers") def _compile_rust( *, diff --git a/extensions/protobuf/proto.bzl b/extensions/protobuf/proto.bzl index 8c67a39ecb..83203547d3 100644 --- a/extensions/protobuf/proto.bzl +++ b/extensions/protobuf/proto.bzl @@ -246,8 +246,10 @@ def _rust_proto_compile(protos, descriptor_sets, imports, crate_name, ctx, is_gr ), output_hash = output_hash, ) - providers.append(OutputGroupInfo(rust_generated_srcs = srcs)) - return providers + # Convert dict to list and add OutputGroupInfo + result = list(providers.values()) + result.append(OutputGroupInfo(rust_generated_srcs = srcs)) + return result def _rust_protogrpc_library_impl(ctx, is_grpc): """Implementation of the rust_(proto|grpc)_library. diff --git a/extensions/wasm_bindgen/private/wasm_bindgen_test.bzl b/extensions/wasm_bindgen/private/wasm_bindgen_test.bzl index 8a2dcb5738..22b60adfcd 100644 --- a/extensions/wasm_bindgen/private/wasm_bindgen_test.bzl +++ b/extensions/wasm_bindgen/private/wasm_bindgen_test.bzl @@ -173,18 +173,21 @@ def _rust_wasm_bindgen_test_impl(ctx): # Force the use of a browser for now as there is no node integration. env["WASM_BINDGEN_USE_BROWSER"] = "1" + # Extract DefaultInfo and create a modified version with the wrapper as executable + default_info = crate_providers["DefaultInfo"] + files = default_info.files.to_list() + if len(files) != 1: + fail("Unexpected number of output files for `{}`: {}".format(ctx.label, files)) + wasm_file = files[0] + env["TEST_WASM_BINARY"] = _rlocationpath(files[0], ctx.workspace_name) + + # Build providers list from the dict, replacing DefaultInfo with our modified version providers = [] - - for prov in crate_providers: - if type(prov) == "DefaultInfo": - files = prov.files.to_list() - if len(files) != 1: - fail("Unexpected number of output files for `{}`: {}".format(ctx.label, files)) - wasm_file = files[0] - env["TEST_WASM_BINARY"] = _rlocationpath(files[0], ctx.workspace_name) + for key, prov in crate_providers.items(): + if key == "DefaultInfo": providers.append(DefaultInfo( - files = prov.files, - runfiles = prov.default_runfiles.merge(ctx.runfiles(files = [wasm_file], transitive_files = wb_toolchain.all_test_files)), + files = default_info.files, + runfiles = default_info.default_runfiles.merge(ctx.runfiles(files = [wasm_file], transitive_files = wb_toolchain.all_test_files)), executable = wrapper, )) else: diff --git a/rust/private/BUILD.bazel b/rust/private/BUILD.bazel index d18e895493..4efebc0e65 100644 --- a/rust/private/BUILD.bazel +++ b/rust/private/BUILD.bazel @@ -6,6 +6,12 @@ load("//rust/private:stamp.bzl", "stamp_build_setting") # Exported for docs exports_files(["providers.bzl"]) +# Test sharding wrapper scripts +exports_files([ + "test_sharding_wrapper.sh", + "test_sharding_wrapper.bat", +]) + bzl_library( name = "bazel_tools_bzl_lib", srcs = ["@bazel_tools//tools:bzl_srcs"], diff --git a/rust/private/rust.bzl b/rust/private/rust.bzl index 6ba4c361e1..5fa321cbd0 100644 --- a/rust/private/rust.bzl +++ b/rust/private/rust.bzl @@ -218,7 +218,7 @@ def _rust_library_common(ctx, crate_type): deps = transform_deps(ctx.attr.deps) proc_macro_deps = transform_deps(ctx.attr.proc_macro_deps + get_import_macro_deps(ctx)) - return rustc_compile_action( + return list(rustc_compile_action( ctx = ctx, attr = ctx.attr, toolchain = toolchain, @@ -246,7 +246,7 @@ def _rust_library_common(ctx, crate_type): owner = ctx.label, cfgs = _collect_cfgs(ctx, toolchain, crate_root, crate_type, crate_is_test = False), ), - ) + ).values()) def _rust_binary_impl(ctx): """The implementation of the `rust_binary` rule @@ -312,16 +312,16 @@ def _rust_binary_impl(ctx): ), ) - providers.append(RunEnvironmentInfo( + providers["RunEnvironmentInfo"] = RunEnvironmentInfo( environment = expand_dict_value_locations( ctx, ctx.attr.env, ctx.attr.data, {}, ), - )) + ) - return providers + return list(providers.values()) def get_rust_test_flags(attr): """Determine the desired rustc flags for test targets. @@ -516,6 +516,50 @@ def _rust_test_impl(ctx): rust_flags = get_rust_test_flags(ctx.attr), skip_expanding_rustc_env = True, ) + + # If sharding is enabled and we're using libtest harness, wrap the test binary + # with a script that handles test enumeration and shard partitioning + if ctx.attr.experimental_enable_sharding and ctx.attr.use_libtest_harness: + default_info = providers["DefaultInfo"] + + # Get the test binary from CrateInfo (or TestCrateInfo for staticlib/cdylib) + crate_info_provider = providers.get("crate_info") or providers.get("test_crate_info") + if crate_info_provider: + # TestCrateInfo wraps the actual CrateInfo + if hasattr(crate_info_provider, "crate"): + crate_info_provider = crate_info_provider.crate + test_binary = crate_info_provider.output + + # Select the appropriate wrapper template based on target OS + if toolchain.target_os == "windows": + wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.bat") + wrapper_template = ctx.file._test_sharding_wrapper_windows + else: + wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.sh") + wrapper_template = ctx.file._test_sharding_wrapper_unix + + # Generate wrapper script with test binary path substituted + ctx.actions.expand_template( + template = wrapper_template, + output = wrapper, + substitutions = { + "{{TEST_BINARY}}": test_binary.short_path, + }, + is_executable = True, + ) + + # Update runfiles to include both wrapper and test binary + new_runfiles = default_info.default_runfiles.merge( + ctx.runfiles(files = [test_binary]), + ) + + # Replace DefaultInfo with wrapper as executable + providers["DefaultInfo"] = DefaultInfo( + files = default_info.files, + runfiles = new_runfiles, + executable = wrapper, + ) + data = getattr(ctx.attr, "data", []) env = expand_dict_value_locations( @@ -544,12 +588,12 @@ def _rust_test_impl(ctx): env["RUST_LLVM_PROFDATA"] = llvm_profdata_path components = "{}/{}".format(ctx.label.workspace_root, ctx.label.package).split("/") env["CARGO_MANIFEST_DIR"] = "/".join([c for c in components if c]) - providers.append(RunEnvironmentInfo( + providers["RunEnvironmentInfo"] = RunEnvironmentInfo( environment = env, inherited_environment = ctx.attr.env_inherit, - )) + ) - return providers + return list(providers.values()) def _rust_library_group_impl(ctx): dep_variant_infos = [] @@ -918,6 +962,22 @@ _rust_test_attrs = { E.g. `bazel test //src:rust_test --test_arg=foo::test::test_fn`. """), ), + "experimental_enable_sharding": attr.bool( + mandatory = False, + default = False, + doc = dedent("""\ + If True, enable support for Bazel test sharding (shard_count attribute). + + When enabled, tests are executed via a wrapper script that: + 1. Enumerates tests using libtest's --list flag + 2. Partitions tests across shards based on TEST_SHARD_INDEX/TEST_TOTAL_SHARDS + 3. Runs only the tests assigned to the current shard + + This attribute only has an effect when use_libtest_harness is True. + + This is experimental and may change in future releases. + """), + ), } | _coverage_attrs | _experimental_use_cc_common_link_attrs rust_library = rule( @@ -1452,6 +1512,14 @@ rust_test = rule( "_allowlist_function_transition": attr.label( default = "@bazel_tools//tools/allowlists/function_transition_allowlist", ), + "_test_sharding_wrapper_unix": attr.label( + default = Label("//rust/private:test_sharding_wrapper.sh"), + allow_single_file = True, + ), + "_test_sharding_wrapper_windows": attr.label( + default = Label("//rust/private:test_sharding_wrapper.bat"), + allow_single_file = True, + ), }, executable = True, fragments = ["cpp"], diff --git a/rust/private/rustc.bzl b/rust/private/rustc.bzl index a28ad50b78..e4eb906528 100644 --- a/rust/private/rustc.bzl +++ b/rust/private/rustc.bzl @@ -1301,10 +1301,14 @@ def rustc_compile_action( include_coverage (bool, optional): Whether to generate coverage information or not. Returns: - list: A list of the following providers: - - (CrateInfo): info for the crate we just built; same as `crate_info` parameter. - - (DepInfo): The transitive dependencies of this crate. - - (DefaultInfo): The output file for this crate, and its runfiles. + dict: A dict mapping provider types to provider instances. Keys include: + - DefaultInfo: The output file for this crate, and its runfiles. + - CrateInfo: info for the crate we just built (or TestCrateInfo for staticlib/cdylib). + - DepInfo: The transitive dependencies of this crate. + - InstrumentedFilesInfo: Coverage information (if include_coverage is True). + - CcInfo: C/C++ interop info (if applicable). + - OutputGroupInfo: Additional output groups (if any). + Callers should convert to a list via `list(providers.values())` when returning from a rule. """ deps = crate_info_dict.pop("deps") proc_macro_deps = crate_info_dict.pop("proc_macro_deps") @@ -1695,25 +1699,24 @@ def rustc_compile_action( "metadata_files": coverage_runfiles + [executable] if executable else [], }) - providers = [ - DefaultInfo( + # Use string keys for providers since provider types are not hashable in all Bazel versions + providers = { + "DefaultInfo": DefaultInfo( # nb. This field is required for cc_library to depend on our output. files = depset(outputs), runfiles = runfiles, executable = executable, ), - ] + } # When invoked by aspects (and when running `bazel coverage`), the # baseline_coverage.dat created here will conflict with the baseline_coverage.dat of the # underlying target, which is a build failure. So we add an option to disable it so that this # function can be invoked from aspects for rules that have its own InstrumentedFilesInfo. if include_coverage: - providers.append( - coverage_common.instrumented_files_info( - ctx, - **instrumented_files_kwargs - ), + providers["InstrumentedFilesInfo"] = coverage_common.instrumented_files_info( + ctx, + **instrumented_files_kwargs ) if crate_info_dict != None: @@ -1732,11 +1735,20 @@ def rustc_compile_action( # as such they shouldn't provide a CrateInfo. However, one may still want to # write a rust_test for them, so we provide the CrateInfo wrapped in a provider # that rust_test understands. - providers.extend([rust_common.test_crate_info(crate = crate_info), dep_info]) + providers["test_crate_info"] = rust_common.test_crate_info(crate = crate_info) else: - providers.extend([crate_info, dep_info]) + providers["crate_info"] = crate_info + + providers["dep_info"] = dep_info - providers += establish_cc_info(ctx, attr, crate_info, toolchain, cc_toolchain, feature_configuration, interface_library) + cc_info_providers = establish_cc_info(ctx, attr, crate_info, toolchain, cc_toolchain, feature_configuration, interface_library) + for cc_provider in cc_info_providers: + # establish_cc_info returns CcInfo and optionally AllocatorLibrariesImplInfo + if type(cc_provider) == "CcInfo": + providers["CcInfo"] = cc_provider + else: + # AllocatorLibrariesImplInfo + providers["AllocatorLibrariesImplInfo"] = cc_provider output_group_info = {} @@ -1752,12 +1764,12 @@ def rustc_compile_action( output_group_info["rustc_output"] = depset([rustc_output]) if output_group_info: - providers.append(OutputGroupInfo(**output_group_info)) + providers["OutputGroupInfo"] = OutputGroupInfo(**output_group_info) # A bit unfortunate, but sidecar the lints info so rustdoc can access the # set of lints from the target it is documenting. if hasattr(ctx.attr, "lint_config") and ctx.attr.lint_config: - providers.append(ctx.attr.lint_config[LintsInfo]) + providers["LintsInfo"] = ctx.attr.lint_config[LintsInfo] return providers diff --git a/rust/private/test_sharding_wrapper.bat b/rust/private/test_sharding_wrapper.bat new file mode 100644 index 0000000000..e363aedf0d --- /dev/null +++ b/rust/private/test_sharding_wrapper.bat @@ -0,0 +1,118 @@ +@REM Copyright 2024 The Bazel Authors. All rights reserved. +@REM +@REM Licensed under the Apache License, Version 2.0 (the "License"); +@REM you may not use this file except in compliance with the License. +@REM You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, software +@REM distributed under the License is distributed on an "AS IS" BASIS, +@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@REM See the License for the specific language governing permissions and +@REM limitations under the License. + +@REM Wrapper script for rust_test that enables Bazel test sharding support. +@REM This script intercepts test execution, enumerates tests using libtest's +@REM --list flag, partitions them by shard index, and runs only the relevant subset. + +@ECHO OFF +SETLOCAL EnableDelayedExpansion + +SET TEST_BINARY_RAW={{TEST_BINARY}} +SET TEST_BINARY_PATH=!TEST_BINARY_RAW:/=\! + +@REM Try to find the binary using RUNFILES_DIR if set +IF DEFINED RUNFILES_DIR ( + SET TEST_BINARY_IN_RUNFILES=!RUNFILES_DIR!\!TEST_BINARY_PATH! + IF EXIST "!TEST_BINARY_IN_RUNFILES!" ( + SET TEST_BINARY_PATH=!TEST_BINARY_IN_RUNFILES! + ) +) + +@REM The short_path is like: test/unit/test_sharding/test-2586318641/sharded_test_enabled.exe +@REM But on Windows, the binary is at grandparent/test-XXX/name.exe (sibling of runfiles dir) +@REM Extract just the last two components (test-XXX/name.exe) +FOR %%F IN ("!TEST_BINARY_PATH!") DO SET BINARY_NAME=%%~nxF +FOR %%F IN ("!TEST_BINARY_PATH!\..") DO SET BINARY_DIR=%%~nxF + +@REM Try various path resolutions +SET FOUND_BINARY=0 + +@REM Try 1: Direct path (might work in some configurations) +IF EXIST "!TEST_BINARY_PATH!" ( + SET FOUND_BINARY=1 +) + +@REM Try 2: Grandparent + last two path components +IF !FOUND_BINARY! EQU 0 ( + FOR %%F IN ("!TEST_BINARY_PATH!") DO ( + SET TEMP_PATH=%%~dpF + SET TEMP_PATH=!TEMP_PATH:~0,-1! + FOR %%D IN ("!TEMP_PATH!") DO SET PARENT_DIR=%%~nxD + ) + SET TEST_BINARY_GP=..\..\!PARENT_DIR!\!BINARY_NAME! + IF EXIST "!TEST_BINARY_GP!" ( + SET TEST_BINARY_PATH=!TEST_BINARY_GP! + SET FOUND_BINARY=1 + ) +) + +@REM Try 3: RUNFILES_DIR based path +IF !FOUND_BINARY! EQU 0 IF DEFINED RUNFILES_DIR ( + SET TEST_BINARY_RF=!RUNFILES_DIR!\_main\!TEST_BINARY_PATH! + SET TEST_BINARY_RF=!TEST_BINARY_RF:/=\! + IF EXIST "!TEST_BINARY_RF!" ( + SET TEST_BINARY_PATH=!TEST_BINARY_RF! + SET FOUND_BINARY=1 + ) +) + +IF !FOUND_BINARY! EQU 0 ( + ECHO ERROR: Could not find test binary at any expected location + EXIT /B 1 +) + +@REM If sharding is not enabled, run test binary directly +IF "%TEST_TOTAL_SHARDS%"=="" ( + !TEST_BINARY_PATH! %* + EXIT /B !ERRORLEVEL! +) + +@REM Touch status file to advertise sharding support to Bazel +IF NOT "%TEST_SHARD_STATUS_FILE%"=="" ( + TYPE NUL > "%TEST_SHARD_STATUS_FILE%" +) + +@REM Create a temporary file for test list +SET TEMP_LIST=%TEMP%\rust_test_list_%RANDOM%.txt + +@REM Enumerate all tests using libtest's --list flag +!TEST_BINARY_PATH! --list --format terse 2>NUL > "!TEMP_LIST!" + +@REM Count tests and filter for this shard +SET INDEX=0 +SET SHARD_TESTS= + +FOR /F "tokens=1 delims=:" %%T IN ('TYPE "!TEMP_LIST!" ^| FINDSTR /E ": test"') DO ( + SET /A MOD=!INDEX! %% %TEST_TOTAL_SHARDS% + IF !MOD! EQU %TEST_SHARD_INDEX% ( + IF "!SHARD_TESTS!"=="" ( + SET SHARD_TESTS=%%T + ) ELSE ( + SET SHARD_TESTS=!SHARD_TESTS! %%T + ) + ) + SET /A INDEX=!INDEX! + 1 +) + +DEL "!TEMP_LIST!" 2>NUL + +@REM If no tests for this shard, exit successfully +IF "!SHARD_TESTS!"=="" ( + EXIT /B 0 +) + +@REM Run the filtered tests with --exact to match exact test names +!TEST_BINARY_PATH! !SHARD_TESTS! --exact %* +EXIT /B !ERRORLEVEL! diff --git a/rust/private/test_sharding_wrapper.sh b/rust/private/test_sharding_wrapper.sh new file mode 100644 index 0000000000..e05970ba0a --- /dev/null +++ b/rust/private/test_sharding_wrapper.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# Copyright 2024 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Wrapper script for rust_test that enables Bazel test sharding support. +# This script intercepts test execution, enumerates tests using libtest's +# --list flag, partitions them by shard index, and runs only the relevant subset. + +set -euo pipefail + +TEST_BINARY="{{TEST_BINARY}}" + +# If sharding is not enabled, run test binary directly +if [[ -z "${TEST_TOTAL_SHARDS:-}" ]]; then + exec "./${TEST_BINARY}" "$@" +fi + +# Touch status file to advertise sharding support to Bazel +if [[ -n "${TEST_SHARD_STATUS_FILE:-}" ]]; then + touch "${TEST_SHARD_STATUS_FILE}" +fi + +# Enumerate all tests using libtest's --list flag +# Output format: "test_name: test" - we need to strip the ": test" suffix +test_list=$("./${TEST_BINARY}" --list --format terse 2>/dev/null | grep ': test$' | sed 's/: test$//' || true) + +# If no tests found, exit successfully +if [[ -z "$test_list" ]]; then + exit 0 +fi + +# Filter tests for this shard +# test_index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX +shard_tests=() +index=0 +while IFS= read -r test_name; do + if (( index % TEST_TOTAL_SHARDS == TEST_SHARD_INDEX )); then + shard_tests+=("$test_name") + fi + ((index++)) || true +done <<< "$test_list" + +# If no tests for this shard, exit successfully +if [[ ${#shard_tests[@]} -eq 0 ]]; then + exit 0 +fi + +# Run the filtered tests with --exact to match exact test names +exec "./${TEST_BINARY}" "${shard_tests[@]}" --exact "$@" diff --git a/test/unit/consistent_crate_name/with_modified_crate_name.bzl b/test/unit/consistent_crate_name/with_modified_crate_name.bzl index 7148b0f0b8..3ec9df037a 100644 --- a/test/unit/consistent_crate_name/with_modified_crate_name.bzl +++ b/test/unit/consistent_crate_name/with_modified_crate_name.bzl @@ -31,7 +31,7 @@ def _with_modified_crate_name_impl(ctx): ) for dep in ctx.attr.deps] rust_lib = ctx.actions.declare_file(rust_lib_name) - return rustc_compile_action( + return list(rustc_compile_action( ctx = ctx, attr = ctx.attr, toolchain = toolchain, @@ -52,7 +52,7 @@ def _with_modified_crate_name_impl(ctx): is_test = False, ), output_hash = output_hash, - ) + ).values()) with_modified_crate_name = rule( implementation = _with_modified_crate_name_impl, diff --git a/test/unit/force_all_deps_direct/generator.bzl b/test/unit/force_all_deps_direct/generator.bzl index 09fba1f423..2314d28b70 100644 --- a/test/unit/force_all_deps_direct/generator.bzl +++ b/test/unit/force_all_deps_direct/generator.bzl @@ -47,7 +47,7 @@ def _generator_impl(ctx): ) for dep in ctx.attr.deps] rust_lib = ctx.actions.declare_file(rust_lib_name) - return rustc_compile_action( + return list(rustc_compile_action( ctx = ctx, attr = ctx.attr, toolchain = toolchain, @@ -69,7 +69,7 @@ def _generator_impl(ctx): ), output_hash = output_hash, force_all_deps_direct = True, - ) + ).values()) generator = rule( implementation = _generator_impl, diff --git a/test/unit/pipelined_compilation/wrap.bzl b/test/unit/pipelined_compilation/wrap.bzl index f24a0e421a..afbec2c783 100644 --- a/test/unit/pipelined_compilation/wrap.bzl +++ b/test/unit/pipelined_compilation/wrap.bzl @@ -59,7 +59,7 @@ def _wrap_impl(ctx): rust_metadata = None if ctx.attr.generate_metadata: rust_metadata = ctx.actions.declare_file(rust_metadata_name) - return rustc_compile_action( + return list(rustc_compile_action( ctx = ctx, attr = ctx.attr, toolchain = toolchain, @@ -83,7 +83,7 @@ def _wrap_impl(ctx): is_test = False, ), output_hash = output_hash, - ) + ).values()) wrap = rule( implementation = _wrap_impl, diff --git a/test/unit/test_sharding/BUILD.bazel b/test/unit/test_sharding/BUILD.bazel new file mode 100644 index 0000000000..0fbfffb43f --- /dev/null +++ b/test/unit/test_sharding/BUILD.bazel @@ -0,0 +1,4 @@ +load(":test_sharding.bzl", "test_sharding_test_suite") + +############################ UNIT TESTS ############################# +test_sharding_test_suite(name = "test_sharding_test_suite") diff --git a/test/unit/test_sharding/sharded_test.rs b/test/unit/test_sharding/sharded_test.rs new file mode 100644 index 0000000000..af8132dfee --- /dev/null +++ b/test/unit/test_sharding/sharded_test.rs @@ -0,0 +1,43 @@ +// Copyright 2024 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Test file with multiple tests for verifying sharding support. +//! +//! These tests are intentionally trivial - their purpose is to provide multiple +//! enumerable test functions that can be partitioned across shards. The BUILD +//! file runs this with `shard_count = 3`, so these 6 tests should be split +//! ~2 per shard. The sharding wrapper script enumerates tests via `--list`, +//! assigns each to a shard based on `index % shard_count`, and runs only the +//! subset for the current shard. + +#[cfg(test)] +mod tests { + #[test] + fn test_1() {} + + #[test] + fn test_2() {} + + #[test] + fn test_3() {} + + #[test] + fn test_4() {} + + #[test] + fn test_5() {} + + #[test] + fn test_6() {} +} diff --git a/test/unit/test_sharding/test_sharding.bzl b/test/unit/test_sharding/test_sharding.bzl new file mode 100644 index 0000000000..7ef3531023 --- /dev/null +++ b/test/unit/test_sharding/test_sharding.bzl @@ -0,0 +1,102 @@ +"""Tests for rust_test sharding support.""" + +load("@bazel_skylib//lib:unittest.bzl", "analysistest", "asserts") +load("//rust:defs.bzl", "rust_test") + +def _sharding_enabled_test(ctx): + """Test that sharding wrapper is generated when experimental_enable_sharding is True.""" + env = analysistest.begin(ctx) + tut = analysistest.target_under_test(env) + + # Get the executable from DefaultInfo + default_info = tut[DefaultInfo] + executable = default_info.files_to_run.executable + + # When sharding is enabled, the executable should be a wrapper script + asserts.true( + env, + executable.basename.endswith("_sharding_wrapper.sh") or + executable.basename.endswith("_sharding_wrapper.bat"), + "Expected sharding wrapper script, got: " + executable.basename, + ) + + return analysistest.end(env) + +sharding_enabled_test = analysistest.make(_sharding_enabled_test) + +def _sharding_disabled_test(ctx): + """Test that no wrapper is generated when experimental_enable_sharding is False.""" + env = analysistest.begin(ctx) + tut = analysistest.target_under_test(env) + + # Get the executable from DefaultInfo + default_info = tut[DefaultInfo] + executable = default_info.files_to_run.executable + + # When sharding is disabled, the executable should be the test binary directly + asserts.false( + env, + executable.basename.endswith("_sharding_wrapper.sh") or + executable.basename.endswith("_sharding_wrapper.bat"), + "Expected test binary, not wrapper script: " + executable.basename, + ) + + return analysistest.end(env) + +sharding_disabled_test = analysistest.make(_sharding_disabled_test) + +def _test_sharding_targets(): + """Create test targets for sharding tests.""" + + # Test with sharding enabled + rust_test( + name = "sharded_test_enabled", + srcs = ["sharded_test.rs"], + edition = "2021", + experimental_enable_sharding = True, + ) + + sharding_enabled_test( + name = "sharding_enabled_test", + target_under_test = ":sharded_test_enabled", + ) + + # Test with sharding disabled (default) + rust_test( + name = "sharded_test_disabled", + srcs = ["sharded_test.rs"], + edition = "2021", + experimental_enable_sharding = False, + ) + + sharding_disabled_test( + name = "sharding_disabled_test", + target_under_test = ":sharded_test_disabled", + ) + + # Integration test: actually run a sharded test + rust_test( + name = "sharded_integration_test", + srcs = ["sharded_test.rs"], + edition = "2021", + experimental_enable_sharding = True, + shard_count = 3, + ) + +def test_sharding_test_suite(name): + """Entry-point macro called from the BUILD file. + + Args: + name: Name of the macro. + """ + + _test_sharding_targets() + + native.test_suite( + name = name, + tests = [ + ":sharding_enabled_test", + ":sharding_disabled_test", + ":sharded_integration_test", + ], + )