From edea2ab8ea1aaa2e47489f630fa8a3d21db16655 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 2 Jul 2026 11:11:55 -0700 Subject: [PATCH] [mlir] Enable non-topological conversions. Make this lazy while checking for cycles. PiperOrigin-RevId: 941766336 --- .../xls_translate/xls_translate_from_mlir.cc | 93 ++++++++++++++++--- .../xls_translate/xls_translate_to_mlir.cc | 2 + 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.cc b/xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.cc index 55cdc97885..f133079ca1 100644 --- a/xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.cc +++ b/xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.cc @@ -126,17 +126,26 @@ class TranslationState { public: explicit TranslationState(Package& package) : package_(package) {} + using TranslateCallback = std::function; + void setValueMap(DenseMap& value_map) { this->value_map_ = &value_map; } + void setTranslateCallback(TranslateCallback cb) { + this->translate_cb_ = std::move(cb); + } + BValue getXlsValue(Value value) const { auto it = value_map_->find(value); - if (it == value_map_->end()) { - llvm::errs() << "Failed to find value: " << value << "\n"; - return BValue(); + if (it != value_map_->end()) { + return it->second; } - return it->second; + if (translate_cb_) { + return translate_cb_(value); + } + llvm::errs() << "Failed to find value: " << value << "\n"; + return BValue(); } void setValue(Value value, BValue xls_value) { @@ -211,6 +220,7 @@ class TranslationState { private: DenseMap* value_map_; + TranslateCallback translate_cb_ = nullptr; DenseMap channel_map_; llvm::StringMap<::xls::Function*> func_map_; llvm::StringMap package_map_; @@ -1173,6 +1183,9 @@ FailureOr convertFunction(TranslationState& translation_state, // Function return value. BValue out; + llvm::SmallPtrSet active_conversions; + std::function get_or_translate; + // Walk over the function ops and construct the XLS function. auto convert = [&](Operation* op) { return TypeSwitch(op) @@ -1219,12 +1232,12 @@ FailureOr convertFunction(TranslationState& translation_state, [&](auto t) { return convertOp(t, translation_state, fb); }) .Case([&](auto ret) { if (ret.getNumOperands() == 1) { - return out = value_map[ret.getOperand(0)]; + return out = get_or_translate(ret.getOperand(0)); } std::vector operands; operands.reserve(ret.getNumOperands()); for (auto operand : ret.getOperands()) { - operands.push_back(value_map[operand]); + operands.push_back(get_or_translate(operand)); } return out = fb.Tuple(operands); }) @@ -1232,7 +1245,7 @@ FailureOr convertFunction(TranslationState& translation_state, // BlockOutputOp is handled separately by the block conversion. // Just return a dummy valid BValue. if (block_out.getNumOperands() > 0) { - return value_map[block_out.getOperand(0)]; + return get_or_translate(block_out.getOperand(0)); } return BValue(); }) @@ -1266,14 +1279,14 @@ FailureOr convertFunction(TranslationState& translation_state, << "register not found: " << write_op.getReg(); return BValue(); } - BValue data = value_map[write_op.getData()]; + BValue data = get_or_translate(write_op.getData()); std::optional load_enable; if (write_op.getLoadEnable()) { - load_enable = value_map[write_op.getLoadEnable()]; + load_enable = get_or_translate(write_op.getLoadEnable()); } std::optional reset; if (write_op.getReset()) { - reset = value_map[write_op.getReset()]; + reset = get_or_translate(write_op.getReset()); } bb->RegisterWrite(reg, data, load_enable, reset, translation_state.getLoc(write_op)); @@ -1308,13 +1321,17 @@ FailureOr convertFunction(TranslationState& translation_state, if (child_block->GetResetPort().has_value() && port == *child_block->GetResetPort()) { if (inst_op.getReset()) { - BValue rst_data = value_map[inst_op.getReset()]; + BValue rst_data = get_or_translate(inst_op.getReset()); + bb->InstantiationInput(inst, port->name(), rst_data, + translation_state.getLoc(inst_op)); + } else if (bb->block()->GetResetPort().has_value()) { + BValue rst_data = BValue(*bb->block()->GetResetPort(), bb); bb->InstantiationInput(inst, port->name(), rst_data, translation_state.getLoc(inst_op)); } continue; } - BValue data = value_map[inst_op.getInputs()[input_idx]]; + BValue data = get_or_translate(inst_op.getInputs()[input_idx]); bb->InstantiationInput(inst, port->name(), data, translation_state.getLoc(inst_op)); input_idx++; @@ -1376,7 +1393,7 @@ FailureOr convertFunction(TranslationState& translation_state, int input_idx = 0; auto connectInput = [&](std::string_view port_name) { if (input_idx < static_cast(inst_op.getInputs().size())) { - BValue data = value_map[inst_op.getInputs()[input_idx]]; + BValue data = get_or_translate(inst_op.getInputs()[input_idx]); bb->InstantiationInput(inst, port_name, data, translation_state.getLoc(inst_op)); input_idx++; @@ -1410,7 +1427,7 @@ FailureOr convertFunction(TranslationState& translation_state, .Case([&](NextValueOp next) { // We just skip the next value op here as its handled in the eproc // conversion. - return value_map[next.getValues()[0]]; + return get_or_translate(next.getValues()[0]); }) .Case([&](CallDslxOp call) { llvm::errs() << "Call remaining, call pass normalize-xls-calls " @@ -1423,6 +1440,44 @@ FailureOr convertFunction(TranslationState& translation_state, }); }; + // Use lazy resolution to enable non-topologically sorted ops. A callback is + // used to allow for recursive calls to the convert/translation function + // without tieing those all into translation state. + get_or_translate = [&](Value val) -> BValue { + auto it = value_map.find(val); + if (it != value_map.end()) { + return it->second; + } + if (llvm::isa(val)) { + return BValue(); + } + Operation* defining_op = val.getDefiningOp(); + if (!defining_op) { + return BValue(); + } + + if (!active_conversions.insert(defining_op).second) { + defining_op->emitOpError( + "combinational cycle detected during translation"); + return BValue(); + } + + BValue result = convert(defining_op); + active_conversions.erase(defining_op); + + if (!result.valid()) { + return BValue(); + } + + if (defining_op->getNumResults() == 1 && + !llvm::isa(defining_op)) { + value_map[val] = result; + } + + return result; + }; + translation_state.setTranslateCallback(get_or_translate); + // Walk over the function and construct the XLS function. auto result_walk = xls_region.walk([&](Operation* op) { if (op == xls_region) { @@ -1448,6 +1503,13 @@ FailureOr convertFunction(TranslationState& translation_state, return WalkResult::skip(); } + // Skip if all results are already translated. + if (!op->getResults().empty() && + llvm::all_of(op->getResults(), + [&](Value res) { return value_map.contains(res); })) { + return WalkResult::advance(); + } + if (BValue r = convert(op); r.valid()) { // BlockInstantiateOp/FifoInstantiateOp manage their own value_map // entries for multiple results, so skip automatic assignment. @@ -1464,6 +1526,9 @@ FailureOr convertFunction(TranslationState& translation_state, return WalkResult::interrupt(); }); + // Reset to avoid dangling reference to local lamba and value_map. + translation_state.setTranslateCallback(nullptr); + if (result_walk.wasInterrupted()) { return failure(); } diff --git a/xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.cc b/xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.cc index e0910b0126..660146590d 100644 --- a/xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.cc +++ b/xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.cc @@ -1698,6 +1698,7 @@ Attribute translateXlsValueToAttribute(const ::xls::Value& value, return nullptr; } llvm::SmallVector elements; + elements.reserve(value.elements().size()); for (const auto& elem : value.elements()) { elements.push_back(translateXlsValueToAttribute( elem, array_type.getElementType(), builder)); @@ -1710,6 +1711,7 @@ Attribute translateXlsValueToAttribute(const ::xls::Value& value, return nullptr; } llvm::SmallVector elements; + elements.reserve(value.elements().size()); for (auto [i, elem] : llvm::enumerate(value.elements())) { elements.push_back( translateXlsValueToAttribute(elem, tuple_type.getType(i), builder));