diff --git a/xls/ir/BUILD b/xls/ir/BUILD index eb076c71c5..31109706cf 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -1271,7 +1271,11 @@ cc_test( ":bits", ":value", ":value_test_util", + "//xls/common:proto_test_utils", "//xls/common:xls_gunit_main", + "//xls/common/file:filesystem", + "//xls/common/fuzzing:fuzztest", + "@abseil-cpp//absl/log:check", "@googletest//:gtest", ], ) @@ -1354,9 +1358,12 @@ cc_library( ":type_manager", ":value", ":value_flattening", + ":xls_ir_interface_cc_proto", ":xls_type_cc_proto", + "//xls/common/file:filesystem", "//xls/common/fuzzing:fuzztest", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:span", "@googletest//:gtest", ], ) diff --git a/xls/ir/value_test_util.cc b/xls/ir/value_test_util.cc index 0142eefe43..7497203995 100644 --- a/xls/ir/value_test_util.cc +++ b/xls/ir/value_test_util.cc @@ -15,10 +15,13 @@ #include "xls/ir/value_test_util.h" #include +#include +#include #include "gtest/gtest.h" #include "xls/common/fuzzing/fuzztest.h" #include "absl/log/check.h" +#include "xls/common/file/filesystem.h" #include "xls/ir/bits.h" #include "xls/ir/bits_test_utils.h" #include "xls/ir/fuzz_type_domain.h" @@ -26,6 +29,7 @@ #include "xls/ir/type_manager.h" #include "xls/ir/value.h" #include "xls/ir/value_flattening.h" +#include "xls/ir/xls_ir_interface.pb.h" #include "xls/ir/xls_type.pb.h" namespace xls { @@ -75,4 +79,25 @@ fuzztest::Domain ArbitraryValue(TypeProto type) { return ArbitraryValue(fuzztest::Just(type)); } +// Parses the binary-format serialized proto of type ElementOfProto and returns +// the proto. +ElementOfProto ParseElementOfProto(absl::Span bytes) { + ElementOfProto proto; + CHECK(proto.ParseFromArray(bytes.data(), bytes.size())); + return proto; +} + +// Note to self: ElementOfProto is a 'using' alias defined in the header. +fuzztest::Domain ElementOfDomain(ElementOfProto proto) { + std::vector values; + values.reserve(proto.values_size()); + for (const auto& value_proto : proto.values()) { + auto value_or = Value::FromProto(value_proto); + CHECK_OK(value_or.status()) << "Failed to parse Value from proto: " + << value_proto.ShortDebugString(); + values.push_back(std::move(value_or.value())); + } + return fuzztest::ElementOf(values); +} + } // namespace xls diff --git a/xls/ir/value_test_util.h b/xls/ir/value_test_util.h index 7d532fbed5..318e35d641 100644 --- a/xls/ir/value_test_util.h +++ b/xls/ir/value_test_util.h @@ -19,11 +19,15 @@ #include "gtest/gtest.h" #include "xls/common/fuzzing/fuzztest.h" +#include "absl/types/span.h" #include "xls/ir/value.h" +#include "xls/ir/xls_ir_interface.pb.h" #include "xls/ir/xls_type.pb.h" namespace xls { +using ElementOfProto = PackageInterfaceProto::FuzzTestDomain::ElementOf; + // Returns an assertion result indicating whether the given two values were // equal. If equal the return value is AssertionSuccess, otherwise // AssertionFailure. For large Values (arrays, tuples, and very wide bit widths) @@ -43,6 +47,13 @@ fuzztest::Domain ArbitraryValue(fuzztest::Domain type); // Create a domain for an arbitrary value which is of the given type. fuzztest::Domain ArbitraryValue(TypeProto type); +// Parses the binary-format serialized proto of type ElementOfProto and returns +// the proto. +ElementOfProto ParseElementOfProto(absl::Span bytes); + +// Create an element_of domain from a serialized ElementOf proto. +fuzztest::Domain ElementOfDomain(ElementOfProto proto); + } // namespace xls #endif // XLS_IR_VALUE_TEST_UTIL_H_ diff --git a/xls/ir/value_test_util_test.cc b/xls/ir/value_test_util_test.cc index 842baaa869..a04074aa67 100644 --- a/xls/ir/value_test_util_test.cc +++ b/xls/ir/value_test_util_test.cc @@ -14,7 +14,12 @@ #include "xls/ir/value_test_util.h" +#include + +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "xls/common/fuzzing/fuzztest.h" +#include "xls/common/proto_test_utils.h" #include "xls/ir/bits.h" #include "xls/ir/value.h" @@ -28,5 +33,26 @@ TEST(ValueTestUtilTest, ValuesEqual) { EXPECT_FALSE(ValuesEqual(Value(UBits(1, 1234)), Value(UBits(1, 10)))); } +ElementOfProto MakeTestProto() { + ElementOfProto proto; + *proto.add_values() = Value(UBits(1, 32)).AsProto().value(); + *proto.add_values() = Value(UBits(2, 32)).AsProto().value(); + return proto; +} + +void ElementOfDomainTestHelper(const Value& value) { + EXPECT_TRUE(value == Value(UBits(1, 32)) || value == Value(UBits(2, 32))); +} +FUZZ_TEST(ValueTestUtilTest, ElementOfDomainTestHelper) + .WithDomains(ElementOfDomain(MakeTestProto())); + +TEST(ValueTestUtilTest, ParseElementOfProtoBytes) { + ElementOfProto original = MakeTestProto(); + std::string serialized = original.SerializeAsString(); + std::vector bytes(serialized.begin(), serialized.end()); + ElementOfProto parsed = ParseElementOfProto(bytes); + EXPECT_THAT(parsed, proto_testing::EqualsProto(original)); +} + } // namespace } // namespace xls diff --git a/xls/jit/BUILD b/xls/jit/BUILD index 5c81c69e0e..97786db9bb 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -375,6 +375,7 @@ pytype_strict_library( "//xls/ir:xls_ir_interface_py_pb2", "//xls/ir:xls_type_py_pb2", "@abseil-py//absl:app", + "@protobuf//:protobuf_python", "@xls_pip_deps//jinja2", ], ) @@ -413,8 +414,8 @@ pytype_strict_contrib_test( "//xls/common:runfiles", "//xls/ir:xls_ir_interface_py_pb2", "//xls/ir:xls_type_py_pb2", - "@abseil-py//absl:app", "@abseil-py//absl/testing:absltest", + "@protobuf//:protobuf_python", "@xls_pip_deps//jinja2", ], ) diff --git a/xls/jit/jit_wrapper_generator.py b/xls/jit/jit_wrapper_generator.py index 5bfaf482a0..f999c76a99 100644 --- a/xls/jit/jit_wrapper_generator.py +++ b/xls/jit/jit_wrapper_generator.py @@ -21,6 +21,7 @@ from typing import Optional from absl import app +from google.protobuf import text_format import jinja2 from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2 @@ -32,7 +33,6 @@ class FuzzTestInfo: """FuzzTest specific information for a value.""" - domain_snippet: Optional[str] = None domain_proto: Optional[ ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain ] = None @@ -312,30 +312,89 @@ def can_use_uint64_range( return False +def _combine_tuple_domains( + t: type_pb2.TypeProto, + child_domains: list[tuple[Optional[str], bool]], + type_expr: str, +) -> tuple[str, bool]: + """Combines child domains of a tuple into a single FuzzTest domain snippet. + + Args: + t: The XLS type proto. + child_domains: The results of calling to_domain on each child type. + type_expr: C++ expression representing the TypeProto of this type. + + Returns: + A tuple (domain_snippet, is_native). + - domain_snippet: A string representing the combined C++ FuzzTest domain, or + None. + - is_native: True if the domain produces a native C++ type, False if it + produces xls::Value. + """ + if all(domain[1] and domain[0] is not None for domain in child_domains): + # All children are "native", so just use fuzztest::TupleOf. + elems = [domain[0] for domain in child_domains] + return f"fuzztest::TupleOf({', '.join(elems)})", True + + # Some children are not "native", so use fuzztest::Map to convert them to + # xls::Value. + elem_value_domains = [] + for i, (e, (elem_d, elem_is_native)) in enumerate( + zip(t.tuple_elements, child_domains) + ): + child_type_expr = f"{type_expr}.tuple_elements({i})" + if elem_d is None: + elem_value_domains.append(f"xls::ArbitraryValue({child_type_expr})") + elif elem_is_native: + c_type = to_c_type(e) + conv = to_value_conversion(e, "v") + elem_value_domains.append( + f"fuzztest::Map([]({c_type} v) {{ return {conv}; }}, {elem_d})" + ) + else: + elem_value_domains.append(elem_d) + + lambda_args = ", ".join( + f"const xls::Value& v{i}" for i in range(len(t.tuple_elements)) + ) + lambda_body = ", ".join(f"v{i}" for i in range(len(t.tuple_elements))) + snippet = ( + f"fuzztest::Map([]({lambda_args}) {{ " + f"return xls::Value::Tuple({{{lambda_body}}}); }}, " + f"{', '.join(elem_value_domains)})" + ) + return snippet, False + + def to_domain( t: type_pb2.TypeProto, d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain], -) -> Optional[str]: + type_expr: str, +) -> tuple[Optional[str], bool]: """Converts an XLS type and domain spec to a FuzzTest domain string. Args: t: The XLS type proto. d: The optional FuzzTest domain specification from the package interface. + type_expr: C++ expression representing the TypeProto of this type. Returns: - A string representing the C++ FuzzTest domain (e.g., - "fuzztest::Arbitrary()"), or None if it should fallback to - xls::ArbitraryValue. + A tuple (domain_snippet, is_native). + - domain_snippet: A string representing the C++ FuzzTest domain, or None. + - is_native: True if the domain produces the native C++ type, False if it + produces xls::Value. Raises: app.UsageError: If the domain specification is invalid or unsupported for the given type. """ if t.type_enum == type_pb2.TypeProto.ARRAY: - elem_domain = to_domain(t.array_element, None) - if elem_domain is None: - return None - return f"fuzztest::ArrayOf<{t.array_size}>({elem_domain})" + elem_domain, elem_is_native = to_domain( + t.array_element, None, f"{type_expr}.array_element()" + ) + if elem_domain is None or not elem_is_native: + return None, False + return f"fuzztest::ArrayOf<{t.array_size}>({elem_domain})", True if ( d is None @@ -349,55 +408,68 @@ def to_domain( if t.type_enum == type_pb2.TypeProto.BITS: c_type = to_specialized(t, int_only=True) if c_type is None: - return None + return None, False if t.bit_count in (8, 16, 32, 64): - return f"fuzztest::Arbitrary<{c_type}>()" + return f"fuzztest::Arbitrary<{c_type}>()", True else: max_val = (1 << t.bit_count) - 1 - return f"fuzztest::InRange<{c_type}>(0, {max_val})" + return f"fuzztest::InRange<{c_type}>(0, {max_val})", True elif t.type_enum == type_pb2.TypeProto.TUPLE: - elems = [to_domain(e, None) for e in t.tuple_elements] - if any(e is None for e in elems): - return None - return f"fuzztest::TupleOf({', '.join(elems)})" - + child_results = [ + to_domain(e, None, f"{type_expr}.tuple_elements({i})") + for i, e in enumerate(t.tuple_elements) + ] + return _combine_tuple_domains(t, child_results, type_expr) else: - return None + return None, False if d.HasField("range"): cpp_type = to_specialized(t, int_only=True) if cpp_type is None: if not can_use_uint64_range(t, d): - raise app.UsageError( - "Range domain is only supported for specializable bits types or" - " ranges fitting in 64 bits" - ) + return None, False cpp_type = "uint64_t" min_val = extract_int_from_bytes(d.range.min.bits.data) max_val = extract_int_from_bytes(d.range.max.bits.data) - return f"fuzztest::InRange<{cpp_type}>({min_val}, {max_val})" + return f"fuzztest::InRange<{cpp_type}>({min_val}, {max_val})", True if d.HasField("element_of"): c_type = to_specialized(t, int_only=True) - if c_type is None: - raise app.UsageError( - "ElementOf domain only supported for specializable bits types in" - " this CL" + if c_type is not None: + vals = [ + str(extract_int_from_bytes(v.bits.data)) for v in d.element_of.values + ] + return ( + f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})", + True, + ) + else: + proto_bytes = d.element_of.SerializeToString() + bytes_str = ", ".join(str(b) for b in proto_bytes) + # Proto in two formats: text (for the human-readable comment) and binary + # (to be parsed by the C++ helper). + proto_text = text_format.MessageToString(d.element_of).replace( + "\n", "\n// " + ) + return ( + ( + f"// {proto_text}\n" + f"xls::ElementOfDomain(xls::ParseElementOfProto(std::vector{{{bytes_str}}}))" + ), + False, ) - vals = [ - str(extract_int_from_bytes(v.bits.data)) for v in d.element_of.values - ] - return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})" if d.HasField("tuple"): if t.type_enum != type_pb2.TypeProto.TUPLE: raise app.UsageError("Tuple domain requires Tuple type") if len(d.tuple.elements) != len(t.tuple_elements): raise app.UsageError("Tuple domain and type element count mismatch") - elems = [ - to_domain(te, de) for te, de in zip(t.tuple_elements, d.tuple.elements) + + child_results = [ + to_domain(te, de, f"{type_expr}.tuple_elements({i})") + for i, (te, de) in enumerate(zip(t.tuple_elements, d.tuple.elements)) ] - return f"fuzztest::TupleOf({', '.join(elems)})" + return _combine_tuple_domains(t, child_results, type_expr) raise app.UsageError(f"Unsupported domain: {d}") @@ -443,9 +515,7 @@ def to_param( unpacked_type=to_unpacked(p.type), specialized_type=to_specialized(p.type), type_proto=p.type, - fuzztest_info=FuzzTestInfo( - domain_snippet=to_domain(p.type, d), domain_proto=d - ), + fuzztest_info=FuzzTestInfo(domain_proto=d), ) @@ -622,30 +692,42 @@ def wrapped_to_fuzztest( ) -> PropertyFunction: """Converts a WrappedIr object to a dictionary for fuzztest template.""" params = [] + can_be_specialized = wrapped.can_be_specialized + if wrapped.params: for idx, p in enumerate(wrapped.params): - is_native = p.specialized_type is not None domain_proto = p.fuzztest_info.domain_proto if p.fuzztest_info else None - domain_snippet = ( - p.fuzztest_info.domain_snippet if p.fuzztest_info else None + type_expr = f"{lib_class_name}::GetParamType({idx}).value()" + domain_snippet, is_native_domain = to_domain( + p.type_proto, domain_proto, type_expr ) + is_native = p.specialized_type is not None if not is_native and can_use_uint64_range(p.type_proto, domain_proto): cpp_type = "uint64_t" is_native = True else: cpp_type = to_c_type(p.type_proto) - if cpp_type is None: + if not is_native_domain: + cpp_type = "xls::Value" + is_native = False + + if cpp_type is None or cpp_type == "xls::Value": cpp_type = "xls::Value" + is_native = False conversion_snippet = None else: conversion_snippet = to_value_conversion(p.type_proto, p.name) + if not is_native: + can_be_specialized = False + if ( len(wrapped.params) == 1 and p.type_proto.type_enum == type_pb2.TypeProto.TUPLE and domain_snippet is not None + and is_native_domain ): domain_snippet = f"fuzztest::TupleOf({domain_snippet})" @@ -676,7 +758,7 @@ def wrapped_to_fuzztest( namespace=wrapped.namespace, params=params, return_type=wrapped.result is not None, - can_be_specialized=wrapped.can_be_specialized, + can_be_specialized=can_be_specialized, result_width=result_width, ) diff --git a/xls/jit/jit_wrapper_generator_test.py b/xls/jit/jit_wrapper_generator_test.py index d241c648cf..df91a76be7 100644 --- a/xls/jit/jit_wrapper_generator_test.py +++ b/xls/jit/jit_wrapper_generator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl import app +from google.protobuf import text_format import jinja2 from absl.testing import absltest @@ -313,12 +313,7 @@ def test_function_tuple_param(self): packed_type='', unpacked_type='', specialized_type=None, - fuzztest_info=jit_wrapper_generator.FuzzTestInfo( - domain_snippet=( - 'fuzztest::TupleOf(fuzztest::Arbitrary(),' - ' fuzztest::Arbitrary())' - ) - ), + fuzztest_info=jit_wrapper_generator.FuzzTestInfo(), ), ], result=jit_wrapper_generator.XlsNamedValue( @@ -342,10 +337,75 @@ def test_function_tuple_param(self): ) self.assertEqual( prop_func.params[0].domain_snippet, - 'fuzztest::TupleOf(fuzztest::TupleOf(fuzztest::Arbitrary(),' + 'fuzztest::TupleOf(fuzztest::TupleOf(fuzztest::InRange(0, 1),' ' fuzztest::Arbitrary()))', ) + def test_wrapped_to_fuzztest_mixed_tuple_fallback(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u128] + ) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.tuple.elements.add().range.min.bits.bit_count = 32 + d.tuple.elements[0].range.min.bits.data = b'\x00' + d.tuple.elements[0].range.max.bits.bit_count = 32 + d.tuple.elements[0].range.max.bits.data = b'\x0a' + d_child2 = d.tuple.elements.add() + d_child2.range.min.bits.bit_count = 128 + d_child2.range.min.bits.data = b'\x01' + d_child2.range.max.bits.bit_count = 128 + d_child2.range.max.bits.data = b'\x00\x00\x00\x00\x00\x00\x00\x00\x01' + + wrapped_ir = jit_wrapper_generator.WrappedIr( + jit_type=jit_wrapper_generator.JitType.FUNCTION, + ir_text='', + function_name='mixed_tuple_func', + class_name='MixedTupleFuncJit', + header_guard='HEADER_GUARD', + header_filename='mixed_tuple_func_jit.h', + namespace='xls', + aot_entrypoint=None, + params=[ + jit_wrapper_generator.XlsNamedValue( + name='t', + type_proto=tup, + packed_type='', + unpacked_type='', + specialized_type=None, + fuzztest_info=jit_wrapper_generator.FuzzTestInfo( + domain_proto=d + ), + ), + ], + result=jit_wrapper_generator.XlsNamedValue( + name='res', + type_proto=u32, + packed_type='', + unpacked_type='', + specialized_type=None, + ), + ) + prop_func = jit_wrapper_generator.wrapped_to_fuzztest( + wrapped_ir, 'xls::MixedTupleFuncJit', 'mixed_tuple_func_jit.h' + ) + self.assertEqual(prop_func.fuzztest_name, 'mixed_tuple_func_fuzztest') + self.assertTrue(prop_func.return_type) + self.assertLen(prop_func.params, 1) + self.assertEqual(prop_func.params[0].name, 't') + self.assertEqual(prop_func.params[0].index, 0) + self.assertEqual(prop_func.params[0].cpp_type, 'xls::Value') + self.assertEqual( + prop_func.params[0].domain_snippet, + 'fuzztest::Map([](const xls::Value& v0, const xls::Value& v1) { ' + 'return xls::Value::Tuple({v0, v1}); }, ' + 'fuzztest::Map([](uint32_t v) { return xls::Value(xls::UBits(v,' + ' 32)); }, ' + 'fuzztest::InRange(0, 10)), ' + 'xls::ArbitraryValue(xls::MixedTupleFuncJit::GetParamType(0).value().tuple_elements(1)))', + ) + def test_function_no_result(self): u8 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=8) wrapped_ir = jit_wrapper_generator.WrappedIr( @@ -530,13 +590,8 @@ def test_render_fuzztest_basic(self): 'XLS_ASSERT_OK_AND_ASSIGN(xls::Value result, f->Run(', rendered_code ) self.assertIn('FUZZ_TEST(my_func_fuzztest, my_func)', rendered_code) - self.assertIn('xls::ArbitraryValue(', rendered_code) - self.assertIn( - 'xls::test::MyFuncJit::GetParamType(0).value()', rendered_code - ) - self.assertIn( - 'xls::test::MyFuncJit::GetParamType(1).value()', rendered_code - ) + self.assertIn('fuzztest::Arbitrary()', rendered_code) + self.assertIn('fuzztest::Arbitrary()', rendered_code) def test_render_fuzztest_array_of_int(self): u16 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=16) @@ -576,7 +631,9 @@ def test_render_fuzztest_array_of_int(self): self.assertIn( 'FUZZ_TEST(array_int_func_fuzztest, array_int_func)', rendered_code ) - self.assertIn('xls::test::ArrayIntFuncJit::GetParamType(0)', rendered_code) + self.assertIn( + 'fuzztest::ArrayOf<4>(fuzztest::Arbitrary())', rendered_code + ) def test_render_fuzztest_array_of_tuple(self): u8 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=8) @@ -622,7 +679,9 @@ def test_render_fuzztest_array_of_tuple(self): 'FUZZ_TEST(array_tuple_func_fuzztest, array_tuple_func)', rendered_code ) self.assertIn( - 'xls::test::ArrayTupleFuncJit::GetParamType(0).value()', rendered_code + 'fuzztest::ArrayOf<3>(fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::Arbitrary()))', + rendered_code, ) def test_render_fuzztest_tuple_of_int(self): @@ -665,7 +724,9 @@ def test_render_fuzztest_tuple_of_int(self): 'FUZZ_TEST(tuple_int_func_fuzztest, tuple_int_func)', rendered_code ) self.assertIn( - 'xls::test::TupleIntFuncJit::GetParamType(0).value()', rendered_code + 'fuzztest::TupleOf(fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::Arbitrary()))', + rendered_code, ) def test_render_fuzztest_tuple_mixed(self): @@ -714,7 +775,10 @@ def test_render_fuzztest_tuple_mixed(self): 'FUZZ_TEST(tuple_mixed_func_fuzztest, tuple_mixed_func)', rendered_code ) self.assertIn( - 'xls::test::TupleMixedFuncJit::GetParamType(0).value()', rendered_code + 'fuzztest::TupleOf(fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::Arbitrary())))', + rendered_code, ) def test_render_fuzztest_uses_property_param_filter(self): @@ -767,9 +831,7 @@ def test_render_fuzztest_default_domain(self): packed_type='xls::PackedBitsView<32>', unpacked_type='xls::BitsView<32>', specialized_type='uint32_t', - fuzztest_info=jit_wrapper_generator.FuzzTestInfo( - domain_snippet='fuzztest::Arbitrary()' - ), + fuzztest_info=jit_wrapper_generator.FuzzTestInfo(), ), ], result=None, @@ -797,20 +859,23 @@ def test_extract_int_from_bytes(self): def test_bits_domain_power_of_2(self): u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) self.assertEqual( - jit_wrapper_generator.to_domain(u32, None), - 'fuzztest::Arbitrary()', + jit_wrapper_generator.to_domain(u32, None, 'T'), + ('fuzztest::Arbitrary()', True), ) def test_bits_domain_non_power_of_2(self): u17 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=17) self.assertEqual( - jit_wrapper_generator.to_domain(u17, None), - 'fuzztest::InRange(0, 131071)', + jit_wrapper_generator.to_domain(u17, None, 'T'), + ('fuzztest::InRange(0, 131071)', True), ) def test_bits_domain_too_wide(self): u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) - self.assertIsNone(jit_wrapper_generator.to_domain(u128, None)) + self.assertEqual( + jit_wrapper_generator.to_domain(u128, None, 'T'), + (None, False), + ) def test_range_domain(self): u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) @@ -820,8 +885,8 @@ def test_range_domain(self): d.range.max.bits.bit_count = 32 d.range.max.bits.data = b'\x0a' self.assertEqual( - jit_wrapper_generator.to_domain(u32, d), - 'fuzztest::InRange(1, 10)', + jit_wrapper_generator.to_domain(u32, d, 'T'), + ('fuzztest::InRange(1, 10)', True), ) def test_range_domain_wide_bits_fits(self): @@ -832,8 +897,8 @@ def test_range_domain_wide_bits_fits(self): d.range.max.bits.bit_count = 128 d.range.max.bits.data = b'\x0a' self.assertEqual( - jit_wrapper_generator.to_domain(u128, d), - 'fuzztest::InRange(1, 10)', + jit_wrapper_generator.to_domain(u128, d, 'T'), + ('fuzztest::InRange(1, 10)', True), ) def test_element_of_domain(self): @@ -846,8 +911,8 @@ def test_element_of_domain(self): v2.bits.bit_count = 32 v2.bits.data = b'\x02' self.assertEqual( - jit_wrapper_generator.to_domain(u32, d), - 'fuzztest::ElementOf(std::vector{1, 2})', + jit_wrapper_generator.to_domain(u32, d, 'T'), + ('fuzztest::ElementOf(std::vector{1, 2})', True), ) def test_tuple_domain(self): @@ -863,9 +928,14 @@ def test_tuple_domain(self): d.tuple.elements.add().arbitrary = True self.assertEqual( - jit_wrapper_generator.to_domain(tup, d), - 'fuzztest::TupleOf(fuzztest::InRange(0, 10),' - ' fuzztest::Arbitrary())', + jit_wrapper_generator.to_domain(tup, d, 'T'), + ( + ( + 'fuzztest::TupleOf(fuzztest::InRange(0, 10),' + ' fuzztest::Arbitrary())' + ), + True, + ), ) def test_nested_tuple_domain(self): @@ -886,9 +956,14 @@ def test_nested_tuple_domain(self): inner_d.range.max.bits.data = b'\x05' self.assertEqual( - jit_wrapper_generator.to_domain(outer_tup, d), - 'fuzztest::TupleOf(fuzztest::Arbitrary(),' - ' fuzztest::TupleOf(fuzztest::InRange(0, 5)))', + jit_wrapper_generator.to_domain(outer_tup, d, 'T'), + ( + ( + 'fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::TupleOf(fuzztest::InRange(0, 5)))' + ), + True, + ), ) def test_array_domain(self): @@ -897,8 +972,8 @@ def test_array_domain(self): type_enum=type_pb2.TypeProto.ARRAY, array_size=3, array_element=u32 ) self.assertEqual( - jit_wrapper_generator.to_domain(arr, None), - 'fuzztest::ArrayOf<3>(fuzztest::Arbitrary())', + jit_wrapper_generator.to_domain(arr, None, 'T'), + ('fuzztest::ArrayOf<3>(fuzztest::Arbitrary())', True), ) def test_tuple_with_array_domain(self): @@ -915,12 +990,17 @@ def test_tuple_with_array_domain(self): d.tuple.elements.add().arbitrary = True self.assertEqual( - jit_wrapper_generator.to_domain(tup, d), - 'fuzztest::TupleOf(fuzztest::Arbitrary(),' - ' fuzztest::ArrayOf<3>(fuzztest::Arbitrary()))', + jit_wrapper_generator.to_domain(tup, d, 'T'), + ( + ( + 'fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::ArrayOf<3>(fuzztest::Arbitrary()))' + ), + True, + ), ) - def test_unsupported_domain_raises(self): + def test_unsupported_range_domain_returns_none(self): u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) tup = type_pb2.TypeProto( type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32] @@ -930,12 +1010,75 @@ def test_unsupported_domain_raises(self): d.range.min.bits.data = b'\x00' d.range.max.bits.bit_count = 32 d.range.max.bits.data = b'\x0a' + self.assertEqual( + jit_wrapper_generator.to_domain(tup, d, 'T'), + (None, False), + ) + + def test_element_of_domain_non_specializable(self): + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + v1 = d.element_of.values.add() + v1.bits.bit_count = 128 + v1.bits.data = b'\x01' + v2 = d.element_of.values.add() + v2.bits.bit_count = 128 + v2.bits.data = b'\x02' + expected_proto_str = text_format.MessageToString(d.element_of) + proto_bytes = d.element_of.SerializeToString() + bytes_str = ', '.join(str(b) for b in proto_bytes) + proto_text = expected_proto_str.replace('\n', '\n// ') + expected_output = ( + f'// {proto_text}\n' + f'xls::ElementOfDomain(xls::ParseElementOfProto(std::vector{{{bytes_str}}}))' + ) + self.assertEqual( + jit_wrapper_generator.to_domain(u128, d, 'T'), + (expected_output, False), + ) - with self.assertRaisesRegex( - app.UsageError, - 'Range domain is only supported for specializable bits types', - ): - jit_wrapper_generator.to_domain(tup, d) + def test_range_domain_wide_bits_does_not_fit(self): + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.range.min.bits.bit_count = 128 + d.range.min.bits.data = b'\x01' + d.range.max.bits.bit_count = 128 + d.range.max.bits.data = b'\x00\x00\x00\x00\x00\x00\x00\x00\x01' + self.assertEqual( + jit_wrapper_generator.to_domain(u128, d, 'T'), + (None, False), + ) + + def test_tuple_domain_with_unsupported_child_fallback(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u128] + ) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.tuple.elements.add().range.min.bits.bit_count = 32 + d.tuple.elements[0].range.min.bits.data = b'\x00' + d.tuple.elements[0].range.max.bits.bit_count = 32 + d.tuple.elements[0].range.max.bits.data = b'\x0a' + d_child2 = d.tuple.elements.add() + d_child2.range.min.bits.bit_count = 128 + d_child2.range.min.bits.data = b'\x01' + d_child2.range.max.bits.bit_count = 128 + d_child2.range.max.bits.data = b'\x00\x00\x00\x00\x00\x00\x00\x00\x01' + self.assertEqual( + jit_wrapper_generator.to_domain(tup, d, 'T'), + ( + ( + 'fuzztest::Map([](const xls::Value& v0, const xls::Value& v1) {' + ' return xls::Value::Tuple({v0, v1}); },' + ' fuzztest::Map([](uint32_t v) { return' + ' xls::Value(xls::UBits(v, 32)); },' + ' fuzztest::InRange(0, 10)),' + ' xls::ArbitraryValue(T.tuple_elements(1)))' + ), + False, + ), + ) class JitWrapperGeneratorToParamTest(absltest.TestCase): @@ -947,9 +1090,7 @@ def test_to_param_default_domain(self): self.assertEqual(xls_param.name, 'a') fuzztest_info = xls_param.fuzztest_info assert fuzztest_info is not None - self.assertEqual( - fuzztest_info.domain_snippet, 'fuzztest::Arbitrary()' - ) + self.assertIsNone(fuzztest_info.domain_proto) if __name__ == '__main__': diff --git a/xls/tests/fuzz_test/BUILD b/xls/tests/fuzz_test/BUILD index 35dd1023c9..c0a42df235 100644 --- a/xls/tests/fuzz_test/BUILD +++ b/xls/tests/fuzz_test/BUILD @@ -183,6 +183,12 @@ dslx_fuzz_test( test_function = "nested_big_array", ) +dslx_fuzz_test( + name = "array_of_tuples_with_wide_bits_fuzz_test", + library = ":array_tests_dslx", + test_function = "array_of_tuples_with_wide_bits", +) + dslx_fuzz_test( name = "inline_nested_struct_domain_fuzz_test", library = ":struct_tests_dslx", @@ -212,3 +218,21 @@ dslx_fuzz_test( library = ":struct_tests_dslx", test_function = "test_struct_domain_tuple", ) + +dslx_fuzz_test( + name = "wide_fuzz_test", + library = ":struct_tests_dslx", + test_function = "wide", +) + +dslx_fuzz_test( + name = "wide_element_of_fuzz_test", + library = ":struct_tests_dslx", + test_function = "wide_element_of", +) + +dslx_fuzz_test( + name = "wide_tuple_element_of_fuzz_test", + library = ":struct_tests_dslx", + test_function = "wide_tuple_element_of", +) diff --git a/xls/tests/fuzz_test/array_tests.x b/xls/tests/fuzz_test/array_tests.x index 7005dcf6d2..6bee6e87ba 100644 --- a/xls/tests/fuzz_test/array_tests.x +++ b/xls/tests/fuzz_test/array_tests.x @@ -41,3 +41,8 @@ fn tuple_with_big_array(x: (uN[128][2], u32)) -> bool { fn nested_big_array(x: uN[128][2][3]) -> bool { true } + +#[fuzz_test(domains=`()`)] +fn array_of_tuples_with_wide_bits(x: (uN[128], u32)[2]) -> bool { + true +} diff --git a/xls/tests/fuzz_test/struct_tests.x b/xls/tests/fuzz_test/struct_tests.x index 6c22fbdcee..b2ab878bc8 100644 --- a/xls/tests/fuzz_test/struct_tests.x +++ b/xls/tests/fuzz_test/struct_tests.x @@ -173,3 +173,24 @@ fn create_tuple_outer_domain() -> TupleOuterDomain { fn test_struct_domain_tuple(s: TupleOuter) -> bool { s.c.0.y < u32:10 && s.c.1 < u8:11 } + +struct WideStruct { w: uN[128], x: u32 } + +#[fuzz_test(domains=`WideStruct { x: u32:0..10 }`)] +fn wide(s: WideStruct) -> bool { + true +} + +#[fuzz_test(domains=`[uN[128]:5, 10, 15]`)] +fn wide_element_of(x: uN[128]) -> bool { + x == uN[128]:5 || x == uN[128]:10 || x == uN[128]:15 +} + +struct WideTupleStruct { w: uN[128], x: u32 } + +#[fuzz_test(domains=`WideTupleStruct { w: [uN[128]:1, 2], x: u32:0..10 }`)] +fn wide_tuple_element_of(s: WideTupleStruct) -> bool { + true +} + +