Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/circt/Dialect/HW/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,25 @@ def HWAggregateToComb : Pass<"hw-aggregate-to-comb", "hw::HWModuleOp"> {
let dependentDialects = ["comb::CombDialect"];
}

def HWConstProp : Pass<
"hw-constprop", "mlir::ModuleOp"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we name this as imconstprop for consistency with FIRRTL?

> {
let summary = "Inter-module constant propagation";
let description = [{
This pass performs inter-module constant propagation for HW modules.
It propagates constant values across module boundaries through instances,
and folds operations with constant operands.
}];
let dependentDialects = [
"comb::CombDialect",
"hw::HWDialect"
];
let statistics = [
Statistic<"numValuesFolded", "num-values-folded",
"Number of values folded to a constant">,
Statistic<"numOpsErased", "num-ops-erased",
"Number of dead ops erased">,
];
}

#endif // CIRCT_DIALECT_HW_PASSES_TD
1 change: 1 addition & 0 deletions lib/Dialect/HW/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_circt_dialect_library(CIRCTHWTransforms
HWAggregateToComb.cpp
HWConstProp.cpp
HWPrintInstanceGraph.cpp
HWSpecialize.cpp
PrintHWModuleGraph.cpp
Expand Down
334 changes: 334 additions & 0 deletions lib/Dialect/HW/Transforms/HWConstProp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
//===- HWConstProp.cpp - Inter-module constant propagation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the `HWConstProp` pass.
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWInstanceGraph.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWPasses.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "hw-constprop"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When changing to imconstprop, make that change here.

Also, I figured out a way to not have to hard-code this, but it requires reorganizing the file a little bit. See how this works here: https://github.com/llvm/circt/blob/main/lib/Dialect/FIRRTL/Transforms/LowerDomains.cpp#L69


using namespace mlir;
using namespace circt;
using namespace hw;

//===----------------------------------------------------------------------===//
// Constant propagation helper
//===----------------------------------------------------------------------===//

namespace {
class ConstantPropagation {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include comments for this and for all member functions/members.

I'm still grokking why there are three queues here and this would help with that!

public:
ConstantPropagation(hw::InstanceGraph &graph) : graph(graph) {}

void initialize(HWModuleOp module);
void propagate();
void markUnknownValuesOverdefined(hw::HWModuleOp module);
std::pair<unsigned, unsigned> fold();

public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This public is unnecessary as this is already in a public scope.

void enqueue(Value value, IntegerAttr attr);
void mark(Value value, IntegerAttr attr);
void markOverdefined(Value value) { mark(value, IntegerAttr{}); }
void propagate(Operation *op);

/**
* Returns the lattice value associated with an SSA value.
*
* `std::nullopt` is unknown, `IntegerAttr{}` is overdefined.
*/
Comment on lines +47 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this needs doxygen-style comments not Javadoc style, i.e., ///.

std::optional<IntegerAttr> map(Value value);

private:
hw::InstanceGraph &graph;
DenseMap<Value, IntegerAttr> values;
DenseSet<Operation *> inQueue;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling this a "queue" is a bit disingenuous.

SmallVector<Operation *> overdefQueue;
SmallVector<Operation *> valueQueue;
};
} // namespace

void ConstantPropagation::initialize(HWModuleOp module) {
if (module.isPublic()) {
// Mark public module inputs as overdefined.
for (auto arg : module.getBodyBlock()->getArguments())
markOverdefined(arg);
}

module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<ConstantOp>([&](auto cst) {
// Constants are omitted from the mapping, but their
// users are enqueued for propagation.
enqueue(cst, cst.getValueAttr());
})
.Case<HWInstanceLike>([&](auto inst) {
// Mark external/generated module outputs as overdefined.
bool hasUnknownTarget = llvm::any_of(
inst.getReferencedModuleNamesAttr(), [&](Attribute ref) {
Operation *referencedOp =
graph.lookup(cast<StringAttr>(ref))->getModule();
auto module = dyn_cast_or_null<HWModuleOp>(referencedOp);
return !module;
Comment on lines +83 to +84
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be:

Suggested change
auto module = dyn_cast_or_null<HWModuleOp>(referencedOp);
return !module;
return !isa_and_nonnull<HWModuleOp>(referencedOp);

});

if (hasUnknownTarget) {
for (auto result : inst->getResults())
markOverdefined(result);
}
})
.Case<hw::WireOp>([&](auto wire) {
// Mark wires as overdefined since they can be targeted by force.
markOverdefined(wire.getResult());
Comment on lines +93 to +94
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a less-strict check that can work here? Like, just if this has an inner symbol?

})
.Default([&](auto op) {
if (op->getNumResults() == 0)
return;
// Mark all non-comb ops and non-integer types as overdefined.
bool isFoldable = hw::isCombinational(op);
for (auto result : op->getResults()) {
Type ty = result.getType();
if (!hw::type_isa<IntegerType>(ty) || !isFoldable)
markOverdefined(result);
}
});
});
}

void ConstantPropagation::mark(Value value, IntegerAttr attr) {
auto it = values.try_emplace(value, attr);
if (!it.second) {
if (it.first->second == attr)
return;
attr = it.first->second = IntegerAttr{};
}
enqueue(value, attr);
}

void ConstantPropagation::enqueue(Value value, IntegerAttr attr) {
for (Operation *user : value.getUsers()) {
if (inQueue.insert(user).second) {
if (attr) {
valueQueue.push_back(user);
} else {
overdefQueue.push_back(user);
}
Comment on lines +123 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
if (attr) {
valueQueue.push_back(user);
} else {
overdefQueue.push_back(user);
}
if (attr)
valueQueue.push_back(user);
else
overdefQueue.push_back(user);

}
}
}

std::optional<IntegerAttr> ConstantPropagation::map(Value value) {
if (auto constant = value.getDefiningOp<hw::ConstantOp>())
return constant.getValueAttr();

auto it = values.find(value);
if (it == values.end())
return std::nullopt;

return it->second;
}
Comment on lines +132 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling this map seems a bit odd given that it's "getMaybeConstant". (That is a terrible name!) Something describing what this is doing would be better.


void ConstantPropagation::propagate() {
while (!overdefQueue.empty() || !valueQueue.empty()) {
while (!overdefQueue.empty()) {
auto *op = overdefQueue.pop_back_val();
inQueue.erase(op);
propagate(op);
}
while (!valueQueue.empty()) {
auto *op = valueQueue.pop_back_val();
inQueue.erase(op);
propagate(op);
}
}
}

void ConstantPropagation::markUnknownValuesOverdefined(hw::HWModuleOp module) {
for (auto arg : module.getBodyBlock()->getArguments()) {
if (!map(arg))
markOverdefined(arg);
}
module.walk([&](Operation *op) {
for (auto result : op->getResults()) {
if (!map(result))
markOverdefined(result);
}
});
}

void ConstantPropagation::propagate(Operation *op) {
if (auto output = dyn_cast<OutputOp>(op)) {
auto module = op->getParentOfType<HWModuleOp>();
for (auto *node : graph[module]->uses()) {
Operation *instLike = node->getInstance();
if (!instLike)
continue;

auto inst = cast<HWInstanceLike>(instLike);
for (auto [op, res] :
llvm::zip(output.getOutputs(), inst->getResults())) {
if (auto attr = map(op))
mark(res, *attr);
}
}
return;
}

if (auto inst = dyn_cast<HWInstanceLike>(op)) {
for (auto ref : inst.getReferencedModuleNamesAttr()) {
Operation *referencedOp =
graph.lookup(cast<StringAttr>(ref))->getModule();
auto module = dyn_cast_or_null<HWModuleOp>(referencedOp);
if (!module)
continue;

Block *body = module.getBodyBlock();
for (auto [op, arg] :
llvm::zip(inst->getOperands(), body->getArguments())) {
if (auto attr = map(op))
mark(arg, *attr);
}
}
return;
}

SmallVector<Attribute> operands;
for (auto op : op->getOperands()) {
auto attr = map(op);
if (!attr)
return;
operands.push_back(*attr);
}

SmallVector<OpFoldResult, 1> results;
if (succeeded(op->fold(operands, results)) && !results.empty()) {
for (auto [res, value] : llvm::zip(op->getResults(), results)) {
if (auto attr = dyn_cast<Attribute>(value)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
mark(res, intAttr);
continue;
}
}
mark(res, {});
}
} else {
for (auto res : op->getResults()) {
mark(res, {});
}
}
}

std::pair<unsigned, unsigned> ConstantPropagation::fold() {
// Cache new constants in each module. Traverse the circuit to
// populate the mapping with values to re-use.
DenseMap<std::pair<HWModuleOp, IntegerAttr>, Value> constants;
for (auto *node : graph) {
Operation *moduleOp = node->getModule();
if (!moduleOp)
continue;
auto module = dyn_cast<HWModuleOp>(moduleOp);
if (!module)
continue;
for (Operation &op : *module.getBodyBlock()) {
if (auto cst = dyn_cast<ConstantOp>(&op)) {
constants.try_emplace({module, cst.getValueAttr()}, cst);
}
}
}

// Traverse the mapping from values to lattices and replace with constants.
DenseSet<Operation *> toDelete;
unsigned numFolded = 0;
for (auto [value, attr] : values) {
if (!attr)
continue;

ImplicitLocOpBuilder builder(value.getLoc(), value.getContext());

HWModuleOp parent;
if (auto arg = dyn_cast<BlockArgument>(value)) {
parent = cast<HWModuleOp>(arg.getOwner()->getParentOp());
} else {
parent = value.getDefiningOp()->getParentOfType<HWModuleOp>();
}

auto it = constants.try_emplace({parent, attr}, Value{});
if (it.second) {
builder.setInsertionPointToStart(parent.getBodyBlock());
it.first->second = builder.create<ConstantOp>(value.getType(), attr);
}

value.replaceAllUsesWith(it.first->second);
LLVM_DEBUG({
llvm::dbgs() << "In " << parent.getModuleName() << ": Replace with "
<< attr << ": " << value << '\n';
});

++numFolded;

if (auto *op = value.getDefiningOp()) {
if (op->use_empty() && mlir::isMemoryEffectFree(op)) {
toDelete.insert(op);
}
}
}

for (Operation *op : toDelete)
op->erase();

return {numFolded, (unsigned)toDelete.size()};
}

//===----------------------------------------------------------------------===//
// Pass Infrastructure
//===----------------------------------------------------------------------===//

namespace circt {
namespace hw {
#define GEN_PASS_DEF_HWCONSTPROP
#include "circt/Dialect/HW/Passes.h.inc"
} // namespace hw
} // namespace circt

namespace {
struct HWConstPropPass
: public circt::hw::impl::HWConstPropBase<HWConstPropPass> {
void runOnOperation() override;
};
} // namespace

void HWConstPropPass::runOnOperation() {
ConstantPropagation prop(getAnalysis<hw::InstanceGraph>());

for (auto module : getOperation().getOps<HWModuleOp>())
prop.initialize(module);

prop.propagate();

// Lattice states may remain overly optimistic due to dependency cycles
// that can occur in non-Chisel designs. To address this, replace unknown
// values with overdefined ones.
for (auto module : getOperation().getOps<HWModuleOp>())
prop.markUnknownValuesOverdefined(module);

// Propagate again to fold constants that were overdefined before.
prop.propagate();

auto [numFolded, numErased] = prop.fold();
numValuesFolded += numFolded;
numOpsErased += numErased;

markAnalysesPreserved<hw::InstanceGraph>();
}
Loading