Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ void GenIrNodesPass::HandleUMul(const FuzzUMulProto& umul) {
// If the bit width is set, use it.
BValue lhs = GetBitsOperand(umul.lhs_idx());
BValue rhs = GetBitsOperand(umul.rhs_idx());
// The bit width cannot exceed the sum of the bit widths of the operands.
int64_t right_bound =
std::min<int64_t>(IrFuzzHelpers::kMaxFuzzBitWidth,
lhs.BitCountOrDie() + rhs.BitCountOrDie());
// The bit width cannot exceed the sum of the bit widths of the operands
// or the max bit width.
int64_t right_bound = std::min<int64_t>(
max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t bit_width =
helpers_.BoundedWidth(umul.bit_width(), /*left_bound=*/1, right_bound);
state().context_list().AppendElement(
Expand All @@ -186,9 +186,8 @@ void GenIrNodesPass::HandleSMul(const FuzzSMulProto& smul) {
if (smul.has_bit_width()) {
BValue lhs = GetBitsOperand(smul.lhs_idx());
BValue rhs = GetBitsOperand(smul.rhs_idx());
int64_t right_bound =
std::min<int64_t>(IrFuzzHelpers::kMaxFuzzBitWidth,
lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t right_bound = std::min<int64_t>(
max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t bit_width =
helpers_.BoundedWidth(smul.bit_width(), /*left_bound=*/1, right_bound);
state().context_list().AppendElement(
Expand All @@ -205,9 +204,8 @@ void GenIrNodesPass::HandleUMulp(const FuzzUMulpProto& umulp) {
if (umulp.has_bit_width()) {
BValue lhs = GetBitsOperand(umulp.lhs_idx());
BValue rhs = GetBitsOperand(umulp.rhs_idx());
int64_t right_bound =
std::min<int64_t>(IrFuzzHelpers::kMaxFuzzBitWidth,
lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t right_bound = std::min<int64_t>(
max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t bit_width =
helpers_.BoundedWidth(umulp.bit_width(), /*left_bound=*/1, right_bound);
state().context_list().AppendElement(
Expand All @@ -224,9 +222,8 @@ void GenIrNodesPass::HandleSMulp(const FuzzSMulpProto& smulp) {
if (smulp.has_bit_width()) {
BValue lhs = GetBitsOperand(smulp.lhs_idx());
BValue rhs = GetBitsOperand(smulp.rhs_idx());
int64_t right_bound =
std::min<int64_t>(IrFuzzHelpers::kMaxFuzzBitWidth,
lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t right_bound = std::min<int64_t>(
max_bit_width_, lhs.BitCountOrDie() + rhs.BitCountOrDie());
int64_t bit_width =
helpers_.BoundedWidth(smulp.bit_width(), /*left_bound=*/1, right_bound);
state().context_list().AppendElement(
Expand Down Expand Up @@ -281,12 +278,12 @@ void GenIrNodesPass::HandleConcat(const FuzzConcatProto& concat) {
std::vector<BValue> operands =
GetBitsOperands(concat.operand_idxs(), /*min_operand_count=*/1);
// Only use operands such that their summed bit width is less than or equal
// to IrFuzzHelpers::kMaxFuzzBitWidth.
// to the maximum allowed bit width.
int64_t bit_width_sum = 0;
int64_t operand_count = 0;
for (const auto& operand : operands) {
bit_width_sum += operand.BitCountOrDie();
if (bit_width_sum > IrFuzzHelpers::kMaxFuzzBitWidth) {
if (bit_width_sum > max_bit_width_) {
break;
}
operand_count += 1;
Expand Down Expand Up @@ -410,13 +407,13 @@ void GenIrNodesPass::HandleSelect(const FuzzSelectProto& select) {

void GenIrNodesPass::HandleOneHot(const FuzzOneHotProto& one_hot) {
BValue operand = GetBitsOperand(one_hot.operand_idx());
// If the operand bit width is kMaxFuzzBitWidth, decrease it by 1 because
// OneHot returns an operand with a bit width of 1 + the bit width of the
// operand.
if (operand.BitCountOrDie() == IrFuzzHelpers::kMaxFuzzBitWidth) {
// If the operand bit width is at the maximum allowed bit width, decrease it
// by 1 because OneHot returns an operand with a bit width of 1 + the bit
// width of the operand.
if (operand.BitCountOrDie() >= max_bit_width_) {
auto operand_coercion_method = one_hot.operand_coercion_method();
operand = helpers_.ChangeBitWidth(
state().fb(), operand, IrFuzzHelpers::kMaxFuzzBitWidth - 1,
state().fb(), operand, max_bit_width_ - 1,
operand_coercion_method.change_bit_width_method());
}
// Convert the LsbOrMsb proto enum to the FunctionBuilder enum.
Expand Down Expand Up @@ -481,11 +478,23 @@ void GenIrNodesPass::HandlePrioritySelect(

void GenIrNodesPass::HandleClz(const FuzzClzProto& clz) {
BValue operand = GetBitsOperand(clz.operand_idx());
// Use max_bit_width_ - 1 because this operation internally generates a
// OneHot node which adds 1 to the bit count. See HandleOneHot.
if (operand.BitCountOrDie() >= max_bit_width_) {
operand =
helpers_.ChangeBitWidth(state().fb(), operand, max_bit_width_ - 1);
}
state().context_list().AppendElement(state().fb()->Clz(operand));
}

void GenIrNodesPass::HandleCtz(const FuzzCtzProto& ctz) {
BValue operand = GetBitsOperand(ctz.operand_idx());
// Use max_bit_width_ - 1 because this operation internally generates a
// OneHot node which adds 1 to the bit count. See HandleOneHot.
if (operand.BitCountOrDie() >= max_bit_width_) {
operand =
helpers_.ChangeBitWidth(state().fb(), operand, max_bit_width_ - 1);
}
state().context_list().AppendElement(state().fb()->Ctz(operand));
}

Expand Down Expand Up @@ -717,13 +726,13 @@ void GenIrNodesPass::HandleEncode(const FuzzEncodeProto& encode) {
void GenIrNodesPass::HandleDecode(const FuzzDecodeProto& decode) {
BValue operand = GetBitsOperand(decode.operand_idx());
// The decode bit width cannot exceed 2 ** operand_bit_width.
int64_t right_bound = IrFuzzHelpers::kMaxFuzzBitWidth;
int64_t right_bound = max_bit_width_;
// We could choose a larger size for this check, but we're clamping to
// IrFuzzHelpers::kMaxFuzzBitWidth and 10 bits is sufficient for this while
// max_bit_width_ and 10 bits is sufficient for this while
// completely avoiding potential overflow, unlike e.g. checking for 64 bits.
if (operand.BitCountOrDie() < 10) {
right_bound = std::min<int64_t>(IrFuzzHelpers::kMaxFuzzBitWidth,
1ULL << operand.BitCountOrDie());
right_bound =
std::min<int64_t>(max_bit_width_, 1ULL << operand.BitCountOrDie());
}
int64_t bit_width =
helpers_.BoundedWidth(decode.bit_width(), /*left_bound=*/1, right_bound);
Expand Down Expand Up @@ -762,7 +771,9 @@ void GenIrNodesPass::HandleDefineFunction(
// to p_->functions() until after the function is fully built.
p_->functions().size() + function_states_.size()),
fuzz_program_.version(), define_function.combine_list_method(),
helpers_.Bounded(define_function.next_nodes_consumed(), 0, 1000));
helpers_.Bounded(define_function.next_nodes_consumed(), /*left_bound=*/0,
/*right_bound=*/1000),
max_bit_width_);
}

void GenIrNodesPass::HandleInvoke(const FuzzInvokeProto& invoke) {
Expand Down
15 changes: 10 additions & 5 deletions xls/fuzzer/ir_fuzzer/gen_ir_nodes_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ class GenIrNodesPass : public IrFuzzVisitor {
public:
GenIrNodesPass(Package* p, std::string_view top_name,
const FuzzProgramProto& fuzz_program,
std::optional<int64_t> param_bits = std::nullopt)
std::optional<int64_t> param_bits = std::nullopt,
std::optional<int64_t> max_bit_width = std::nullopt)
: p_(p),
fuzz_program_(fuzz_program),
helpers_(fuzz_program_.version()),
helpers_(fuzz_program_.version(), max_bit_width),
max_bit_width_(max_bit_width.value_or(IrFuzzHelpers::kMaxFuzzBitWidth)),
remaining_param_bits_(param_bits) {
function_states_.emplace_back(p, top_name, fuzz_program_.version(),
fuzz_program_.combine_list_method());
fuzz_program_.combine_list_method(),
std::nullopt, max_bit_width);
}

void GenIrNodes();
Expand Down Expand Up @@ -128,9 +131,10 @@ class GenIrNodesPass : public IrFuzzVisitor {
public:
FunctionState(Package* p, std::string_view top_name, FuzzVersion version,
CombineListMethod combine_list_method,
std::optional<int64_t> nodes_to_consume = std::nullopt)
std::optional<int64_t> nodes_to_consume = std::nullopt,
std::optional<int64_t> max_bit_width = std::nullopt)
: fb_(std::make_unique<FunctionBuilder>(top_name, p)),
context_list_(p, fb_.get(), IrFuzzHelpers(version)),
context_list_(p, fb_.get(), IrFuzzHelpers(version, max_bit_width)),
combine_list_method_(combine_list_method),
nodes_to_consume_(nodes_to_consume) {}

Expand Down Expand Up @@ -208,6 +212,7 @@ class GenIrNodesPass : public IrFuzzVisitor {
caller_to_callee_;
std::vector<FunctionState> function_states_;
std::optional<int64_t> remaining_param_bits_ = std::nullopt;
int64_t max_bit_width_;
};

} // namespace xls
Expand Down
15 changes: 10 additions & 5 deletions xls/fuzzer/ir_fuzzer/ir_fuzz_domain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ std::vector<Op> GetRestrictedSet(absl::Span<Op const> ops, bool only_bits,

absl::Status BuildIr(Package* package, std::string_view top_name,
const FuzzProgramProto& fuzz_program,
std::optional<int64_t> param_bits) {
GenIrNodesPass pass(package, top_name, fuzz_program, param_bits);
std::optional<int64_t> param_bits,
std::optional<int64_t> max_bit_width) {
GenIrNodesPass pass(package, top_name, fuzz_program, param_bits,
max_bit_width);
pass.GenIrNodes();
return absl::OkStatus();
}
Expand Down Expand Up @@ -275,17 +277,20 @@ fuzztest::Domain<FuzzPackage> FuzzPackageDomainBuilder::Build() && {
fuzztest::Just<int>(*combine_list_method_))
: std::move(proto);
return fuzztest::Map(
[](FuzzProgramProto fuzz_program, std::optional<int64_t> param_bits) {
[](FuzzProgramProto fuzz_program, std::optional<int64_t> param_bits,
std::optional<int64_t> max_bit_width) {
// Create the package.
std::unique_ptr<Package> p =
std::make_unique<VerifiedPackage>(kFuzzTestName);
CHECK_OK(BuildIr(p.get(), kFuzzTestName, fuzz_program, param_bits))
CHECK_OK(BuildIr(p.get(), kFuzzTestName, fuzz_program, param_bits,
max_bit_width))
<< "Failed to build package from FuzzProgramProto: "
<< fuzz_program.DebugString();
// Create the FuzzPackage object as the domain export.
return FuzzPackage(std::move(p), fuzz_program);
},
std::move(proto), fuzztest::Just<std::optional<int64_t>>(param_bits_));
std::move(proto), fuzztest::Just<std::optional<int64_t>>(param_bits_),
fuzztest::Just<std::optional<int64_t>>(max_bit_width_));
}

fuzztest::Domain<FuzzPackageWithArgs> PackageWithArgsDomainBuilder::Build() && {
Expand Down
13 changes: 13 additions & 0 deletions xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class FuzzPackageDomainBuilder {
param_bits_ = bits;
return std::move(*this);
}
// Limits the bit width of any individual node to this value.
FuzzPackageDomainBuilder WithMaxBitWidth(int64_t bits) && {
max_bit_width_ = bits;
return std::move(*this);
}
// Forces the combine list method to be the given value.
FuzzPackageDomainBuilder WithCombineListMethod(
std::optional<CombineListMethod> method) && {
Expand All @@ -107,6 +112,7 @@ class FuzzPackageDomainBuilder {
bool allow_define_function_ = true;
bool allow_invoke_ = true;
std::optional<int64_t> param_bits_;
std::optional<int64_t> max_bit_width_;
std::optional<CombineListMethod> combine_list_method_;
};

Expand Down Expand Up @@ -145,6 +151,10 @@ class PackageWithArgsDomainBuilder {
return PackageWithArgsDomainBuilder(std::move(base_).WithParamBits(bits),
arg_set_count_);
}
PackageWithArgsDomainBuilder WithMaxBitWidth(int64_t bits) && {
return PackageWithArgsDomainBuilder(std::move(base_).WithMaxBitWidth(bits),
arg_set_count_);
}
PackageWithArgsDomainBuilder WithCombineListMethod(
std::optional<CombineListMethod> method) && {
return PackageWithArgsDomainBuilder(
Expand Down Expand Up @@ -197,6 +207,9 @@ class PackageDomainBuilder {
PackageDomainBuilder WithParamBits(int64_t bits) && {
return PackageDomainBuilder(std::move(base_).WithParamBits(bits));
}
PackageDomainBuilder WithMaxBitWidth(int64_t bits) && {
return PackageDomainBuilder(std::move(base_).WithMaxBitWidth(bits));
}
PackageDomainBuilder WithCombineListMethod(
std::optional<CombineListMethod> method) && {
return PackageDomainBuilder(
Expand Down
23 changes: 14 additions & 9 deletions xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h"

#include <algorithm>
#include <cstdint>
#include <string>
#include <vector>
Expand Down Expand Up @@ -127,7 +128,8 @@ BValue IrFuzzHelpers::CoercedArray(Package* p, FunctionBuilder* fb,
// Coerce each array element and create a new array with the coerced
// elements.
for (int64_t i = 0; i < array_size; i += 1) {
BValue element = fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true);
BValue element =
fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true);
BValue coerced_element =
Coerced(p, fb, element, coerced_type.array_element(),
target_array_type->element_type());
Expand Down Expand Up @@ -225,7 +227,8 @@ BValue IrFuzzHelpers::FittedArray(Package* p, FunctionBuilder* fb,
std::vector<BValue> fitted_elements;
// Fit each array element and create a new array with the fitted elements.
for (int64_t i = 0; i < target_array_type->size(); i += 1) {
BValue element = fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true);
BValue element =
fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true);
BValue fitted_element = Fitted(p, fb, element, coercion_method,
target_array_type->element_type());
fitted_elements.push_back(fitted_element);
Expand Down Expand Up @@ -362,15 +365,16 @@ BValue IrFuzzHelpers::DecreaseArraySize(
ArrayType* array_type = bvalue.GetType()->AsArrayOrDie();
switch (decrease_size_method) {
case DecreaseArraySizeMethod::ARRAY_SLICE_METHOD:
return fb->ArraySlice(bvalue, fb->Literal(UBits(0, 64)), new_size);
return fb->ArraySlice(bvalue, fb->Literal(UBits(0, BoundedWidth(64))),
new_size);
case DecreaseArraySizeMethod::SHRINK_ARRAY_METHOD:
default:
// Create a new array with only some of the elements from the original
// array.
std::vector<BValue> elements;
for (int64_t i = 0; i < new_size; i += 1) {
elements.push_back(
fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true));
elements.push_back(fb->ArrayIndex(
bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true));
}
return fb->Array(elements, array_type->element_type());
}
Expand All @@ -387,8 +391,8 @@ BValue IrFuzzHelpers::IncreaseArraySize(
// to expand the array to the specified size.
std::vector<BValue> elements;
for (int64_t i = 0; i < array_type->size(); i += 1) {
elements.push_back(
fb->ArrayIndex(bvalue, {fb->Literal(UBits(i, 64))}, true));
elements.push_back(fb->ArrayIndex(
bvalue, {fb->Literal(UBits(i, BoundedWidth(64)))}, true));
}
// Use a default value to fill the rest of the array.
BValue default_value =
Expand Down Expand Up @@ -430,6 +434,7 @@ int64_t IrFuzzHelpers::Bounded(int64_t value, int64_t left_bound,
// out of memory or take too long.
int64_t IrFuzzHelpers::BoundedWidth(int64_t bit_width, int64_t left_bound,
int64_t right_bound) const {
right_bound = std::max(left_bound, std::min(right_bound, max_bit_width_));
return Bounded(bit_width, left_bound, right_bound);
}

Expand All @@ -455,12 +460,12 @@ BValue IrFuzzHelpers::DefaultValue(Package* p, FunctionBuilder* fb,
return fb->Tuple({DefaultValue(p, fb, TypeCase::BITS_CASE)});
case TypeCase::ARRAY_CASE:
return fb->Array({DefaultValue(p, fb, TypeCase::BITS_CASE)},
p->GetBitsType(64));
p->GetBitsType(BoundedWidth(64)));
}
}

BValue IrFuzzHelpers::DefaultBitsValue(FunctionBuilder* fb) const {
return fb->Literal(UBits(0, 64));
return fb->Literal(UBits(0, BoundedWidth(64)));
}

// Returns a default BValue of the specified type.
Expand Down
9 changes: 6 additions & 3 deletions xls/fuzzer/ir_fuzzer/ir_fuzz_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ enum TypeCase {

class IrFuzzHelpers {
public:
constexpr explicit IrFuzzHelpers(FuzzVersion version)
: fuzz_version_(version) {}
constexpr explicit IrFuzzHelpers(
FuzzVersion version, std::optional<int64_t> max_bit_width = std::nullopt)
: fuzz_version_(version),
max_bit_width_(max_bit_width.value_or(kMaxFuzzBitWidth)) {}

BValue Coerced(Package* p, FunctionBuilder* fb, BValue bvalue,
const CoercedTypeProto& coerced_type, Type* target_type) const;
Expand Down Expand Up @@ -135,7 +137,7 @@ class IrFuzzHelpers {
case TypeCase::kArray:
return ConvertArrayTypeProtoToType(p, type_proto.array());
default:
return p->GetBitsType(64);
return p->GetBitsType(BoundedWidth(64));
}
}
template <typename BitsTypeProto>
Expand Down Expand Up @@ -171,6 +173,7 @@ class IrFuzzHelpers {

private:
FuzzVersion fuzz_version_;
int64_t max_bit_width_;
};

constexpr std::array<IrFuzzHelpers, 2> kFuzzHelpers =
Expand Down
14 changes: 14 additions & 0 deletions xls/fuzzer/ir_fuzzer/ir_fuzz_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ FUZZ_TEST(IrFuzzTest, BuilderRestrictsOp)
.NoCtz()
.Build());

void BuilderRestrictsBitWidth(FuzzPackage fuzz_package) {
for (Node* node : fuzz_package.p->functions().front()->nodes()) {
if (node->GetType()->IsBits()) {
EXPECT_LE(node->GetType()->AsBitsOrDie()->bit_count(), 10)
<< "Node exceeded max bit width: " << node->ToString()
<< "\nGenerated by program:\n"
<< fuzz_package.fuzz_program.DebugString();
}
}
}

FUZZ_TEST(IrFuzzTest, BuilderRestrictsBitWidth)
.WithDomains(FuzzPackageDomainBuilder().WithMaxBitWidth(10).Build());

// This test makes sure that the OptimizationPassChangesOutputs test works for
// this specific example.
TEST(IrFuzzTest, PassChangesOutputsWithBitsParam) {
Expand Down
Loading