Skip to content

Commit af1a306

Browse files
committed
Add support for "out of order" domain references
1 parent 619804a commit af1a306

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

lib/Dialect/FIRRTL/Transforms/InferDomains.cpp

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,7 @@ LogicalResult unify(Term *lhs, Term *rhs) {
391391
}
392392

393393
void solve(Term *lhs, Term *rhs) {
394-
auto result = unify(lhs, rhs);
395-
(void)result;
394+
[[maybe_unused]] auto result = unify(lhs, rhs);
396395
assert(result.succeeded());
397396
}
398397

@@ -532,6 +531,20 @@ Term *getTermForDomain(TermAllocator &allocator, DomainTable &table,
532531
return term;
533532
}
534533

534+
void processDomainDefinition(TermAllocator &allocator, DomainTable &table,
535+
Value domain) {
536+
assert(isa<DomainType>(domain.getType()));
537+
auto *newTerm = allocator.allocVal(domain);
538+
auto *oldTerm = table.getOptTermForDomain(domain);
539+
if (!oldTerm) {
540+
table.setTermForDomain(domain, newTerm);
541+
return;
542+
}
543+
544+
[[maybe_unused]] auto result = unify(oldTerm, newTerm);
545+
assert(result.succeeded());
546+
}
547+
535548
/// Get the row of domains that a hardware value in the IR is associated with.
536549
/// The returned term is forced to be at least a row.
537550
RowTerm *getDomainAssociationAsRow(const DomainInfo &info,
@@ -680,9 +693,8 @@ LogicalResult processModulePorts(const DomainInfo &info,
680693
if (isa<DomainType>(port.getType())) {
681694
auto typeID = info.getDomainTypeID(domainInfo, i);
682695
domainTypeIDTable[i] = typeID;
683-
if (module.getPortDirection(i) == Direction::In) {
684-
table.setTermForDomain(port, allocator.allocVal(port));
685-
}
696+
if (module.getPortDirection(i) == Direction::In)
697+
processDomainDefinition(allocator, table, port);
686698
continue;
687699
}
688700

@@ -723,12 +735,8 @@ LogicalResult processInstancePorts(const DomainInfo &info,
723735
if (isa<DomainType>(port.getType())) {
724736
auto typeID = info.getDomainTypeID(domainInfo, i);
725737
domainPortTypeIDTable[i] = typeID;
726-
if (op.getPortDirection(i) == Direction::Out) {
727-
table.setTermForDomain(port, allocator.allocVal(port));
728-
} else {
729-
table.setTermForDomain(port, allocator.allocVar());
730-
}
731-
continue;
738+
if (op.getPortDirection(i) == Direction::Out)
739+
processDomainDefinition(allocator, table, port);
732740
}
733741

734742
if (!isa<FIRRTLBaseType>(port.getType()))
@@ -1019,10 +1027,10 @@ void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns,
10191027
ip, loc, find(term), exports);
10201028
}
10211029

1022-
void getUpdatesForModulePorts(const DomainInfo &info, Namespace &ns,
1023-
TermAllocator &allocator,
1030+
void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator,
10241031
const ExportTable &exports, DomainTable &table,
1025-
FModuleOp module, PendingUpdates &pending) {
1032+
Namespace &ns, FModuleOp module,
1033+
PendingUpdates &pending) {
10261034
for (size_t i = 0, e = module.getNumPorts(); i < e; ++i) {
10271035
auto port = module.getArgument(i);
10281036
auto type = port.getType();
@@ -1038,16 +1046,17 @@ void getUpdatesForModulePorts(const DomainInfo &info, Namespace &ns,
10381046
/// is unsolved, solve the variable by adding an input port to the pending
10391047
/// updates.
10401048
template <typename T>
1041-
void getUpdatesForInstance(const DomainInfo &info, const DomainTable &table,
1042-
Namespace &ns, size_t ip, PendingUpdates &pending,
1043-
T op) {
1049+
void getUpdatesForInstance(const DomainInfo &info, TermAllocator &allocator,
1050+
DomainTable &table, Namespace &ns, size_t ip,
1051+
PendingUpdates &pending, T op) {
10441052
for (size_t i = 0, e = op.getNumResults(); i < e; ++i) {
10451053
auto result = op.getResult(i);
10461054
if (!isa<DomainType>(result.getType()) ||
10471055
op.getPortDirection(i) == Direction::Out)
10481056
continue;
10491057

1050-
auto *var = dyn_cast<VariableTerm>(table.getTermForDomain(result));
1058+
auto *term = getTermForDomain(allocator, table, result);
1059+
auto *var = dyn_cast<VariableTerm>(term);
10511060
if (!var)
10521061
continue;
10531062

@@ -1057,21 +1066,21 @@ void getUpdatesForInstance(const DomainInfo &info, const DomainTable &table,
10571066
}
10581067
}
10591068

1060-
void getUpdatesForOp(const DomainInfo &info, const DomainTable &table,
1061-
Namespace &ns, size_t ip, PendingUpdates &pending,
1062-
Operation *op) {
1069+
void getUpdatesForOp(const DomainInfo &info, TermAllocator &allocator,
1070+
DomainTable &table, Namespace &ns, size_t ip,
1071+
PendingUpdates &pending, Operation *op) {
10631072
if (auto inst = dyn_cast<InstanceOp>(op))
1064-
return getUpdatesForInstance(info, table, ns, ip, pending, inst);
1073+
return getUpdatesForInstance(info, allocator, table, ns, ip, pending, inst);
10651074
if (auto inst = dyn_cast<InstanceChoiceOp>(op))
1066-
return getUpdatesForInstance(info, table, ns, ip, pending, inst);
1075+
return getUpdatesForInstance(info, allocator, table, ns, ip, pending, inst);
10671076
}
10681077

1069-
void getUpdatesForModuleBody(const DomainInfo &info, const DomainTable &table,
1070-
Namespace &ns, FModuleOp mod,
1078+
void getUpdatesForModuleBody(const DomainInfo &info, TermAllocator &allocator,
1079+
DomainTable &table, Namespace &ns, FModuleOp mod,
10711080
PendingUpdates &pending) {
10721081
auto ip = mod.getNumPorts();
10731082
mod->walk([&](Operation *op) {
1074-
getUpdatesForOp(info, table, ns, ip, pending, op);
1083+
getUpdatesForOp(info, allocator, table, ns, ip, pending, op);
10751084
});
10761085
}
10771086

@@ -1082,9 +1091,8 @@ void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator,
10821091
auto names = mod.getPortNamesAttr();
10831092
for (auto name : names.getAsRange<StringAttr>())
10841093
ns.add(name);
1085-
1086-
getUpdatesForModulePorts(info, ns, allocator, exports, table, mod, pending);
1087-
getUpdatesForModuleBody(info, table, ns, mod, pending);
1094+
getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending);
1095+
getUpdatesForModuleBody(info, allocator, table, ns, mod, pending);
10881096
}
10891097

10901098
void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator,
@@ -1477,7 +1485,6 @@ LogicalResult checkAndInferModule(const DomainInfo &info,
14771485
return updateModuleBody(table, module);
14781486
}
14791487

1480-
14811488
LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info,
14821489
ModuleUpdateTable &updateTable, Operation *op) {
14831490
if (auto module = dyn_cast<FModuleOp>(op)) {

0 commit comments

Comments
 (0)