Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions extensions/prost/private/prost.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,22 @@ def _compile_proto(
return lib_rs, package_info_file

def _get_crate_info(providers):
"""Finds the CrateInfo provider in the list of providers."""
for provider in providers:
if hasattr(provider, "name"):
return provider
fail("Couldn't find a CrateInfo in the list of providers")
"""Finds the CrateInfo provider in the dict of providers."""
if "crate_info" in providers:
return providers["crate_info"]
fail("Couldn't find a CrateInfo in the providers")

def _get_dep_info(providers):
"""Finds the DepInfo provider in the list of providers."""
for provider in providers:
if hasattr(provider, "direct_crates"):
return provider
fail("Couldn't find a DepInfo in the list of providers")
"""Finds the DepInfo provider in the dict of providers."""
if "dep_info" in providers:
return providers["dep_info"]
fail("Couldn't find a DepInfo in the providers")

def _get_cc_info(providers):
"""Finds the CcInfo provider in the list of providers."""
for provider in providers:
if hasattr(provider, "linking_context"):
return provider
fail("Couldn't find a CcInfo in the list of providers")
"""Finds the CcInfo provider in the dict of providers."""
if "CcInfo" in providers:
return providers["CcInfo"]
fail("Couldn't find a CcInfo in the providers")

def _compile_rust(
*,
Expand Down
6 changes: 4 additions & 2 deletions extensions/protobuf/proto.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def _rust_proto_compile(protos, descriptor_sets, imports, crate_name, ctx, is_gr
),
output_hash = output_hash,
)
providers.append(OutputGroupInfo(rust_generated_srcs = srcs))
return providers
# Convert dict to list and add OutputGroupInfo
result = list(providers.values())
result.append(OutputGroupInfo(rust_generated_srcs = srcs))
return result

def _rust_protogrpc_library_impl(ctx, is_grpc):
"""Implementation of the rust_(proto|grpc)_library.
Expand Down
23 changes: 13 additions & 10 deletions extensions/wasm_bindgen/private/wasm_bindgen_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,21 @@ def _rust_wasm_bindgen_test_impl(ctx):
# Force the use of a browser for now as there is no node integration.
env["WASM_BINDGEN_USE_BROWSER"] = "1"

# Extract DefaultInfo and create a modified version with the wrapper as executable
default_info = crate_providers["DefaultInfo"]
files = default_info.files.to_list()
if len(files) != 1:
fail("Unexpected number of output files for `{}`: {}".format(ctx.label, files))
wasm_file = files[0]
env["TEST_WASM_BINARY"] = _rlocationpath(files[0], ctx.workspace_name)

# Build providers list from the dict, replacing DefaultInfo with our modified version
providers = []

for prov in crate_providers:
if type(prov) == "DefaultInfo":
files = prov.files.to_list()
if len(files) != 1:
fail("Unexpected number of output files for `{}`: {}".format(ctx.label, files))
wasm_file = files[0]
env["TEST_WASM_BINARY"] = _rlocationpath(files[0], ctx.workspace_name)
for key, prov in crate_providers.items():
if key == "DefaultInfo":
providers.append(DefaultInfo(
files = prov.files,
runfiles = prov.default_runfiles.merge(ctx.runfiles(files = [wasm_file], transitive_files = wb_toolchain.all_test_files)),
files = default_info.files,
runfiles = default_info.default_runfiles.merge(ctx.runfiles(files = [wasm_file], transitive_files = wb_toolchain.all_test_files)),
executable = wrapper,
))
else:
Expand Down
6 changes: 6 additions & 0 deletions rust/private/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ load("//rust/private:stamp.bzl", "stamp_build_setting")
# Exported for docs
exports_files(["providers.bzl"])

# Test sharding wrapper scripts
exports_files([
"test_sharding_wrapper.sh",
"test_sharding_wrapper.bat",
])

bzl_library(
name = "bazel_tools_bzl_lib",
srcs = ["@bazel_tools//tools:bzl_srcs"],
Expand Down
84 changes: 76 additions & 8 deletions rust/private/rust.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _rust_library_common(ctx, crate_type):
deps = transform_deps(ctx.attr.deps)
proc_macro_deps = transform_deps(ctx.attr.proc_macro_deps + get_import_macro_deps(ctx))

return rustc_compile_action(
return list(rustc_compile_action(
ctx = ctx,
attr = ctx.attr,
toolchain = toolchain,
Expand Down Expand Up @@ -246,7 +246,7 @@ def _rust_library_common(ctx, crate_type):
owner = ctx.label,
cfgs = _collect_cfgs(ctx, toolchain, crate_root, crate_type, crate_is_test = False),
),
)
).values())

def _rust_binary_impl(ctx):
"""The implementation of the `rust_binary` rule
Expand Down Expand Up @@ -312,16 +312,16 @@ def _rust_binary_impl(ctx):
),
)

providers.append(RunEnvironmentInfo(
providers["RunEnvironmentInfo"] = RunEnvironmentInfo(
environment = expand_dict_value_locations(
ctx,
ctx.attr.env,
ctx.attr.data,
{},
),
))
)

return providers
return list(providers.values())

def get_rust_test_flags(attr):
"""Determine the desired rustc flags for test targets.
Expand Down Expand Up @@ -516,6 +516,50 @@ def _rust_test_impl(ctx):
rust_flags = get_rust_test_flags(ctx.attr),
skip_expanding_rustc_env = True,
)

# If sharding is enabled and we're using libtest harness, wrap the test binary
# with a script that handles test enumeration and shard partitioning
if ctx.attr.experimental_enable_sharding and ctx.attr.use_libtest_harness:
default_info = providers["DefaultInfo"]

# Get the test binary from CrateInfo (or TestCrateInfo for staticlib/cdylib)
crate_info_provider = providers.get("crate_info") or providers.get("test_crate_info")
if crate_info_provider:
# TestCrateInfo wraps the actual CrateInfo
if hasattr(crate_info_provider, "crate"):
crate_info_provider = crate_info_provider.crate
test_binary = crate_info_provider.output

# Select the appropriate wrapper template based on target OS
if toolchain.target_os == "windows":
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.bat")
wrapper_template = ctx.file._test_sharding_wrapper_windows
else:
wrapper = ctx.actions.declare_file(ctx.label.name + "_sharding_wrapper.sh")
wrapper_template = ctx.file._test_sharding_wrapper_unix

# Generate wrapper script with test binary path substituted
ctx.actions.expand_template(
template = wrapper_template,
output = wrapper,
substitutions = {
"{{TEST_BINARY}}": test_binary.short_path,
},
is_executable = True,
)

# Update runfiles to include both wrapper and test binary
new_runfiles = default_info.default_runfiles.merge(
ctx.runfiles(files = [test_binary]),
)

# Replace DefaultInfo with wrapper as executable
providers["DefaultInfo"] = DefaultInfo(
files = default_info.files,
runfiles = new_runfiles,
executable = wrapper,
)

data = getattr(ctx.attr, "data", [])

env = expand_dict_value_locations(
Expand Down Expand Up @@ -544,12 +588,12 @@ def _rust_test_impl(ctx):
env["RUST_LLVM_PROFDATA"] = llvm_profdata_path
components = "{}/{}".format(ctx.label.workspace_root, ctx.label.package).split("/")
env["CARGO_MANIFEST_DIR"] = "/".join([c for c in components if c])
providers.append(RunEnvironmentInfo(
providers["RunEnvironmentInfo"] = RunEnvironmentInfo(
environment = env,
inherited_environment = ctx.attr.env_inherit,
))
)

return providers
return list(providers.values())

def _rust_library_group_impl(ctx):
dep_variant_infos = []
Expand Down Expand Up @@ -918,6 +962,22 @@ _rust_test_attrs = {
E.g. `bazel test //src:rust_test --test_arg=foo::test::test_fn`.
"""),
),
"experimental_enable_sharding": attr.bool(
mandatory = False,
default = False,
doc = dedent("""\
If True, enable support for Bazel test sharding (shard_count attribute).

When enabled, tests are executed via a wrapper script that:
1. Enumerates tests using libtest's --list flag
2. Partitions tests across shards based on TEST_SHARD_INDEX/TEST_TOTAL_SHARDS
3. Runs only the tests assigned to the current shard

This attribute only has an effect when use_libtest_harness is True.

This is experimental and may change in future releases.
"""),
),
} | _coverage_attrs | _experimental_use_cc_common_link_attrs

rust_library = rule(
Expand Down Expand Up @@ -1452,6 +1512,14 @@ rust_test = rule(
"_allowlist_function_transition": attr.label(
default = "@bazel_tools//tools/allowlists/function_transition_allowlist",
),
"_test_sharding_wrapper_unix": attr.label(
default = Label("//rust/private:test_sharding_wrapper.sh"),
allow_single_file = True,
),
"_test_sharding_wrapper_windows": attr.label(
default = Label("//rust/private:test_sharding_wrapper.bat"),
allow_single_file = True,
),
},
executable = True,
fragments = ["cpp"],
Expand Down
46 changes: 29 additions & 17 deletions rust/private/rustc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1301,10 +1301,14 @@ def rustc_compile_action(
include_coverage (bool, optional): Whether to generate coverage information or not.

Returns:
list: A list of the following providers:
- (CrateInfo): info for the crate we just built; same as `crate_info` parameter.
- (DepInfo): The transitive dependencies of this crate.
- (DefaultInfo): The output file for this crate, and its runfiles.
dict: A dict mapping provider types to provider instances. Keys include:
- DefaultInfo: The output file for this crate, and its runfiles.
- CrateInfo: info for the crate we just built (or TestCrateInfo for staticlib/cdylib).
- DepInfo: The transitive dependencies of this crate.
- InstrumentedFilesInfo: Coverage information (if include_coverage is True).
- CcInfo: C/C++ interop info (if applicable).
- OutputGroupInfo: Additional output groups (if any).
Callers should convert to a list via `list(providers.values())` when returning from a rule.
"""
deps = crate_info_dict.pop("deps")
proc_macro_deps = crate_info_dict.pop("proc_macro_deps")
Expand Down Expand Up @@ -1695,25 +1699,24 @@ def rustc_compile_action(
"metadata_files": coverage_runfiles + [executable] if executable else [],
})

providers = [
DefaultInfo(
# Use string keys for providers since provider types are not hashable in all Bazel versions
providers = {
"DefaultInfo": DefaultInfo(
# nb. This field is required for cc_library to depend on our output.
files = depset(outputs),
runfiles = runfiles,
executable = executable,
),
]
}

# When invoked by aspects (and when running `bazel coverage`), the
# baseline_coverage.dat created here will conflict with the baseline_coverage.dat of the
# underlying target, which is a build failure. So we add an option to disable it so that this
# function can be invoked from aspects for rules that have its own InstrumentedFilesInfo.
if include_coverage:
providers.append(
coverage_common.instrumented_files_info(
ctx,
**instrumented_files_kwargs
),
providers["InstrumentedFilesInfo"] = coverage_common.instrumented_files_info(
ctx,
**instrumented_files_kwargs
)

if crate_info_dict != None:
Expand All @@ -1732,11 +1735,20 @@ def rustc_compile_action(
# as such they shouldn't provide a CrateInfo. However, one may still want to
# write a rust_test for them, so we provide the CrateInfo wrapped in a provider
# that rust_test understands.
providers.extend([rust_common.test_crate_info(crate = crate_info), dep_info])
providers["test_crate_info"] = rust_common.test_crate_info(crate = crate_info)
else:
providers.extend([crate_info, dep_info])
providers["crate_info"] = crate_info

providers["dep_info"] = dep_info

providers += establish_cc_info(ctx, attr, crate_info, toolchain, cc_toolchain, feature_configuration, interface_library)
cc_info_providers = establish_cc_info(ctx, attr, crate_info, toolchain, cc_toolchain, feature_configuration, interface_library)
for cc_provider in cc_info_providers:
# establish_cc_info returns CcInfo and optionally AllocatorLibrariesImplInfo
if type(cc_provider) == "CcInfo":
providers["CcInfo"] = cc_provider
else:
# AllocatorLibrariesImplInfo
providers["AllocatorLibrariesImplInfo"] = cc_provider

output_group_info = {}

Expand All @@ -1752,12 +1764,12 @@ def rustc_compile_action(
output_group_info["rustc_output"] = depset([rustc_output])

if output_group_info:
providers.append(OutputGroupInfo(**output_group_info))
providers["OutputGroupInfo"] = OutputGroupInfo(**output_group_info)

# A bit unfortunate, but sidecar the lints info so rustdoc can access the
# set of lints from the target it is documenting.
if hasattr(ctx.attr, "lint_config") and ctx.attr.lint_config:
providers.append(ctx.attr.lint_config[LintsInfo])
providers["LintsInfo"] = ctx.attr.lint_config[LintsInfo]

return providers

Expand Down
Loading