Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,26 @@ class TranslationState {
public:
explicit TranslationState(Package& package) : package_(package) {}

using TranslateCallback = std::function<BValue(Value)>;

void setValueMap(DenseMap<Value, BValue>& value_map) {
this->value_map_ = &value_map;
}

void setTranslateCallback(TranslateCallback cb) {
this->translate_cb_ = std::move(cb);
}

BValue getXlsValue(Value value) const {
auto it = value_map_->find(value);
if (it == value_map_->end()) {
llvm::errs() << "Failed to find value: " << value << "\n";
return BValue();
if (it != value_map_->end()) {
return it->second;
}
return it->second;
if (translate_cb_) {
return translate_cb_(value);
}
llvm::errs() << "Failed to find value: " << value << "\n";
return BValue();
}

void setValue(Value value, BValue xls_value) {
Expand Down Expand Up @@ -211,6 +220,7 @@ class TranslationState {

private:
DenseMap<Value, BValue>* value_map_;
TranslateCallback translate_cb_ = nullptr;
DenseMap<ChanOp, ::xls::Channel*> channel_map_;
llvm::StringMap<::xls::Function*> func_map_;
llvm::StringMap<PackageInfo> package_map_;
Expand Down Expand Up @@ -1173,6 +1183,9 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
// Function return value.
BValue out;

llvm::SmallPtrSet<Operation*, 8> active_conversions;
std::function<BValue(Value)> get_or_translate;

// Walk over the function ops and construct the XLS function.
auto convert = [&](Operation* op) {
return TypeSwitch<Operation*, BValue>(op)
Expand Down Expand Up @@ -1219,20 +1232,20 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
[&](auto t) { return convertOp(t, translation_state, fb); })
.Case<func::ReturnOp, YieldOp>([&](auto ret) {
if (ret.getNumOperands() == 1) {
return out = value_map[ret.getOperand(0)];
return out = get_or_translate(ret.getOperand(0));
}
std::vector<BValue> operands;
operands.reserve(ret.getNumOperands());
for (auto operand : ret.getOperands()) {
operands.push_back(value_map[operand]);
operands.push_back(get_or_translate(operand));
}
return out = fb.Tuple(operands);
})
.Case<BlockOutputOp>([&](BlockOutputOp block_out) {
// BlockOutputOp is handled separately by the block conversion.
// Just return a dummy valid BValue.
if (block_out.getNumOperands() > 0) {
return value_map[block_out.getOperand(0)];
return get_or_translate(block_out.getOperand(0));
}
return BValue();
})
Expand Down Expand Up @@ -1266,14 +1279,14 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
<< "register not found: " << write_op.getReg();
return BValue();
}
BValue data = value_map[write_op.getData()];
BValue data = get_or_translate(write_op.getData());
std::optional<BValue> load_enable;
if (write_op.getLoadEnable()) {
load_enable = value_map[write_op.getLoadEnable()];
load_enable = get_or_translate(write_op.getLoadEnable());
}
std::optional<BValue> reset;
if (write_op.getReset()) {
reset = value_map[write_op.getReset()];
reset = get_or_translate(write_op.getReset());
}
bb->RegisterWrite(reg, data, load_enable, reset,
translation_state.getLoc(write_op));
Expand Down Expand Up @@ -1308,13 +1321,17 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
if (child_block->GetResetPort().has_value() &&
port == *child_block->GetResetPort()) {
if (inst_op.getReset()) {
BValue rst_data = value_map[inst_op.getReset()];
BValue rst_data = get_or_translate(inst_op.getReset());
bb->InstantiationInput(inst, port->name(), rst_data,
translation_state.getLoc(inst_op));
} else if (bb->block()->GetResetPort().has_value()) {
BValue rst_data = BValue(*bb->block()->GetResetPort(), bb);
bb->InstantiationInput(inst, port->name(), rst_data,
translation_state.getLoc(inst_op));
}
continue;
}
BValue data = value_map[inst_op.getInputs()[input_idx]];
BValue data = get_or_translate(inst_op.getInputs()[input_idx]);
bb->InstantiationInput(inst, port->name(), data,
translation_state.getLoc(inst_op));
input_idx++;
Expand Down Expand Up @@ -1376,7 +1393,7 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
int input_idx = 0;
auto connectInput = [&](std::string_view port_name) {
if (input_idx < static_cast<int>(inst_op.getInputs().size())) {
BValue data = value_map[inst_op.getInputs()[input_idx]];
BValue data = get_or_translate(inst_op.getInputs()[input_idx]);
bb->InstantiationInput(inst, port_name, data,
translation_state.getLoc(inst_op));
input_idx++;
Expand Down Expand Up @@ -1410,7 +1427,7 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
.Case<NextValueOp>([&](NextValueOp next) {
// We just skip the next value op here as its handled in the eproc
// conversion.
return value_map[next.getValues()[0]];
return get_or_translate(next.getValues()[0]);
})
.Case<CallDslxOp>([&](CallDslxOp call) {
llvm::errs() << "Call remaining, call pass normalize-xls-calls "
Expand All @@ -1423,6 +1440,44 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
});
};

// Use lazy resolution to enable non-topologically sorted ops. A callback is
// used to allow for recursive calls to the convert/translation function
// without tieing those all into translation state.
get_or_translate = [&](Value val) -> BValue {
auto it = value_map.find(val);
if (it != value_map.end()) {
return it->second;
}
if (llvm::isa<BlockArgument>(val)) {
return BValue();
}
Operation* defining_op = val.getDefiningOp();
if (!defining_op) {
return BValue();
}

if (!active_conversions.insert(defining_op).second) {
defining_op->emitOpError(
"combinational cycle detected during translation");
return BValue();
}

BValue result = convert(defining_op);
active_conversions.erase(defining_op);

if (!result.valid()) {
return BValue();
}

if (defining_op->getNumResults() == 1 &&
!llvm::isa<BlockInstantiateOp, FifoInstantiateOp>(defining_op)) {
value_map[val] = result;
}

return result;
};
translation_state.setTranslateCallback(get_or_translate);

// Walk over the function and construct the XLS function.
auto result_walk = xls_region.walk([&](Operation* op) {
if (op == xls_region) {
Expand All @@ -1448,6 +1503,13 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
return WalkResult::skip();
}

// Skip if all results are already translated.
if (!op->getResults().empty() &&
llvm::all_of(op->getResults(),
[&](Value res) { return value_map.contains(res); })) {
return WalkResult::advance();
}

if (BValue r = convert(op); r.valid()) {
// BlockInstantiateOp/FifoInstantiateOp manage their own value_map
// entries for multiple results, so skip automatic assignment.
Expand All @@ -1464,6 +1526,9 @@ FailureOr<BValue> convertFunction(TranslationState& translation_state,
return WalkResult::interrupt();
});

// Reset to avoid dangling reference to local lamba and value_map.
translation_state.setTranslateCallback(nullptr);

if (result_walk.wasInterrupted()) {
return failure();
}
Expand Down
2 changes: 2 additions & 0 deletions xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,7 @@ Attribute translateXlsValueToAttribute(const ::xls::Value& value,
return nullptr;
}
llvm::SmallVector<Attribute> elements;
elements.reserve(value.elements().size());
for (const auto& elem : value.elements()) {
elements.push_back(translateXlsValueToAttribute(
elem, array_type.getElementType(), builder));
Expand All @@ -1710,6 +1711,7 @@ Attribute translateXlsValueToAttribute(const ::xls::Value& value,
return nullptr;
}
llvm::SmallVector<Attribute> elements;
elements.reserve(value.elements().size());
for (auto [i, elem] : llvm::enumerate(value.elements())) {
elements.push_back(
translateXlsValueToAttribute(elem, tuple_type.getType(i), builder));
Expand Down
Loading