Skip to content

Commit

Permalink
[Handshake] Adding func instance op for integration
Browse files Browse the repository at this point in the history
Adds the ESIInstanceOp. This op allows a non-handshake design to
instantiate `handshake.func`s. Since handshake needs elastic inputs and
produces elastic outputs, this uses ESI channels.
  • Loading branch information
teqdruid committed Nov 14, 2024
1 parent b72a394 commit adc9288
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/circt/Dialect/Handshake/HandshakeOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "circt/Dialect/Handshake/HandshakeDialect.h"
#include "circt/Dialect/Handshake/HandshakeInterfaces.h"
#include "circt/Dialect/Seq/SeqTypes.h"
#include "circt/Support/LLVM.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down
51 changes: 51 additions & 0 deletions include/circt/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

include "circt/Dialect/ESI/ESITypes.td"
include "circt/Dialect/Seq/SeqTypes.td"

// @mortbopet: some kind of support for interfaces as parent ops is currently
// being tracked here: https://github.com/llvm/llvm-project/pull/66196
class HasParentInterface<string interface>
Expand Down Expand Up @@ -138,6 +141,54 @@ def FuncOp : Op<Handshake_Dialect, "func", [
let hasCustomAssemblyFormat = 1;
}

def ESIInstanceOp : Op<Handshake_Dialect, "esi_instance", [
CallOpInterface,
HasClock,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Instantiate a Handshake circuit";
let description = [{
Instantiate (call) a Handshake function in a non-Handshake design using ESI
channels as the outside connections.
}];
let arguments = (ins FlatSymbolRefAttr:$module, StrAttr:$instName,
ClockType:$clk, I1:$rst,
Variadic<ChannelType>:$opOperands);
let results = (outs Variadic<ChannelType>);

let assemblyFormat = [{
$module $instName `clk` $clk `rst` $rst
`(` $opOperands `)` attr-dict `:` functional-type($opOperands, results)
}];

let extraClassDeclaration = [{
// Account for `clk` and `rst` operands vs call arguments.
static constexpr int NumFixedOperands = 2;

/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}

operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }

/// Return the module of this operation.
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("module");
}

/// Set the callee for this operation.
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
(*this)->setAttr(getModuleAttrName(), callee.get<mlir::SymbolRefAttr>());
}

MutableOperandRange getArgOperandsMutable() {
return getOpOperandsMutable();
}
}];
}

// InstanceOp
def InstanceOp : Handshake_Op<"instance", [
CallOpInterface,
Expand Down
39 changes: 39 additions & 0 deletions lib/Conversion/HandshakeToHW/HandshakeToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/ESI/ESIOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWSymCache.h"
#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
Expand Down Expand Up @@ -1054,6 +1055,32 @@ class InstanceConversionPattern
}
};

class ESIInstanceConversionPattern
: public OpConversionPattern<handshake::ESIInstanceOp> {
public:
ESIInstanceConversionPattern(MLIRContext *context,
const HWSymbolCache &symCache)
: OpConversionPattern(context), symCache(symCache) {}

LogicalResult
matchAndRewrite(ESIInstanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> operands;
for (size_t i = ESIInstanceOp::NumFixedOperands, e = op.getNumOperands();
i < e; ++i)
operands.push_back(adaptor.getOperands()[i]);
operands.push_back(adaptor.getClk());
operands.push_back(adaptor.getRst());
Operation *targetModule = symCache.getDefinition(op.getModuleAttr());
rewriter.replaceOpWithNewOp<hw::InstanceOp>(op, targetModule,
op.getInstNameAttr(), operands);
return success();
}

private:
const HWSymbolCache &symCache;
};

class ReturnConversionPattern
: public OpConversionPattern<handshake::ReturnOp> {
public:
Expand Down Expand Up @@ -1976,6 +2003,18 @@ class HandshakeToHWPass
for (auto hwModule : mod.getOps<hw::HWModuleOp>())
if (failed(convertExtMemoryOps(hwModule)))
return signalPassFailure();

// Run conversions which need see everything.
HWSymbolCache symbolCache;
symbolCache.addDefinitions(mod);
symbolCache.freeze();
RewritePatternSet patterns(mod.getContext());
patterns.insert<ESIInstanceConversionPattern>(mod.getContext(),
symbolCache);
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
mod->emitOpError() << "error during conversion";
signalPassFailure();
}
}
};
} // end anonymous namespace
Expand Down
51 changes: 51 additions & 0 deletions lib/Dialect/Handshake/HandshakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "circt/Dialect/ESI/ESITypes.h"
#include "circt/Support/LLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -1335,6 +1336,56 @@ void JoinOp::print(OpAsmPrinter &p) {
p << " : " << getData().getTypes();
}

LogicalResult
ESIInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the module attribute was specified.
auto fnAttr = this->getModuleAttr();
assert(fnAttr && "requires a 'module' symbol reference attribute");

FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid handshake function";

// Verify that the operand and result types match the callee.
auto fnType = fn.getFunctionType();
if (fnType.getNumInputs() != getNumOperands() - NumFixedOperands)
return emitOpError(
"incorrect number of operands for the referenced handshake function");

for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
Type operandType = getOperand(i + NumFixedOperands).getType();
auto channelType = dyn_cast<esi::ChannelType>(operandType);
if (!channelType)
return emitOpError("operand type mismatch: expected channel type, but "
"provided ")
<< operandType << " for operand number " << i;
if (channelType.getInner() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
}

if (fnType.getNumResults() != getNumResults())
return emitOpError(
"incorrect number of results for the referenced handshake function");

for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
Type resultType = getResult(i).getType();
auto channelType = dyn_cast<esi::ChannelType>(resultType);
if (!channelType)
return emitOpError("result type mismatch: expected channel type, but "
"provided ")
<< resultType << " for result number " << i;
if (channelType.getInner() != fnType.getResult(i))
return emitOpError("result type mismatch: expected result type ")
<< fnType.getResult(i) << ", but provided "
<< getResult(i).getType() << " for result number " << i;
}

return success();
}

/// Based on mlir::func::CallOp::verifySymbolUses
LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the module attribute was specified.
Expand Down
11 changes: 11 additions & 0 deletions test/Conversion/HandshakeToHW/test_instance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,14 @@ handshake.func @bar(%in : i32) -> (i32) {
%out = handshake.instance @foo(%in) : (i32) -> (i32)
handshake.return %out : i32
}

// -----

handshake.func @foo(%ctrl : i32) -> i32 {
return %ctrl : i32
}

hw.module @outer(in %clk: !seq.clock, in %rst: i1, in %ctrl: !esi.channel<i32>, out out: !esi.channel<i32>) {
%ret = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst (%ctrl) : (!esi.channel<i32>) -> (!esi.channel<i32>)
hw.output %ret : !esi.channel<i32>
}
21 changes: 21 additions & 0 deletions test/Dialect/Handshake/call.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,24 @@ handshake.func @invalid_instance_op(%arg0 : i32) -> i32 {
instance @foo(%arg0) : (i32) -> (i32)
return %arg0 : i32
}

// -----

// CHECK-LABEL: handshake.func @foo(
// CHECK-SAME: %[[VAL_0:.*]]: i32, ...) -> i32
// CHECK: return %[[VAL_0]] : i32
// CHECK: }

// CHECK-LABEL: hw.module @outer(in %clk : !seq.clock, in %rst : i1, in %ctrl : !esi.channel<i32>, out out : !esi.channel<i32>) {
// CHECK-NEXT: [[R0:%.+]] = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst(%ctrl) : (!esi.channel<i32>) -> !esi.channel<i32>
// CHECK-NEXT: hw.output [[R0]] : !esi.channel<i32>


handshake.func @foo(%ctrl : i32) -> i32 {
return %ctrl : i32
}

hw.module @outer(in %clk: !seq.clock, in %rst: i1, in %ctrl: !esi.channel<i32>, out out: !esi.channel<i32>) {
%ret = handshake.esi_instance @foo "foo_inst" clk %clk rst %rst (%ctrl) : (!esi.channel<i32>) -> (!esi.channel<i32>)
hw.output %ret : !esi.channel<i32>
}

0 comments on commit adc9288

Please sign in to comment.