diff --git a/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.cc b/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.cc index a4f8b18c22..f0913664da 100644 --- a/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.cc +++ b/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.cc @@ -164,10 +164,10 @@ void GenIrNodesPass::HandleUMul(const FuzzUMulProto& umul) { // If the bit width is set, use it. BValue lhs = GetBitsOperand(umul.lhs_idx()); BValue rhs = GetBitsOperand(umul.rhs_idx()); - // The bit width cannot exceed the sum of the bit widths of the operands. - int64_t right_bound = - std::min(IrFuzzHelpers::kMaxFuzzBitWidth, - lhs.BitCountOrDie() + rhs.BitCountOrDie()); + // The bit width cannot exceed the sum of the bit widths of the operands + // or the max bit width. + int64_t right_bound = std::min( + max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie()); int64_t bit_width = helpers_.BoundedWidth(umul.bit_width(), /*left_bound=*/1, right_bound); state().context_list().AppendElement( @@ -186,9 +186,8 @@ void GenIrNodesPass::HandleSMul(const FuzzSMulProto& smul) { if (smul.has_bit_width()) { BValue lhs = GetBitsOperand(smul.lhs_idx()); BValue rhs = GetBitsOperand(smul.rhs_idx()); - int64_t right_bound = - std::min(IrFuzzHelpers::kMaxFuzzBitWidth, - lhs.BitCountOrDie() + rhs.BitCountOrDie()); + int64_t right_bound = std::min( + max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie()); int64_t bit_width = helpers_.BoundedWidth(smul.bit_width(), /*left_bound=*/1, right_bound); state().context_list().AppendElement( @@ -205,9 +204,8 @@ void GenIrNodesPass::HandleUMulp(const FuzzUMulpProto& umulp) { if (umulp.has_bit_width()) { BValue lhs = GetBitsOperand(umulp.lhs_idx()); BValue rhs = GetBitsOperand(umulp.rhs_idx()); - int64_t right_bound = - std::min(IrFuzzHelpers::kMaxFuzzBitWidth, - lhs.BitCountOrDie() + rhs.BitCountOrDie()); + int64_t right_bound = std::min( + max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie()); int64_t bit_width = helpers_.BoundedWidth(umulp.bit_width(), /*left_bound=*/1, right_bound); state().context_list().AppendElement( @@ -224,9 +222,8 @@ void GenIrNodesPass::HandleSMulp(const FuzzSMulpProto& smulp) { if (smulp.has_bit_width()) { BValue lhs = GetBitsOperand(smulp.lhs_idx()); BValue rhs = GetBitsOperand(smulp.rhs_idx()); - int64_t right_bound = - std::min(IrFuzzHelpers::kMaxFuzzBitWidth, - lhs.BitCountOrDie() + rhs.BitCountOrDie()); + int64_t right_bound = std::min( + max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie()); int64_t bit_width = helpers_.BoundedWidth(smulp.bit_width(), /*left_bound=*/1, right_bound); state().context_list().AppendElement( @@ -281,12 +278,12 @@ void GenIrNodesPass::HandleConcat(const FuzzConcatProto& concat) { std::vector operands = GetBitsOperands(concat.operand_idxs(), /*min_operand_count=*/1); // Only use operands such that their summed bit width is less than or equal - // to IrFuzzHelpers::kMaxFuzzBitWidth. + // to the maximum allowed bit width. int64_t bit_width_sum = 0; int64_t operand_count = 0; for (const auto& operand : operands) { bit_width_sum += operand.BitCountOrDie(); - if (bit_width_sum > IrFuzzHelpers::kMaxFuzzBitWidth) { + if (bit_width_sum > max_bit_width_) { break; } operand_count += 1; @@ -410,13 +407,13 @@ void GenIrNodesPass::HandleSelect(const FuzzSelectProto& select) { void GenIrNodesPass::HandleOneHot(const FuzzOneHotProto& one_hot) { BValue operand = GetBitsOperand(one_hot.operand_idx()); - // If the operand bit width is kMaxFuzzBitWidth, decrease it by 1 because - // OneHot returns an operand with a bit width of 1 + the bit width of the - // operand. - if (operand.BitCountOrDie() == IrFuzzHelpers::kMaxFuzzBitWidth) { + // If the operand bit width is at the maximum allowed bit width, decrease it + // by 1 because OneHot returns an operand with a bit width of 1 + the bit + // width of the operand. + if (operand.BitCountOrDie() >= max_bit_width_) { auto operand_coercion_method = one_hot.operand_coercion_method(); operand = helpers_.ChangeBitWidth( - state().fb(), operand, IrFuzzHelpers::kMaxFuzzBitWidth - 1, + state().fb(), operand, max_bit_width_ - 1, operand_coercion_method.change_bit_width_method()); } // Convert the LsbOrMsb proto enum to the FunctionBuilder enum. @@ -481,11 +478,23 @@ void GenIrNodesPass::HandlePrioritySelect( void GenIrNodesPass::HandleClz(const FuzzClzProto& clz) { BValue operand = GetBitsOperand(clz.operand_idx()); + // Use max_bit_width_ - 1 because this operation internally generates a + // OneHot node which adds 1 to the bit count. See HandleOneHot. + if (operand.BitCountOrDie() >= max_bit_width_) { + operand = + helpers_.ChangeBitWidth(state().fb(), operand, max_bit_width_ - 1); + } state().context_list().AppendElement(state().fb()->Clz(operand)); } void GenIrNodesPass::HandleCtz(const FuzzCtzProto& ctz) { BValue operand = GetBitsOperand(ctz.operand_idx()); + // Use max_bit_width_ - 1 because this operation internally generates a + // OneHot node which adds 1 to the bit count. See HandleOneHot. + if (operand.BitCountOrDie() >= max_bit_width_) { + operand = + helpers_.ChangeBitWidth(state().fb(), operand, max_bit_width_ - 1); + } state().context_list().AppendElement(state().fb()->Ctz(operand)); } @@ -717,13 +726,13 @@ void GenIrNodesPass::HandleEncode(const FuzzEncodeProto& encode) { void GenIrNodesPass::HandleDecode(const FuzzDecodeProto& decode) { BValue operand = GetBitsOperand(decode.operand_idx()); // The decode bit width cannot exceed 2 ** operand_bit_width. - int64_t right_bound = IrFuzzHelpers::kMaxFuzzBitWidth; + int64_t right_bound = max_bit_width_; // We could choose a larger size for this check, but we're clamping to - // IrFuzzHelpers::kMaxFuzzBitWidth and 10 bits is sufficient for this while + // max_bit_width_ and 10 bits is sufficient for this while // completely avoiding potential overflow, unlike e.g. checking for 64 bits. if (operand.BitCountOrDie() < 10) { - right_bound = std::min(IrFuzzHelpers::kMaxFuzzBitWidth, - 1ULL << operand.BitCountOrDie()); + right_bound = + std::min(max_bit_width_, 1ULL << operand.BitCountOrDie()); } int64_t bit_width = helpers_.BoundedWidth(decode.bit_width(), /*left_bound=*/1, right_bound); @@ -762,7 +771,9 @@ void GenIrNodesPass::HandleDefineFunction( // to p_->functions() until after the function is fully built. p_->functions().size() + function_states_.size()), fuzz_program_.version(), define_function.combine_list_method(), - helpers_.Bounded(define_function.next_nodes_consumed(), 0, 1000)); + helpers_.Bounded(define_function.next_nodes_consumed(), /*left_bound=*/0, + /*right_bound=*/1000), + max_bit_width_); } void GenIrNodesPass::HandleInvoke(const FuzzInvokeProto& invoke) { diff --git a/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.h b/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.h index 0c33663c94..e9edd83e0d 100644 --- a/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.h +++ b/xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.h @@ -43,13 +43,16 @@ class GenIrNodesPass : public IrFuzzVisitor { public: GenIrNodesPass(Package* p, std::string_view top_name, const FuzzProgramProto& fuzz_program, - std::optional param_bits = std::nullopt) + std::optional param_bits = std::nullopt, + std::optional max_bit_width = std::nullopt) : p_(p), fuzz_program_(fuzz_program), - helpers_(fuzz_program_.version()), + helpers_(fuzz_program_.version(), max_bit_width), + max_bit_width_(max_bit_width.value_or(IrFuzzHelpers::kMaxFuzzBitWidth)), remaining_param_bits_(param_bits) { function_states_.emplace_back(p, top_name, fuzz_program_.version(), - fuzz_program_.combine_list_method()); + fuzz_program_.combine_list_method(), + std::nullopt, max_bit_width); } void GenIrNodes(); @@ -128,9 +131,10 @@ class GenIrNodesPass : public IrFuzzVisitor { public: FunctionState(Package* p, std::string_view top_name, FuzzVersion version, CombineListMethod combine_list_method, - std::optional nodes_to_consume = std::nullopt) + std::optional nodes_to_consume = std::nullopt, + std::optional max_bit_width = std::nullopt) : fb_(std::make_unique(top_name, p)), - context_list_(p, fb_.get(), IrFuzzHelpers(version)), + context_list_(p, fb_.get(), IrFuzzHelpers(version, max_bit_width)), combine_list_method_(combine_list_method), nodes_to_consume_(nodes_to_consume) {} @@ -208,6 +212,7 @@ class GenIrNodesPass : public IrFuzzVisitor { caller_to_callee_; std::vector function_states_; std::optional remaining_param_bits_ = std::nullopt; + int64_t max_bit_width_; }; } // namespace xls diff --git a/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.cc b/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.cc index f37d9b91b0..60ec5fe98e 100644 --- a/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.cc +++ b/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.cc @@ -243,8 +243,10 @@ std::vector GetRestrictedSet(absl::Span ops, bool only_bits, absl::Status BuildIr(Package* package, std::string_view top_name, const FuzzProgramProto& fuzz_program, - std::optional param_bits) { - GenIrNodesPass pass(package, top_name, fuzz_program, param_bits); + std::optional param_bits, + std::optional max_bit_width) { + GenIrNodesPass pass(package, top_name, fuzz_program, param_bits, + max_bit_width); pass.GenIrNodes(); return absl::OkStatus(); } @@ -275,17 +277,20 @@ fuzztest::Domain FuzzPackageDomainBuilder::Build() && { fuzztest::Just(*combine_list_method_)) : std::move(proto); return fuzztest::Map( - [](FuzzProgramProto fuzz_program, std::optional param_bits) { + [](FuzzProgramProto fuzz_program, std::optional param_bits, + std::optional max_bit_width) { // Create the package. std::unique_ptr p = std::make_unique(kFuzzTestName); - CHECK_OK(BuildIr(p.get(), kFuzzTestName, fuzz_program, param_bits)) + CHECK_OK(BuildIr(p.get(), kFuzzTestName, fuzz_program, param_bits, + max_bit_width)) << "Failed to build package from FuzzProgramProto: " << fuzz_program.DebugString(); // Create the FuzzPackage object as the domain export. return FuzzPackage(std::move(p), fuzz_program); }, - std::move(proto), fuzztest::Just>(param_bits_)); + std::move(proto), fuzztest::Just>(param_bits_), + fuzztest::Just>(max_bit_width_)); } fuzztest::Domain PackageWithArgsDomainBuilder::Build() && { diff --git a/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h b/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h index f18d579c6b..1a5bed5316 100644 --- a/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h +++ b/xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h @@ -84,6 +84,11 @@ class FuzzPackageDomainBuilder { param_bits_ = bits; return std::move(*this); } + // Limits the bit width of any individual node to this value. + FuzzPackageDomainBuilder WithMaxBitWidth(int64_t bits) && { + max_bit_width_ = bits; + return std::move(*this); + } // Forces the combine list method to be the given value. FuzzPackageDomainBuilder WithCombineListMethod( std::optional method) && { @@ -107,6 +112,7 @@ class FuzzPackageDomainBuilder { bool allow_define_function_ = true; bool allow_invoke_ = true; std::optional param_bits_; + std::optional max_bit_width_; std::optional combine_list_method_; }; @@ -145,6 +151,10 @@ class PackageWithArgsDomainBuilder { return PackageWithArgsDomainBuilder(std::move(base_).WithParamBits(bits), arg_set_count_); } + PackageWithArgsDomainBuilder WithMaxBitWidth(int64_t bits) && { + return PackageWithArgsDomainBuilder(std::move(base_).WithMaxBitWidth(bits), + arg_set_count_); + } PackageWithArgsDomainBuilder WithCombineListMethod( std::optional method) && { return PackageWithArgsDomainBuilder( @@ -197,6 +207,9 @@ class PackageDomainBuilder { PackageDomainBuilder WithParamBits(int64_t bits) && { return PackageDomainBuilder(std::move(base_).WithParamBits(bits)); } + PackageDomainBuilder WithMaxBitWidth(int64_t bits) && { + return PackageDomainBuilder(std::move(base_).WithMaxBitWidth(bits)); + } PackageDomainBuilder WithCombineListMethod( std::optional method) && { return PackageDomainBuilder( diff --git a/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.cc b/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.cc index 008d698461..c336d42a1e 100644 --- a/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.cc +++ b/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.cc @@ -14,6 +14,7 @@ #include "xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h" +#include #include #include #include @@ -127,7 +128,8 @@ BValue IrFuzzHelpers::CoercedArray(Package* p, FunctionBuilder* fb, // Coerce each array element and create a new array with the coerced // elements. for (int64_t i = 0; i < array_size; i += 1) { - BValue element = fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true); + BValue element = + fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true); BValue coerced_element = Coerced(p, fb, element, coerced_type.array_element(), target_array_type->element_type()); @@ -225,7 +227,8 @@ BValue IrFuzzHelpers::FittedArray(Package* p, FunctionBuilder* fb, std::vector fitted_elements; // Fit each array element and create a new array with the fitted elements. for (int64_t i = 0; i < target_array_type->size(); i += 1) { - BValue element = fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true); + BValue element = + fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true); BValue fitted_element = Fitted(p, fb, element, coercion_method, target_array_type->element_type()); fitted_elements.push_back(fitted_element); @@ -362,15 +365,16 @@ BValue IrFuzzHelpers::DecreaseArraySize( ArrayType* array_type = bvalue.GetType()->AsArrayOrDie(); switch (decrease_size_method) { case DecreaseArraySizeMethod::ARRAY_SLICE_METHOD: - return fb->ArraySlice(bvalue, fb->Literal(UBits(0, 64)), new_size); + return fb->ArraySlice(bvalue, fb->Literal(UBits(0, BoundedWidth(64))), + new_size); case DecreaseArraySizeMethod::SHRINK_ARRAY_METHOD: default: // Create a new array with only some of the elements from the original // array. std::vector elements; for (int64_t i = 0; i < new_size; i += 1) { - elements.push_back( - fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true)); + elements.push_back(fb->ArrayIndex( + bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true)); } return fb->Array(elements, array_type->element_type()); } @@ -387,8 +391,8 @@ BValue IrFuzzHelpers::IncreaseArraySize( // to expand the array to the specified size. std::vector elements; for (int64_t i = 0; i < array_type->size(); i += 1) { - elements.push_back( - fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true)); + elements.push_back(fb->ArrayIndex( + bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true)); } // Use a default value to fill the rest of the array. BValue default_value = @@ -430,6 +434,7 @@ int64_t IrFuzzHelpers::Bounded(int64_t value, int64_t left_bound, // out of memory or take too long. int64_t IrFuzzHelpers::BoundedWidth(int64_t bit_width, int64_t left_bound, int64_t right_bound) const { + right_bound = std::max(left_bound, std::min(right_bound, max_bit_width_)); return Bounded(bit_width, left_bound, right_bound); } @@ -455,12 +460,12 @@ BValue IrFuzzHelpers::DefaultValue(Package* p, FunctionBuilder* fb, return fb->Tuple({DefaultValue(p, fb, TypeCase::BITS_CASE)}); case TypeCase::ARRAY_CASE: return fb->Array({DefaultValue(p, fb, TypeCase::BITS_CASE)}, - p->GetBitsType(64)); + p->GetBitsType(BoundedWidth(64))); } } BValue IrFuzzHelpers::DefaultBitsValue(FunctionBuilder* fb) const { - return fb->Literal(UBits(0, 64)); + return fb->Literal(UBits(0, BoundedWidth(64))); } // Returns a default BValue of the specified type. diff --git a/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h b/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h index 4ef55699b2..4c453d27c1 100644 --- a/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h +++ b/xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h @@ -41,8 +41,10 @@ enum TypeCase { class IrFuzzHelpers { public: - constexpr explicit IrFuzzHelpers(FuzzVersion version) - : fuzz_version_(version) {} + constexpr explicit IrFuzzHelpers( + FuzzVersion version, std::optional max_bit_width = std::nullopt) + : fuzz_version_(version), + max_bit_width_(max_bit_width.value_or(kMaxFuzzBitWidth)) {} BValue Coerced(Package* p, FunctionBuilder* fb, BValue bvalue, const CoercedTypeProto& coerced_type, Type* target_type) const; @@ -135,7 +137,7 @@ class IrFuzzHelpers { case TypeCase::kArray: return ConvertArrayTypeProtoToType(p, type_proto.array()); default: - return p->GetBitsType(64); + return p->GetBitsType(BoundedWidth(64)); } } template @@ -171,6 +173,7 @@ class IrFuzzHelpers { private: FuzzVersion fuzz_version_; + int64_t max_bit_width_; }; constexpr std::array kFuzzHelpers = diff --git a/xls/fuzzer/ir_fuzzer/ir_fuzz_test.cc b/xls/fuzzer/ir_fuzzer/ir_fuzz_test.cc index 91d01495a1..649a4fd8fb 100644 --- a/xls/fuzzer/ir_fuzzer/ir_fuzz_test.cc +++ b/xls/fuzzer/ir_fuzzer/ir_fuzz_test.cc @@ -69,6 +69,20 @@ FUZZ_TEST(IrFuzzTest, BuilderRestrictsOp) .NoCtz() .Build()); +void BuilderRestrictsBitWidth(FuzzPackage fuzz_package) { + for (Node* node : fuzz_package.p->functions().front()->nodes()) { + if (node->GetType()->IsBits()) { + EXPECT_LE(node->GetType()->AsBitsOrDie()->bit_count(), 10) + << "Node exceeded max bit width: " << node->ToString() + << "\nGenerated by program:\n" + << fuzz_package.fuzz_program.DebugString(); + } + } +} + +FUZZ_TEST(IrFuzzTest, BuilderRestrictsBitWidth) + .WithDomains(FuzzPackageDomainBuilder().WithMaxBitWidth(10).Build()); + // This test makes sure that the OptimizationPassChangesOutputs test works for // this specific example. TEST(IrFuzzTest, PassChangesOutputsWithBitsParam) {