From 37b40b44dcf48f297721b9367bf2e75bfa055eb6 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:09:18 +0530 Subject: [PATCH 01/12] [MIPS] Add MIPS dialect with mips.matmul op Introduces a new MIPS dialect that acts as a semantic abstraction layer for hardware-specific matrix multiply operations inside the IREE compiler. Key components: - IR/MIPSBase.td / MIPSDialect.h/.cpp: dialect definition, namespace ::mlir::iree_compiler::IREE::MIPS, dependent on func/memref/tensor. - IR/MIPSOps.td / MIPSOps.h/.cpp: mips.matmul op, tensor-only, Destination-Passing-Style (DPS). Verifier checks 2-D tensor shapes (lhs[MxK], rhs[KxN], init[MxN]) and element-type consistency. Implements ReifyRankedShapedTypeOpInterface and MemoryEffectsOpInterface. - IR/MIPSBufferizableOpInterface.cpp: eliminates mips.matmul entirely during One-Shot Bufferize by emitting func.call @my_matmul_kernel directly with the memref (base_ptr, offset, stride0, stride1) ABI. Uses memref.memory_space_cast to strip IREE HAL memory spaces from the base pointers so the function declaration stays stable across eraseHALDescriptorTypeFromMemRef. - Transforms/Passes.td/.h/.cpp: pass registry for LowerMIPSToFuncCallPass (now a no-op; kept for pipeline compatibility since bufferization handles the lowering directly). - Transforms/LowerMIPSToFuncCall.cpp: no-op pass stub. - DispatchCreation/FormDispatchRegions.cpp: teach IREE's dispatch formation to treat mips.matmul as a compute-heavy op eligible for outlining into flow.dispatch.workgroups. - Tools/init_iree_dialects.h: register MIPSDialect and its BufferizableOpInterface external models. - Tools/init_iree_passes.h: register MIPS passes with the global pass registry. --- .../iree/compiler/Dialect/MIPS/CMakeLists.txt | 8 + .../compiler/Dialect/MIPS/IR/CMakeLists.txt | 70 ++++++ .../iree/compiler/Dialect/MIPS/IR/MIPSBase.td | 47 ++++ .../MIPS/IR/MIPSBufferizableOpInterface.cpp | 216 ++++++++++++++++++ .../compiler/Dialect/MIPS/IR/MIPSDialect.cpp | 56 +++++ .../compiler/Dialect/MIPS/IR/MIPSDialect.h | 28 +++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp | 98 ++++++++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.h | 25 ++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.td | 71 ++++++ .../Dialect/MIPS/Transforms/CMakeLists.txt | 56 +++++ .../MIPS/Transforms/LowerMIPSToFuncCall.cpp | 36 +++ .../Dialect/MIPS/Transforms/Passes.cpp | 18 ++ .../compiler/Dialect/MIPS/Transforms/Passes.h | 33 +++ .../Dialect/MIPS/Transforms/Passes.td | 49 ++++ .../MIPS/Transforms/test/CMakeLists.txt | 15 ++ .../Transforms/test/lower_to_func_call.mlir | 49 ++++ .../compiler/DispatchCreation/CMakeLists.txt | 1 + .../DispatchCreation/FormDispatchRegions.cpp | 4 + .../iree/compiler/Tools/init_iree_dialects.h | 3 + .../iree/compiler/Tools/init_iree_passes.h | 2 + 20 files changed, 885 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir diff --git a/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt new file mode 100644 index 000000000000..3da5a7a85912 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Recursively picks up IR/ and Transforms/ subdirectories. +iree_add_all_subdirs() diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt new file mode 100644 index 000000000000..632b1ee5b190 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +# ─── Tablegen ──────────────────────────────────────────────────────────────── +# Generate op declarations/definitions AND dialect declarations/definitions +# from a single MIPSOps.td source (same pattern as LinalgExt). + +iree_tablegen_library( + NAME + MIPSOpsIncGen + TD_FILE + "MIPSOps.td" + OUTS + --gen-op-decls MIPSOps.h.inc + --gen-op-defs MIPSOps.cpp.inc + --dialect=mips --gen-dialect-decls MIPSDialect.h.inc + --dialect=mips --gen-dialect-defs MIPSDialect.cpp.inc +) + +# ─── C++ library ───────────────────────────────────────────────────────────── + +iree_cc_library( + NAME + IR + HDRS + "MIPSDialect.h" + "MIPSOps.h" + "MIPSDialect.h.inc" + TEXTUAL_HDRS + "MIPSOps.h.inc" + "MIPSOps.cpp.inc" + SRCS + "MIPSDialect.cpp" + "MIPSDialect.cpp.inc" + "MIPSOps.cpp" + "MIPSBufferizableOpInterface.cpp" + DEPS + ::MIPSOpsIncGen + LLVMSupport + MLIRIR + MLIRSupport + MLIRFuncDialect + MLIRMemRefDialect + MLIRTensorDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + MLIRTensorUtils + MLIRTransforms + PUBLIC +) + +# ─── Documentation ─────────────────────────────────────────────────────────── + +iree_tablegen_doc( + NAME + MIPSDialectDocGen + CATEGORY "Dialects" + TD_FILE + "MIPSOps.td" + OUTS + --gen-dialect-doc -dialect=mips MIPSDialect.md +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td new file mode 100644 index 000000000000..fd006c3bfcd7 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td @@ -0,0 +1,47 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_BASE +#define IREE_DIALECT_MIPS_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// MIPS dialect definition +//===----------------------------------------------------------------------===// + +def MIPS_Dialect : Dialect { + let name = "mips"; + let cppNamespace = "::mlir::iree_compiler::IREE::MIPS"; + let summary = "MIPS custom compute dialect for experimental matmul pipeline."; + let description = [{ + The MIPS dialect defines a single `mips.matmul` operation that serves as + a custom intermediate representation between the Torch frontend and the + final `func.call @my_matmul_kernel` generated by the MIPS lowering pass. + + Pipeline: + torch.aten.mm → mips.matmul → func.call @my_matmul_kernel + }]; + + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::memref::MemRefDialect", + "::mlir::tensor::TensorDialect" + ]; + + // No custom attribute types → do not declare parseAttribute/printAttribute + // overrides. The base Dialect class handles the fallback behavior. + let useDefaultAttributePrinterParser = 0; +} + +//===----------------------------------------------------------------------===// +// Base op class +//===----------------------------------------------------------------------===// + +class MIPS_Op traits = []> + : Op; + +#endif // IREE_DIALECT_MIPS_BASE diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp new file mode 100644 index 000000000000..e5499939bcc3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp @@ -0,0 +1,216 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Implements BufferizableOpInterface for mips.matmul. +// +// mips.matmul is a tensor-only, Destination-Passing-Style (DPS) op. It is +// eliminated *entirely* during One-Shot Bufferize: bufferize() obtains memref +// buffers for all three operands, decomposes each 2-D memref into +// (base_ptr, offset, stride0, stride1) via memref.extract_strided_metadata, +// and emits a func.call @my_matmul_kernel directly. No memref form of +// mips.matmul ever exists in the IR. +// +// Before bufferization: +// %C = mips.matmul %A, %B, %init +// : tensor, tensor, tensor -> tensor +// +// After bufferization (produced inside bufferize()): +// %A_meta = memref.extract_strided_metadata %A_buf -> (base, off, s0, s1) +// %B_meta = memref.extract_strided_metadata %B_buf -> (base, off, s0, s1) +// %C_meta = memref.extract_strided_metadata %C_buf -> (base, off, s0, s1) +// %M = memref.dim %A_buf, 0 +// %N = memref.dim %B_buf, 1 +// %K = memref.dim %A_buf, 1 +// call @my_matmul_kernel(%A_base, %A_off, %A_s0, %A_s1, +// %B_base, %B_off, %B_s0, %B_s1, +// %C_base, %C_off, %C_s0, %C_s1, +// %M, %N, %K) +// -- tensor result replaced by %C_buf via replaceOpWithBufferizedValues -- + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir::iree_compiler::IREE::MIPS { +namespace { + +static constexpr StringLiteral kKernelName = "my_matmul_kernel"; + +//===----------------------------------------------------------------------===// +// Helper: ensure func.func private @my_matmul_kernel exists at module scope. +// +// The declaration carries {llvm.bareptr = true} so the LLVM backend passes +// bare float* arguments instead of MLIR memref descriptor structs, matching +// the C kernel ABI. +//===----------------------------------------------------------------------===// + +static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, + Operation *moduleOp, + FunctionType fnType, + Location loc) { + if (auto existing = dyn_cast_if_present( + SymbolTable::lookupSymbolIn(moduleOp, kKernelName))) + return existing; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + auto fnDecl = func::FuncOp::create(rewriter, loc, kKernelName, fnType); + SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private); + fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); + return fnDecl; +} + +//===----------------------------------------------------------------------===// +// Helper: decompose a 2-D memref into (base_ptr, offset, stride0, stride1). +// +// Uses memref.extract_strided_metadata. The base_ptr is always a rank-0 +// memref with DEFAULT address space (memref), regardless of the source +// memref's address space. Any IREE-specific memory space (e.g. +// #hal.descriptor_type) is stripped via +// memref.memory_space_cast so that: +// +// 1. The function declaration uses plain memref, which is stable across +// all pipeline stages. +// 2. eraseHALDescriptorTypeFromMemRefPass (which runs after bufferization and +// does NOT update external function declarations) cannot introduce a +// type mismatch between the call operands and the declaration. +// +// Combined with the {llvm.bareptr = true} attribute on the callee, the +// rank-0 memref lowers to a bare float* matching the C ABI. +//===----------------------------------------------------------------------===// + +static void decomposeMemref2D(RewriterBase &rewriter, Location loc, + Value memref2D, + SmallVectorImpl &callOperands, + SmallVectorImpl &callArgTypes) { + Type indexType = IndexType::get(rewriter.getContext()); + + auto meta = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref2D); + + // Strip any IREE-specific memory space from the base pointer so the + // function declaration stays in the default address space. + Value basePtr = meta.getBaseBuffer(); + auto basePtrMemrefTy = cast(basePtr.getType()); + MemRefType plainBasePtrTy = + MemRefType::get(/*shape=*/{}, basePtrMemrefTy.getElementType()); + if (basePtrMemrefTy != plainBasePtrTy) { + basePtr = memref::MemorySpaceCastOp::create(rewriter, loc, plainBasePtrTy, + basePtr); + } + + callOperands.push_back(basePtr); + callArgTypes.push_back(plainBasePtrTy); + + callOperands.push_back(meta.getOffset()); + callArgTypes.push_back(indexType); + + for (Value stride : meta.getStrides()) { + callOperands.push_back(stride); + callArgTypes.push_back(indexType); + } +} + +//===----------------------------------------------------------------------===// +// External model — BufferizableOpInterface for mips.matmul. +// +// Inherits from DstBufferizableOpInterfaceExternalModel which automatically +// handles the DPS aliasing (init ↔ result) and write detection for the init +// operand. We override bufferizesToMemoryRead to mark lhs and rhs as read, +// and provide a custom bufferize() that emits func.call @my_matmul_kernel. +//===----------------------------------------------------------------------===// + +struct MIPSMatmulBufferizableOpInterface + : public DstBufferizableOpInterfaceExternalModel< + MIPSMatmulBufferizableOpInterface, MatmulOp> { + + // All three operands are read by the kernel. + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto matmulOp = cast(op); + Location loc = matmulOp.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(matmulOp); + + // Obtain memref buffers for all three tensor operands. + FailureOr lhsBuf = + getBuffer(rewriter, matmulOp.getLhs(), options, state); + if (failed(lhsBuf)) + return failure(); + FailureOr rhsBuf = + getBuffer(rewriter, matmulOp.getRhs(), options, state); + if (failed(rhsBuf)) + return failure(); + // init aliases with result — one-shot bufferize allocates the output buffer + // (via bufferization.alloc_tensor or in-place analysis) and gives it to us + // here as initBuf. + FailureOr initBuf = + getBuffer(rewriter, matmulOp.getInit(), options, state); + if (failed(initBuf)) + return failure(); + + // Build the flattened argument list for func.call @my_matmul_kernel. + // For each 2-D memref: (base_ptr, offset, stride0, stride1) + // Then: M, N, K as index scalars. + SmallVector callOperands; + SmallVector callArgTypes; + + decomposeMemref2D(rewriter, loc, *lhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *rhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *initBuf, callOperands, callArgTypes); + + Type indexType = IndexType::get(ctx); + Value M = memref::DimOp::create(rewriter, loc, *lhsBuf, 0); + Value N = memref::DimOp::create(rewriter, loc, *rhsBuf, 1); + Value K = memref::DimOp::create(rewriter, loc, *lhsBuf, 1); + callOperands.append({M, N, K}); + callArgTypes.append(3, indexType); + + // Declare the kernel function in the enclosing module (idempotent). + Operation *moduleOp = SymbolTable::getNearestSymbolTable(matmulOp); + FunctionType fnType = rewriter.getFunctionType(callArgTypes, TypeRange{}); + ensureKernelDeclaration(rewriter, moduleOp, fnType, loc); + + // Emit the call — the kernel writes into *initBuf in place. + func::CallOp::create(rewriter, loc, kKernelName, TypeRange{}, callOperands); + + // Replace the tensor result with the init buffer (DPS aliasing). + replaceOpWithBufferizedValues(rewriter, op, *initBuf); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public registration entry point +//===----------------------------------------------------------------------===// + +void registerMIPSBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MIPSDialect * /*dialect*/) { + MatmulOp::attachInterface(*ctx); + }); +} + +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp new file mode 100644 index 000000000000..f136f11594fa --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp @@ -0,0 +1,56 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// Inliner interface — allow MIPS ops to be inlined unconditionally. +//===----------------------------------------------------------------------===// + +namespace { +struct MIPSInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Dialect initialize +//===----------------------------------------------------------------------===// + +void MIPSDialect::initialize() { + addInterfaces(); + +#define GET_OP_LIST + addOperations< +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" + >(); +} + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h new file mode 100644 index 000000000000..aa99d0a63005 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h @@ -0,0 +1,28 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +// clang-format off +// MIPSDialect.h.inc is generated from MIPSOps.td via: +// --dialect=mips --gen-dialect-decls MIPSDialect.h.inc +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h.inc" // IWYU pragma: keep +// clang-format on + +namespace mlir::iree_compiler::IREE::MIPS { + +// Register external BufferizableOpInterface models for MIPS ops. +// Call this from registerIreeDialects() before bufferization runs. +void registerMIPSBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry); + +} // namespace mlir::iree_compiler::IREE::MIPS + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp new file mode 100644 index 000000000000..c60c0b1d6c6f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp @@ -0,0 +1,98 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// MatmulOp — ReifyRankedShapedTypeOpInterface +// +// Returns the output shape [M, N] so IREE's dispatch formation can compute +// the workload when wrapping mips.matmul in a flow.dispatch.workgroups region. +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + // Result is always tensor. M from lhs dim 0, N from rhs dim 1. + reifiedReturnShapes.push_back({tensor::getMixedSize(b, getLoc(), getLhs(), 0), + tensor::getMixedSize(b, getLoc(), getRhs(), 1)}); + return success(); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — MemoryEffectsOpInterface +// +// In the tensor domain, ops are nominally pure (tensors are values, not memory). +// However mips.matmul uses DPS — the init operand logically "carries" the +// result. We declare read on lhs/rhs and read+write on init so that alias +// analyses outside of bufferization correctly treat init as modified. +//===----------------------------------------------------------------------===// + +void MatmulOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — Verifier +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::verify() { + auto shape = [](Value v) { + return cast(v.getType()).getShape(); + }; + auto elemTy = [](Value v) { + return cast(v.getType()).getElementType(); + }; + + // All operands must be 2-D tensors. + for (Value v : {getLhs(), getRhs(), getInit()}) { + if (cast(v.getType()).getRank() != 2) + return emitOpError("all operands must be 2-D ranked tensors"); + } + + // Dimension compatibility: lhs[M x K], rhs[K x N], init[M x N]. + // Only validate static dimensions; dynamic dims are checked at runtime. + auto compat = [](int64_t a, int64_t b) { + return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b; + }; + if (!compat(shape(getLhs())[0], shape(getInit())[0])) + return emitOpError("lhs dim 0 (M) must match init dim 0 (M)"); + if (!compat(shape(getLhs())[1], shape(getRhs())[0])) + return emitOpError("lhs dim 1 (K) must match rhs dim 0 (K)"); + if (!compat(shape(getRhs())[1], shape(getInit())[1])) + return emitOpError("rhs dim 1 (N) must match init dim 1 (N)"); + + // All element types must match. + if (elemTy(getLhs()) != elemTy(getRhs()) || elemTy(getLhs()) != elemTy(getInit())) + return emitOpError("element types of all operands must match"); + + // Result type must match init type (both tensor). + if (getResult().getType() != getInit().getType()) + return emitOpError("result type must match init type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen generated op definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h new file mode 100644 index 000000000000..dc2881aca2e3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h @@ -0,0 +1,25 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +// clang-format off + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h.inc" // IWYU pragma: export + +// clang-format on + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td new file mode 100644 index 000000000000..e87073fe3b99 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td @@ -0,0 +1,71 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_OPS +#define IREE_DIALECT_MIPS_OPS + +include "iree/compiler/Dialect/MIPS/IR/MIPSBase.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// mips.matmul — tensor-only semantic op +// +// This op exists exclusively in the tensor domain. It is eliminated during +// One-Shot Bufferize: the BufferizableOpInterface implementation allocates +// the output memref and emits func.call @mips_matmul(...) directly, so no +// memref form of this op ever exists in the IR. +//===----------------------------------------------------------------------===// + +def MIPS_MatmulOp : MIPS_Op<"matmul", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DestinationStyleOpInterface +]> { + let summary = "MIPS matrix multiplication (tensor domain): result = lhs * rhs"; + let description = [{ + Computes a 2-D matrix multiplication in the tensor domain using + destination-passing style (DPS). The caller provides an `init` tensor that + One-Shot Bufferize uses to determine the output buffer — typically + `bufferization.alloc_tensor` for a fresh allocation. + + The semantic is: `result[m, n] = sum_k(lhs[m, k] * rhs[k, n])`. + + This op is created by the Torch -> MIPS conversion pass from `torch.aten.mm` + and is eliminated entirely during bufferization: the BufferizableOpInterface + implementation emits `func.call @mips_matmul` directly with the bufferized + memref operands. No memref-form `mips.matmul` is ever produced. + + Example: + ```mlir + %result = mips.matmul %A, %B, %init + : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> + ``` + }]; + + let arguments = (ins + AnyRankedTensor:$lhs, // [M x K] + AnyRankedTensor:$rhs, // [K x N] + AnyRankedTensor:$init // [M x N] — DPS destination (typically alloc_tensor) + ); + + let results = (outs AnyRankedTensor:$result); // [M x N] + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $init attr-dict `:` + type($lhs) `,` type($rhs) `,` type($init) `->` type($result) + }]; + + let extraClassDeclaration = [{ + // DestinationStyleOpInterface: init is the DPS output operand. + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + }]; + + let hasVerifier = 1; +} + +#endif // IREE_DIALECT_MIPS_OPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..b37cc85e3b29 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +# ─── Tablegen: generate Passes.h.inc from Passes.td ────────────────────────── + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +# ─── Header-only pass declarations ─────────────────────────────────────────── + +iree_cc_library( + NAME + PassHeaders + HDRS + "Passes.h" + "Passes.h.inc" + DEPS + ::PassesIncGen + MLIRPass + MLIRTransforms + PUBLIC +) + +# ─── Full transforms library ───────────────────────────────────────────────── + +iree_cc_library( + NAME + Transforms + HDRS + "Passes.h" + SRCS + "LowerMIPSToFuncCall.cpp" + "Passes.cpp" + DEPS + ::PassHeaders + ::PassesIncGen + iree::compiler::Dialect::MIPS::IR + MLIRFuncDialect + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransformUtils + MLIRTransforms + PUBLIC +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp new file mode 100644 index 000000000000..df9caa9608f1 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp @@ -0,0 +1,36 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// mips.matmul is a tensor-only op that is eliminated entirely during +// One-Shot Bufferize: the BufferizableOpInterface implementation in +// MIPSBufferizableOpInterface.cpp emits func.call @my_matmul_kernel directly. +// +// This pass is therefore a no-op and exists only for registration purposes +// (so that --iree-mips-lower-to-func-call can be specified on the command line +// without error, and so that any pipeline that references it still compiles). + +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +#define GEN_PASS_DEF_LOWERMIPSTOFUNCCALLPASS +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" + +namespace { + +struct LowerMIPSToFuncCallPass + : impl::LowerMIPSToFuncCallPassBase { + void runOnOperation() override { + // mips.matmul is eliminated during One-Shot Bufferize (see + // MIPSBufferizableOpInterface.cpp). No work to do here. + } +}; + +} // namespace +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp new file mode 100644 index 000000000000..751ea73789d8 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp @@ -0,0 +1,18 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +namespace { +#define GEN_PASS_REGISTRATION +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" +} // namespace + +void registerMIPSPasses() { registerPasses(); } + +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h new file mode 100644 index 000000000000..8b4390167407 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h @@ -0,0 +1,33 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +//===----------------------------------------------------------------------===// +// Pass factory functions (generated by tablegen + implemented in Passes.cpp) +//===----------------------------------------------------------------------===// + +/// Creates the pass that lowers memref-form mips.matmul to +/// func.call @my_matmul_kernel. +std::unique_ptr createLowerMIPSToFuncCallPass(); + +/// Registers all MIPS passes with the global pass registry so they can be +/// invoked from the command line (e.g. `iree-opt --iree-mips-lower-to-func-call`). +void registerMIPSPasses(); + +} // namespace mlir::iree_compiler::IREE::MIPS + +// clang-format off +#define GEN_PASS_DECL +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" // IWYU pragma: keep +// clang-format on + +#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td new file mode 100644 index 000000000000..b081e3a58a7c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td @@ -0,0 +1,49 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES +#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// LowerMIPSToFuncCallPass +// +// Converts memref-form mips.matmul operations (produced after one-shot +// bufferization in the LLVMCPU codegen pipeline) to: +// +// func.func private @my_matmul_kernel(...) attributes {llvm.bareptr = true} +// func.call @my_matmul_kernel(A_base, A_off, A_s0, A_s1, +// B_base, B_off, B_s0, B_s1, +// C_base, C_off, C_s0, C_s1, +// M, N, K) +// +// The pass uses memref.extract_strided_metadata to decompose each memref into +// a (base_pointer, offset, strides...) tuple matching the C ABI of the kernel. +//===----------------------------------------------------------------------===// + +def LowerMIPSToFuncCallPass : + Pass<"iree-mips-lower-to-func-call", "ModuleOp"> { + let summary = "Lower mips.matmul (memref form) to func.call @my_matmul_kernel"; + let description = [{ + Walks all mips.matmul ops in the module and replaces each one with a call + to the external C kernel `my_matmul_kernel`. The memref operands are + decomposed via `memref.extract_strided_metadata` into base-pointer + offset + + stride arguments, matching the ABI declared in my_matmul_kernel.h. + + The pass creates a `func.func private @my_matmul_kernel` declaration with + `{llvm.bareptr = true}` so that the LLVM backend passes bare float* pointers + rather than MLIR memref descriptor structs. + + This pass runs after one-shot bufferization in the LLVMCPU codegen pipeline. + }]; + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::memref::MemRefDialect" + ]; +} + +#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt new file mode 100644 index 000000000000..725370463654 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_lit_test_suite( + NAME + lit + SRCS + "lower_to_func_call.mlir" + TOOLS + FileCheck + iree-opt +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir new file mode 100644 index 000000000000..582419acf589 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir @@ -0,0 +1,49 @@ +// RUN: iree-opt --split-input-file --iree-mips-lower-to-func-call %s \ +// RUN: | FileCheck %s + +// ───────────────────────────────────────────────────────────────────────────── +// Basic static-shape: memref-form mips.matmul → func.call @my_matmul_kernel +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK: func.func private @my_matmul_kernel +// CHECK-SAME: {llvm.bareptr = true} +// +// CHECK-LABEL: func.func @lower_mips_matmul +// CHECK-NOT: mips.matmul +// CHECK: memref.extract_strided_metadata +// CHECK: call @my_matmul_kernel +module { + func.func @lower_mips_matmul(%A: memref<4x8xf32>, + %B: memref<8x4xf32>, + %C: memref<4x4xf32>) { + mips.matmul %A, %B, %C + : memref<4x8xf32>, memref<8x4xf32>, memref<4x4xf32> + return + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Multiple matmuls reuse the same @my_matmul_kernel declaration. +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK: func.func private @my_matmul_kernel +// Check that there is exactly one declaration (not two). +// CHECK-NOT: func.func private @my_matmul_kernel +// +// CHECK-LABEL: func.func @two_matmuls +// CHECK: call @my_matmul_kernel +// CHECK: call @my_matmul_kernel +module { + func.func @two_matmuls(%A: memref<2x4xf32>, + %B: memref<4x2xf32>, + %C: memref<2x2xf32>, + %D: memref<2x4xf32>, + %E: memref<4x2xf32>, + %F: memref<2x2xf32>) { + mips.matmul %A, %B, %C + : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> + mips.matmul %D, %E, %F + : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> + return + } +} diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt index 4b124334e1ef..01cc625df2af 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt @@ -89,6 +89,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::TensorExt::Transforms diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 6ba8a4bd8db2..f6a5467aee5c 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h" @@ -369,6 +370,9 @@ static bool isRootLikeOp(Operation *op) { return !isa(op); } + // MIPS: mips.matmul is a dispatch root (lowered to a custom C kernel call). + if (isa(op)) + return true; return isa(op); } diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index c47ae6cb4368..78f1d3768f68 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h @@ -23,6 +23,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -50,6 +51,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { IREE::HAL::Loader::HALLoaderDialect, IREE::IO::Parameters::IOParametersDialect, IREE::LinalgExt::IREELinalgExtDialect, + IREE::MIPS::MIPSDialect, IREE::PCF::PCFDialect, IREE::Encoding::IREEEncodingDialect, IREE::Stream::StreamDialect, @@ -65,6 +67,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { registerCodegenInterfaces(registry); registerGlobalOptimizationInterfaces(registry); registerUKernelBufferizationInterface(registry); + IREE::MIPS::registerMIPSBufferizableOpInterfaceExternalModels(registry); // Register transform dialect extensions. registerTransformDialectPreprocessingExtension(registry); diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h index 6f7de0752f45..d113b8d61d3d 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_passes.h +++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h @@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Stream/Transforms/Passes.h" #include "iree/compiler/Dialect/TensorExt/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" @@ -62,6 +63,7 @@ inline void registerAllIreePasses() { IREE::HAL::Loader::registerHALLoaderPasses(); IREE::IO::Parameters::registerParametersPasses(); IREE::LinalgExt::registerPasses(); + IREE::MIPS::registerMIPSPasses(); IREE::Stream::registerStreamPasses(); IREE::TensorExt::registerPasses(); IREE::Util::registerUtilPasses(); From f3c7ba519ff06a101f8378aba30a3a39feb99aac Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:10:58 +0530 Subject: [PATCH 02/12] [MIPS] Add Torch-to-MIPS conversion pass and LLVMCPU pipeline wiring Adds the frontend conversion from torch.aten.mm to mips.matmul and wires the MIPS dialect into the IREE LLVMCPU codegen pipeline. Torch InputConversion changes: - ConvertTorchToMIPS.cpp: ConvertAtenMmToMIPSMatmul pattern rewrites torch.aten.mm to mips.matmul with a zero-initialized init tensor (bufferization.alloc_tensor). The pass runs before the standard ConvertTorchToLinalgPass so mips.matmul takes precedence over the generic linalg.matmul path. - Passes.td / Passes.h / Passes.cpp: declare ConvertTorchToMIPSPass and add it to the torch-to-iree pipeline under the use-mips-matmul option. - test/convert_torch_to_mips.mlir: FileCheck test verifying that torch.aten.mm is replaced by mips.matmul after the pass. LLVMCPU codegen pipeline changes: - Passes.cpp: insert LowerMIPSToFuncCallPass (no-op stub) in the post-bufferize section of buildLLVMCPUCodegenPassPipeline. The actual lowering to func.call is performed during One-Shot Bufferize by MIPSBufferizableOpInterface; this stub ensures the pass slot is reserved for future use and keeps the pipeline definition explicit. - CMakeLists.txt: add iree_compiler_Dialect_MIPS_Transforms_Transforms dependency to the LLVMCPU codegen target. --- .../Torch/InputConversion/CMakeLists.txt | 2 + .../InputConversion/ConvertTorchToMIPS.cpp | 155 ++++++++++++++++++ .../input/Torch/InputConversion/Passes.cpp | 5 + .../input/Torch/InputConversion/Passes.h | 6 + .../input/Torch/InputConversion/Passes.td | 10 ++ .../Torch/InputConversion/test/CMakeLists.txt | 1 + .../test/convert_torch_to_mips.mlir | 37 +++++ .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 + .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 5 + 9 files changed, 222 insertions(+) create mode 100644 compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp create mode 100644 compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt index 2738c8d9fa99..d9d0fa598477 100644 --- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt @@ -37,6 +37,7 @@ iree_cc_library( "BindSymbolicShapes.cpp" "BitCastTensor.cpp" "ConvertTMTensorToLinalgExt.cpp" + "ConvertTorchToMIPS.cpp" "ConvertTorchUnstructuredToLinalgExt.cpp" "FuncConversion.cpp" "SetStrictSymbolicShapes.cpp" @@ -57,6 +58,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR PUBLIC diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp new file mode 100644 index 000000000000..1999f8695dfd --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp @@ -0,0 +1,155 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Converts torch.aten.mm → mips.matmul. +// +// The pattern runs inside the Torch input-conversion pipeline, BEFORE +// createConvertTorchToLinalgPass(), so it intercepts aten.mm first. +// +// Since torch ops carry ValueTensorType (torch's tensor type), the pattern: +// 1. Casts operands to builtin RankedTensorType via ToBuiltinTensorOp. +// 2. Creates a zero-initialised init tensor (Destination Passing Style). +// 3. Emits mips.matmul on builtin tensors. +// 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp. +// +// This mirrors the approach in ConvertTorchUnstructuredToLinalgExt.cpp. + +#include "compiler/plugins/input/Torch/InputConversion/Passes.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +namespace mlir::iree_compiler::TorchInput { + +#define GEN_PASS_DEF_CONVERTTORCHTOMIPSPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Helper: create a zero-filled tensor of a given shape and element type. +// Accepts (M, N) as dynamic Value dimensions. +//===----------------------------------------------------------------------===// + +static Value createZeroTensor(PatternRewriter &rewriter, Location loc, + RankedTensorType ty, ValueRange dynSizes) { + Value empty = tensor::EmptyOp::create(rewriter, loc, ty, dynSizes); + Attribute zeroAttr = rewriter.getZeroAttr(ty.getElementType()); + Value zero = arith::ConstantOp::create(rewriter, loc, cast(zeroAttr)); + return linalg::FillOp::create(rewriter, loc, zero, empty).result(); +} + +//===----------------------------------------------------------------------===// +// Pattern: torch.aten.mm → mips.matmul +//===----------------------------------------------------------------------===// + +struct ConvertAtenMmToMIPSMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(torch::Torch::AtenMmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // ---------------------------------------------------------------- + // 1. Verify that we have supported tensor types. + // ---------------------------------------------------------------- + auto lhsTorchTy = + dyn_cast(op.getSelf().getType()); + auto rhsTorchTy = + dyn_cast(op.getMat2().getType()); + auto resultTorchTy = + dyn_cast(op.getType()); + + if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) + return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); + + // Only handle f32 for now (extensible). + if (!lhsTorchTy.getDtype().isF32()) + return rewriter.notifyMatchFailure(op, "only f32 supported"); + + // ---------------------------------------------------------------- + // 2. Cast operands from torch ValueTensorType → builtin RankedTensorType. + // ---------------------------------------------------------------- + auto lhsBuiltinTy = + dyn_cast_or_null(lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = + dyn_cast_or_null(rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = + dyn_cast_or_null(resultTorchTy.toBuiltinTensor()); + + if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || + lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, lhsBuiltinTy, op.getSelf()); + Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, rhsBuiltinTy, op.getMat2()); + + // ---------------------------------------------------------------- + // 3. Collect dynamic dimension values for the result tensor (M, N). + // ---------------------------------------------------------------- + SmallVector dynSizes; + if (resultBuiltinTy.isDynamicDim(0)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0)); + if (resultBuiltinTy.isDynamicDim(1)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1)); + + // ---------------------------------------------------------------- + // 4. Create a zero-initialised init tensor for DPS output. + // ---------------------------------------------------------------- + Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes); + + // ---------------------------------------------------------------- + // 5. Emit mips.matmul on builtin tensors. + // ---------------------------------------------------------------- + Value result = + IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy}, + lhs, rhs, init) + .getResult(); + + // ---------------------------------------------------------------- + // 6. Cast result back to ValueTensorType so downstream torch passes can + // still operate on it until the type finalisation pass runs. + // ---------------------------------------------------------------- + Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, resultTorchTy, result); + + rewriter.replaceOp(op, torchResult); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct ConvertTorchToMIPSPass + : impl::ConvertTorchToMIPSPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index ce0cb9c34f36..82fdf118b9d5 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -63,6 +63,11 @@ void createTorchToIREEPipeline( pm.addNestedPass(torch::createConvertTorchToTensorPass()); pm.addNestedPass( TorchInput::createConvertTorchUnstructuredToLinalgExtPass()); + // MIPS: When enabled, intercept aten.mm before the standard torch->linalg + // pass and route it through mips.matmul -> func.call @my_matmul_kernel. + if (options.useMIPSMatmul) { + pm.addNestedPass(TorchInput::createConvertTorchToMIPSPass()); + } pm.addNestedPass(torch::createConvertTorchToLinalgPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(torch::createConvertTorchToSCFPass()); diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.h b/compiler/plugins/input/Torch/InputConversion/Passes.h index 23995b21a943..8eb22c3a6034 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/InputConversion/Passes.h @@ -43,6 +43,12 @@ struct TorchToIREELoweringPipelineOptions "program inputs. This buffer will be used for storing transient " "memory and must be provided by the user."), llvm::cl::init(false)}; + Option useMIPSMatmul{ + *this, "use-mips-matmul", + llvm::cl::desc("If enabled, lowers torch.aten.mm through the MIPS " + "custom dialect (mips.matmul) instead of the standard " + "torch->linalg path."), + llvm::cl::init(false)}; }; // Creates a pipeline that lowers from the torch backend contract to IREE. diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index a868d4bb8354..fb55ba7d189f 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -29,6 +29,16 @@ def ConvertTorchUnstructuredToLinalgExtPass : let summary = "Convert unstructured Torch ops to LinalgExt ops"; } +def ConvertTorchToMIPSPass : + InterfacePass<"torch-iree-to-mips-matmul", "mlir::FunctionOpInterface"> { + let summary = "Convert torch.aten.mm to mips.matmul"; + let description = [{ + Intercepts torch.aten.mm before the standard torch->linalg conversion and + replaces it with mips.matmul. The mips.matmul op is later lowered to a + func.call @my_matmul_kernel by LowerMIPSToFuncCallPass after bufferization. + }]; +} + def SetStrictSymbolicShapesPass : InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> { let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt index b1785a708878..3d451b68f13a 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt @@ -9,6 +9,7 @@ iree_lit_test_suite( "attention.mlir" "bind_symbolic_shapes.mlir" "bitcast_tensor.mlir" + "convert_torch_to_mips.mlir" "func_conversion.mlir" "func_conversion_invalid.mlir" "func_conversion_transients.mlir" diff --git a/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir new file mode 100644 index 000000000000..98d5407fefc5 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt --split-input-file \ +// RUN: --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ +// RUN: %s | FileCheck %s + +// ───────────────────────────────────────────────────────────────────────────── +// Static-shape: torch.aten.mm on f32 tensors → mips.matmul +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_static +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<4x8xf32> +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<8x4xf32> +// CHECK: mips.matmul {{.*}} : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> +// CHECK-NOT: torch.aten.mm +func.func @mm_static(%A: !torch.vtensor<[4,8],f32>, + %B: !torch.vtensor<[8,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ───────────────────────────────────────────────────────────────────────────── +// Non-f32 (i32) should be left untouched (pattern rejects non-f32 dtypes). +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_i32_unchanged +// CHECK-NOT: mips.matmul +// CHECK: torch.aten.mm +func.func @mm_i32_unchanged(%A: !torch.vtensor<[4,8],si32>, + %B: !torch.vtensor<[8,4],si32>) + -> !torch.vtensor<[4,4],si32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,4],si32> + -> !torch.vtensor<[4,4],si32> + return %0 : !torch.vtensor<[4,4],si32> +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index ebd8d7d93755..173e4cdc06db 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -166,6 +166,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::MIPS::Transforms iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::Util::IR iree::compiler::Dialect::Util::Transforms diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 771839bd3400..c0d9ed9981e4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/Utils/CodegenOptions.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/TypeSwitch.h" @@ -518,6 +519,10 @@ static void addLowerToLLVMPasses(OpPassManager &modulePassManager, FunctionLikeNest(modulePassManager) .addPass(createEraseHALDescriptorTypeFromMemRefPass); + // mips.matmul is eliminated during One-Shot Bufferize (func.call emitted + // directly by MIPSBufferizableOpInterface). This pass is now a no-op. + modulePassManager.addPass(IREE::MIPS::createLowerMIPSToFuncCallPass()); + // Lower `ukernel.*` ops to function calls modulePassManager.addPass(createLowerUKernelOpsToCallsPass()); From 2958952bf1787aabf339c8de5b635c0749ff62a3 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:11:49 +0530 Subject: [PATCH 03/12] [MIPS] Add my_matmul_kernel IREE executable plugin library Provides the runtime kernel that backs mips.matmul dispatches. The kernel is packaged as an IREE executable plugin (implements iree_hal_executable_plugin_query) rather than a plain shared library because IREE's LLVMCPU system-dylib dispatch format resolves external function references through an internal import table (not standard ELF dynamic linking). The plugin's resolve() callback maps the symbol name "my_matmul_kernel" to the import-ABI wrapper. Usage: iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... --- runtime/src/iree/builtins/mips/CMakeLists.txt | 31 ++++ .../src/iree/builtins/mips/my_matmul_kernel.c | 141 ++++++++++++++++++ .../src/iree/builtins/mips/my_matmul_kernel.h | 40 +++++ 3 files changed, 212 insertions(+) create mode 100644 runtime/src/iree/builtins/mips/CMakeLists.txt create mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.c create mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.h diff --git a/runtime/src/iree/builtins/mips/CMakeLists.txt b/runtime/src/iree/builtins/mips/CMakeLists.txt new file mode 100644 index 000000000000..19574a5148f7 --- /dev/null +++ b/runtime/src/iree/builtins/mips/CMakeLists.txt @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Shared library containing the MIPS custom matmul kernel. +# +# The kernel is exposed as an IREE executable plugin (implements +# iree_hal_executable_plugin_query) so that iree-run-module can load it via: +# iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... +# +# The IREE runtime resolves the @my_matmul_kernel import from the dispatch +# executable through this plugin's resolve() function. + +add_library(my_matmul_kernel SHARED + my_matmul_kernel.c +) + +target_include_directories(my_matmul_kernel + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + PRIVATE + # For iree/hal/local/executable_plugin.h (standalone C header, no deps) + ${PROJECT_SOURCE_DIR}/runtime/src +) + +set_target_properties(my_matmul_kernel PROPERTIES + C_VISIBILITY_PRESET default + POSITION_INDEPENDENT_CODE ON +) diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.c b/runtime/src/iree/builtins/mips/my_matmul_kernel.c new file mode 100644 index 000000000000..e952e72b22e0 --- /dev/null +++ b/runtime/src/iree/builtins/mips/my_matmul_kernel.c @@ -0,0 +1,141 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Naive triple-loop matmul kernel exposed as an IREE executable plugin. +// +// The IREE-compiled dispatch calls @my_matmul_kernel through the IREE import +// mechanism (not direct dynamic linking). At runtime the plugin is registered +// via: +// iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... +// +// IREE import calling convention (for all imports): +// int fn(void* params_ptr, void* context, void* reserved) +// +// The params_ptr points to a packed struct whose fields mirror the +// func.call arguments emitted by LowerMIPSToFuncCall, in order: +// float* A, int64_t A_off, A_s0, A_s1, +// float* B, int64_t B_off, B_s0, B_s1, +// float* C, int64_t C_off, C_s0, C_s1, +// int64_t M, int64_t N, int64_t K + +#include +#include + +// IREE standalone plugin header — only requires C99 standard headers. +#include "iree/hal/local/executable_plugin.h" + +//===----------------------------------------------------------------------===// +// Kernel implementation +//===----------------------------------------------------------------------===// + +// Packed argument struct mirroring the func.call arguments from +// LowerMIPSToFuncCall.cpp. +typedef struct { + float *A; + int64_t A_off, A_s0, A_s1; + float *B; + int64_t B_off, B_s0, B_s1; + float *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} my_matmul_kernel_args_t; + +// Import thunk wrapper — called by the IREE runtime with the packed args. +static int my_matmul_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const my_matmul_kernel_args_t *a = (const my_matmul_kernel_args_t *)params_ptr; + + float *A = a->A + a->A_off; + float *B = a->B + a->B_off; + float *C = a->C + a->C_off; + int64_t M = a->M, N = a->N, K = a->K; + int64_t A_s0 = a->A_s0, A_s1 = a->A_s1; + int64_t B_s0 = a->B_s0, B_s1 = a->B_s1; + int64_t C_s0 = a->C_s0, C_s1 = a->C_s1; + + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } + } + return 0; +} + +//===----------------------------------------------------------------------===// +// IREE Executable Plugin interface +//===----------------------------------------------------------------------===// + +static iree_hal_executable_plugin_status_t plugin_load( + const iree_hal_executable_plugin_environment_v0_t *environment, + size_t param_count, + const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { + (void)environment; + (void)param_count; + (void)params; + *out_self = NULL; // stateless plugin + return iree_hal_executable_plugin_ok_status(); +} + +static void plugin_unload(void *self) { (void)self; } + +static iree_hal_executable_plugin_status_t plugin_resolve( + void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, + iree_hal_executable_plugin_resolution_t *out_resolution) { + (void)self; + *out_resolution = 0; + bool any_required_not_found = false; + + for (size_t i = 0; i < params->count; ++i) { + if (params->out_fn_ptrs[i]) continue; // already resolved + const char *name = params->symbol_names[i]; + bool optional = iree_hal_executable_plugin_import_is_optional(name); + if (optional) ++name; // skip the leading '?' + + if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { + params->out_fn_ptrs[i] = my_matmul_kernel_import; + params->out_fn_contexts[i] = NULL; + } else { + if (optional) { + *out_resolution |= + IREE_HAL_EXECUTABLE_PLUGIN_RESOLUTION_MISSING_OPTIONAL; + } else { + any_required_not_found = true; + } + } + } + + return any_required_not_found + ? iree_hal_executable_plugin_status_from_code( + IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) + : iree_hal_executable_plugin_ok_status(); +} + +// Exported entry point queried by the IREE runtime (via dlsym). +IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** +iree_hal_executable_plugin_query( + iree_hal_executable_plugin_version_t max_version, void *reserved) { + static const iree_hal_executable_plugin_header_t header = { + .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, + .name = "mips_matmul", + .description = "MIPS custom matmul kernel plugin", + .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, + .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, + }; + static const iree_hal_executable_plugin_v0_t plugin = { + .header = &header, + .load = plugin_load, + .unload = plugin_unload, + .resolve = plugin_resolve, + }; + return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST + ? (const iree_hal_executable_plugin_header_t **)&plugin + : NULL; +} diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.h b/runtime/src/iree/builtins/mips/my_matmul_kernel.h new file mode 100644 index 000000000000..8f8319d6dad9 --- /dev/null +++ b/runtime/src/iree/builtins/mips/my_matmul_kernel.h @@ -0,0 +1,40 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// C ABI for the MIPS custom matmul kernel. +// +// The kernel computes: +// C[m, n] = sum_k A[m, k] * B[k, n] +// +// Each matrix is passed as a base pointer plus explicit strided-layout +// parameters that match what MLIR's memref.extract_strided_metadata produces. +// This matches the calling convention emitted by LowerMIPSToFuncCallPass. + +#ifndef IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ +#define IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// void my_matmul_kernel( +// float *A, int64_t A_offset, int64_t A_stride0, int64_t A_stride1, +// float *B, int64_t B_offset, int64_t B_stride0, int64_t B_stride1, +// float *C, int64_t C_offset, int64_t C_stride0, int64_t C_stride1, +// int64_t M, int64_t N, int64_t K); +void my_matmul_kernel(float *A, int64_t A_offset, int64_t A_stride0, + int64_t A_stride1, float *B, int64_t B_offset, + int64_t B_stride0, int64_t B_stride1, float *C, + int64_t C_offset, int64_t C_stride0, int64_t C_stride1, + int64_t M, int64_t N, int64_t K); + +#ifdef __cplusplus +} +#endif + +#endif // IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ From a459d7854a5b1723b2ceedb7641aeb22a5282f17 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:12:17 +0530 Subject: [PATCH 04/12] [MIPS] Add compile_mips.zsh and run_mips.zsh end-to-end test scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Provides two convenience scripts for exercising the full torch.aten.mm → mips.matmul → func.call → vmfb → kernel plugin pipeline. compile_mips.zsh: Step 1 — verifies torch.aten.mm is converted to mips.matmul by the ConvertTorchToMIPSPass (iree-opt smoke check). Step 2 — runs the full torch-to-iree pipeline with use-mips-matmul=true, producing the IREE input IR (/tmp/mm_iree.mlir). Step 3 — compiles the IREE input IR to a vmfb (/tmp/mm_mips.vmfb) with --mlir-print-ir-after-all, writing a per-pass IR dump to /tmp/mm_mips_ir_dump.mlir for debugging. run_mips.zsh: Runs iree-run-module with --executable_plugin pointing at the built libmy_matmul_kernel.dylib. Tests A * I = A (matrix multiplied by identity) and prints the result for visual verification. --- compile_mips.zsh | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ run_mips.zsh | 42 ++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100755 compile_mips.zsh create mode 100755 run_mips.zsh diff --git a/compile_mips.zsh b/compile_mips.zsh new file mode 100755 index 000000000000..104fc98332d5 --- /dev/null +++ b/compile_mips.zsh @@ -0,0 +1,62 @@ +#!/bin/zsh +# compile_mips.zsh +# +# Compiles a simple torch.aten.mm through the MIPS custom matmul path: +# torch.aten.mm +# -> mips.matmul (ConvertTorchToMIPSPass) +# -> flow.dispatch(...) (IREE dispatch formation) +# -> func.call @my_matmul_kernel (bufferize via BufferizableOpInterface) +# -> LLVM / vmfb (iree-compile LLVMCPU backend) +# +# Output: /tmp/mm_mips.vmfb +# IR dump: /tmp/mm_mips_ir_dump.mlir (--mlir-print-ir-after-all) + +set -e # exit on first error + +BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build +IREE_OPT=$BUILD/tools/iree-opt +IREE_COMPILE=$BUILD/tools/iree-compile + +# ── Input: 4x4 f32 matrix multiply ──────────────────────────────────────────── +cat > /tmp/mm_torch.mlir << 'EOF' +module { + func.func @mm(%A: !torch.vtensor<[4,4],f32>, + %B: !torch.vtensor<[4,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> + } +} +EOF + +# ── Step 1: Verify torch.aten.mm → mips.matmul ──────────────────────────────── +echo "==> Step 1: verifying torch.aten.mm → mips.matmul" +$IREE_OPT \ + --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ + /tmp/mm_torch.mlir \ + | grep -q "mips.matmul" && echo " [OK] mips.matmul found in IR" + +# ── Step 2: Full torch → IREE input IR (with MIPS path enabled) ─────────────── +echo "==> Step 2: torch → IREE input IR (use-mips-matmul=true)" +$IREE_OPT \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + /tmp/mm_torch.mlir \ + -o /tmp/mm_iree.mlir + +# ── Step 3: IREE input IR → vmfb (dispatch + bufferize + LLVM) ─────────────── +IR_DUMP=/tmp/mm_mips_ir_dump.mlir +echo "==> Step 3: IREE input IR → vmfb (IR dump → $IR_DUMP)" +$IREE_COMPILE \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=false \ + --mlir-print-ir-after-all \ + /tmp/mm_iree.mlir \ + -o /tmp/mm_mips.vmfb \ + 2>"$IR_DUMP" + +echo "" +echo "==> Compiled successfully: /tmp/mm_mips.vmfb" +echo " IR dump written to: $IR_DUMP" +echo " Run with: ./run_mips.zsh" diff --git a/run_mips.zsh b/run_mips.zsh new file mode 100755 index 000000000000..d215f1e8efdd --- /dev/null +++ b/run_mips.zsh @@ -0,0 +1,42 @@ +#!/bin/zsh +# run_mips.zsh +# +# Runs the vmfb produced by compile_mips.zsh. +# +# The vmfb's dispatch executable calls @my_matmul_kernel through the IREE +# import mechanism (not direct dynamic linking). The kernel is provided by +# libmy_matmul_kernel.dylib which implements the IREE executable plugin API +# (exports iree_hal_executable_plugin_query). +# +# Test: A * I = A (multiply by 4x4 identity → expect the same matrix back) + +BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build +KERNEL_LIB=$BUILD/runtime/src/iree/builtins/mips/libmy_matmul_kernel.dylib +IREE_RUN=$BUILD/tools/iree-run-module +VMFB=/tmp/mm_mips.vmfb + +if [[ ! -f $VMFB ]]; then + echo "ERROR: $VMFB not found. Run compile_mips.zsh first." + exit 1 +fi + +if [[ ! -f $KERNEL_LIB ]]; then + echo "ERROR: $KERNEL_LIB not found. Build my_matmul_kernel target first." + exit 1 +fi + +# A = 1..16 (row-major 4x4), B = 4x4 identity +A="4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +B="4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + +echo "==> Running mm(A, I) via MIPS kernel" +echo " Kernel plugin : $KERNEL_LIB" +echo " Expected : A * I = A (rows: [1 2 3 4], [5 6 7 8], ...)" +echo "" + +$IREE_RUN \ + --executable_plugin=$KERNEL_LIB \ + --module=$VMFB \ + --function=mm \ + --input="$A" \ + --input="$B" From c14a0d20c529ae12de268e762fa4b37d727ec320 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 16:19:24 +0530 Subject: [PATCH 05/12] [MIPS] Add iree-build.zsh build configuration script --- iree-build.zsh | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100755 iree-build.zsh diff --git a/iree-build.zsh b/iree-build.zsh new file mode 100755 index 000000000000..8738503ba098 --- /dev/null +++ b/iree-build.zsh @@ -0,0 +1,44 @@ +#!/usr/bin/env zsh + +# Exit if any command fails +set -e + +# Paths (customize as needed) +SRC_DIR="$HOME/MLIR_Work/mips" +IREE_SRC_DIR="$SRC_DIR/iree" # Path to your cloned iree +BUILD_DIR="$SRC_DIR/iree-build" + +# Configuration +CMAKE_GENERATOR="Ninja" # Change to "Unix Makefiles" if preferred +BUILD_TYPE="RelWithDebInfo" # "Debug" or "RelWithDebInfo" are alternatives +NUM_JOBS=$(sysctl -n hw.logicalcpu) # Uses all CPU cores + +# Prepare directories +mkdir -p "${BUILD_DIR}" + +# CMake configuration +cmake -S "${IREE_SRC_DIR}" -B "${BUILD_DIR}" \ + -G "${CMAKE_GENERATOR}" \ + -DCMAKE_BUILD_TYPE="${BUILD_TYPE}" \ + -DIREE_ENABLE_ASSERTIONS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER=/usr/bin/clang \ + -DCMAKE_CXX_COMPILER=/usr/bin/clang++ \ + -DIREE_ENABLE_SPLIT_DWARF=ON \ + -DIREE_ENABLE_LLD=ON \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_TARGET_BACKEND_LLVM_CPU=ON \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python)" + +# Build +ninja -C "${BUILD_DIR}" -j"${NUM_JOBS}" + +# Or combine all steps using a utility target +#cmake --build ../iree-build --target iree-run-tests + +echo "✅ IREE built and executed tests successfully!" From ff06dd7cc871fb76456ec4f09b83636b70ff03cf Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 03:53:04 -0700 Subject: [PATCH 06/12] [MIPS] Add static linking option via --iree-mips-static-embedding Add a global llvm::cl::opt flag `--iree-mips-static-embedding` to iree-compile. When set, MIPSBufferizableOpInterface stamps the emitted @my_matmul_kernel declaration with {hal.import.static}, causing ConvertToLLVM to emit a direct linker-resolved call instead of a dynamic HAL import table entry. Without the flag the call goes through the HAL import table, allowing runtime resolution via --executable_plugin. --- compile_mips.zsh | 3 +- .../input/Torch/InputConversion/Passes.cpp | 2 + .../MIPS/IR/MIPSBufferizableOpInterface.cpp | 20 ++ iree-build.zsh | 44 --- run_mips.zsh | 7 +- .../iree/builtins/mips/rvv_matmul_kernel.c | 242 +++++++++++++++++ .../iree/builtins/mips/rvv_standalone_test.c | 252 ++++++++++++++++++ 7 files changed, 523 insertions(+), 47 deletions(-) delete mode 100755 iree-build.zsh create mode 100644 runtime/src/iree/builtins/mips/rvv_matmul_kernel.c create mode 100644 runtime/src/iree/builtins/mips/rvv_standalone_test.c diff --git a/compile_mips.zsh b/compile_mips.zsh index 104fc98332d5..c92271ce30f5 100755 --- a/compile_mips.zsh +++ b/compile_mips.zsh @@ -13,7 +13,8 @@ set -e # exit on first error -BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build +export LD_LIBRARY_PATH="$HOME/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" +BUILD=$HOME/MLIR_Work/mips/iree-build IREE_OPT=$BUILD/tools/iree-opt IREE_COMPILE=$BUILD/tools/iree-compile diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index 82fdf118b9d5..888d400db5ea 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -8,6 +8,8 @@ #include "iree/compiler/Dialect/Util/IR/UtilOps.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp index e5499939bcc3..b28f4ea3ce01 100644 --- a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp @@ -32,6 +32,7 @@ #include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" #include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -44,6 +45,19 @@ using namespace mlir; using namespace mlir::bufferization; +// When true, my_matmul_kernel is emitted as a direct linker-resolved call +// (hal.import.static) instead of a dynamic HAL import table entry. +// Pass --iree-mips-static-embedding to iree-compile to enable. +// Mutually exclusive with --executable_plugin at runtime. +static llvm::cl::opt clMIPSStaticEmbedding( + "iree-mips-static-embedding", + llvm::cl::desc( + "Emit my_matmul_kernel as a direct linker-resolved call " + "(hal.import.static) instead of a dynamic HAL import. " + "Requires the kernel .o to be appended by lld_wrapper at compile " + "time. Mutually exclusive with --executable_plugin at runtime."), + llvm::cl::init(false)); + namespace mlir::iree_compiler::IREE::MIPS { namespace { @@ -69,6 +83,12 @@ static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, auto fnDecl = func::FuncOp::create(rewriter, loc, kKernelName, fnType); SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private); fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); + // If --iree-mips-static-embedding was passed to iree-compile, emit a direct + // linker call instead of a dynamic HAL import table entry. + // Without this flag the call goes through the HAL import table, which lets + // the runtime resolve it from an --executable_plugin .so at run time. + if (clMIPSStaticEmbedding) + fnDecl->setAttr("hal.import.static", rewriter.getUnitAttr()); return fnDecl; } diff --git a/iree-build.zsh b/iree-build.zsh deleted file mode 100755 index 8738503ba098..000000000000 --- a/iree-build.zsh +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env zsh - -# Exit if any command fails -set -e - -# Paths (customize as needed) -SRC_DIR="$HOME/MLIR_Work/mips" -IREE_SRC_DIR="$SRC_DIR/iree" # Path to your cloned iree -BUILD_DIR="$SRC_DIR/iree-build" - -# Configuration -CMAKE_GENERATOR="Ninja" # Change to "Unix Makefiles" if preferred -BUILD_TYPE="RelWithDebInfo" # "Debug" or "RelWithDebInfo" are alternatives -NUM_JOBS=$(sysctl -n hw.logicalcpu) # Uses all CPU cores - -# Prepare directories -mkdir -p "${BUILD_DIR}" - -# CMake configuration -cmake -S "${IREE_SRC_DIR}" -B "${BUILD_DIR}" \ - -G "${CMAKE_GENERATOR}" \ - -DCMAKE_BUILD_TYPE="${BUILD_TYPE}" \ - -DIREE_ENABLE_ASSERTIONS=ON \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_C_COMPILER=/usr/bin/clang \ - -DCMAKE_CXX_COMPILER=/usr/bin/clang++ \ - -DIREE_ENABLE_SPLIT_DWARF=ON \ - -DIREE_ENABLE_LLD=ON \ - -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ - -DIREE_TARGET_BACKEND_LLVM_CPU=ON \ - -DIREE_HAL_DRIVER_DEFAULTS=OFF \ - -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ - -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ - -DIREE_BUILD_PYTHON_BINDINGS=ON \ - -DPython3_EXECUTABLE="$(which python)" - -# Build -ninja -C "${BUILD_DIR}" -j"${NUM_JOBS}" - -# Or combine all steps using a utility target -#cmake --build ../iree-build --target iree-run-tests - -echo "✅ IREE built and executed tests successfully!" diff --git a/run_mips.zsh b/run_mips.zsh index d215f1e8efdd..3bee80deb071 100755 --- a/run_mips.zsh +++ b/run_mips.zsh @@ -10,9 +10,12 @@ # # Test: A * I = A (multiply by 4x4 identity → expect the same matrix back) -BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build -KERNEL_LIB=$BUILD/runtime/src/iree/builtins/mips/libmy_matmul_kernel.dylib +BUILD=$HOME/MLIR_Work/mips/iree-build +KERNEL_LIB=$BUILD/runtime/src/iree/builtins/mips/libmy_matmul_kernel.so IREE_RUN=$BUILD/tools/iree-run-module + +# conda libstdc++ must be visible when iree-run-module dlopen()s the .so +export LD_LIBRARY_PATH="$HOME/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" VMFB=/tmp/mm_mips.vmfb if [[ ! -f $VMFB ]]; then diff --git a/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c b/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c new file mode 100644 index 000000000000..149073d416b5 --- /dev/null +++ b/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c @@ -0,0 +1,242 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// RVV (RISC-V Vector 1.0) matmul kernel exposed as an IREE executable plugin. +// +// ABI contract (must match MIPSBufferizableOpInterface.cpp / decomposeMemref2D): +// Each 2-D memref is decomposed into (base_ptr, offset, stride0, stride1). +// {llvm.bareptr = true} means memref → float* in C. +// index → int64_t on RV64. +// +// C signature (15 args total): +// void my_matmul_kernel( +// float* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] +// float* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] +// float* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] +// int64_t M, int64_t N, int64_t K +// ); +// +// Vectorization strategy (RVV LMUL=m4): +// - Outer loops: m (rows of A) and k (contraction axis) +// - Inner loop : n (columns of B), vectorized with vsetvl_e32m4 +// - Each vl-wide chunk of C[m, n:n+vl] is kept in a vfloat32m4_t accumulator +// and accumulated across the full k-dimension before storing. +// - Handles arbitrary strides via conditional vlse/vsse. + +#include +#include + +#include "iree/hal/local/executable_plugin.h" + +#ifdef __riscv_vector +#include +#endif + +//===----------------------------------------------------------------------===// +// Core RVV kernel +//===----------------------------------------------------------------------===// + +// The kernel is separated from the plugin wrapper so it can be unit-tested +// with a plain C harness (rvv_standalone_test.c) without IREE headers. +// +// For non-RV targets (e.g. when the host compiler processes this file) the +// function falls back to a scalar triple loop so the plugin still links. + +#ifdef __riscv_vector + +// RVV implementation — compiled only when targeting RV64GCV. +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + // Apply offsets once — identical to the scalar implementation. + A += A_off; + B += B_off; + C += C_off; + + // Detect unit-stride fast paths at runtime to choose vle/vlse. + const int a_unit_col = (A_s1 == 1); + const int b_unit_col = (B_s1 == 1); + const int c_unit_col = (C_s1 == 1); + + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + // Set vector length for this strip of N columns. + size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); + + // Zero accumulator — one vfloat32m4_t per output strip C[m, n:n+vl]. + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t k = 0; k < K; ++k) { + // Scalar load: A[m, k] (row-major access — s0=N, s1=1 for row-major). + float a_val; + if (a_unit_col) { + a_val = A[m * A_s0 + k]; + } else { + a_val = A[m * A_s0 + k * A_s1]; + } + + // Vector load: B[k, n:n+vl]. + vfloat32m4_t b_vec; + if (b_unit_col) { + b_vec = __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl); + } else { + b_vec = __riscv_vlse32_v_f32m4( + &B[k * B_s0 + n * B_s1], + (ptrdiff_t)(B_s1 * (int64_t)sizeof(float)), vl); + } + + // acc += a_val * b_vec + acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); + } + + // Store accumulator to C[m, n:n+vl]. + if (c_unit_col) { + __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); + } else { + __riscv_vsse32_v_f32m4( + &C[m * C_s0 + n * C_s1], + (ptrdiff_t)(C_s1 * (int64_t)sizeof(float)), acc, vl); + } + n += (int64_t)vl; + } + } +} + +#else // !__riscv_vector + +// Scalar fallback — used when this file is compiled for the host (x86/arm) +// during CI or initial bring-up. +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +//===----------------------------------------------------------------------===// +// Direct-call entry point for static embedding +//===----------------------------------------------------------------------===// +// When the dispatch ELF is built with --iree-llvmcpu-link-embedded=true and +// the function declaration carries {hal.import.static}, LLVMCPU codegen emits +// a direct call to this symbol instead of going through the HAL import table. +// The lld wrapper appends this .o to the linker invocation so the symbol is +// resolved at link time. + +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + rvv_matmul_core(A, A_off, A_s0, A_s1, + B, B_off, B_s0, B_s1, + C, C_off, C_s0, C_s1, + M, N, K); +} + +//===----------------------------------------------------------------------===// +// IREE Executable Plugin interface +//===----------------------------------------------------------------------===// + +// Packed argument struct — mirrors the func.call argument list emitted by +// MIPSBufferizableOpInterface::bufferize() → decomposeMemref2D(). +typedef struct { + float *A; + int64_t A_off, A_s0, A_s1; + float *B; + int64_t B_off, B_s0, B_s1; + float *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} rvv_matmul_kernel_args_t; + +static int rvv_matmul_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const rvv_matmul_kernel_args_t *a = + (const rvv_matmul_kernel_args_t *)params_ptr; + rvv_matmul_core(a->A, a->A_off, a->A_s0, a->A_s1, + a->B, a->B_off, a->B_s0, a->B_s1, + a->C, a->C_off, a->C_s0, a->C_s1, + a->M, a->N, a->K); + return 0; +} + +static iree_hal_executable_plugin_status_t plugin_load( + const iree_hal_executable_plugin_environment_v0_t *environment, + size_t param_count, + const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { + (void)environment; + (void)param_count; + (void)params; + *out_self = NULL; + return iree_hal_executable_plugin_ok_status(); +} + +static void plugin_unload(void *self) { (void)self; } + +static iree_hal_executable_plugin_status_t plugin_resolve( + void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, + iree_hal_executable_plugin_resolution_t *out_resolution) { + (void)self; + *out_resolution = 0; + bool any_required_not_found = false; + + for (size_t i = 0; i < params->count; ++i) { + if (params->out_fn_ptrs[i]) continue; + const char *name = params->symbol_names[i]; + bool optional = iree_hal_executable_plugin_import_is_optional(name); + if (optional) ++name; + + if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { + params->out_fn_ptrs[i] = rvv_matmul_kernel_import; + params->out_fn_contexts[i] = NULL; + } else { + if (!optional) any_required_not_found = true; + } + } + + return any_required_not_found + ? iree_hal_executable_plugin_status_from_code( + IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) + : iree_hal_executable_plugin_ok_status(); +} + +IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** +iree_hal_executable_plugin_query( + iree_hal_executable_plugin_version_t max_version, void *reserved) { + static const iree_hal_executable_plugin_header_t header = { + .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, + .name = "rvv_matmul", + .description = "RISC-V RVV 1.0 matmul kernel plugin", + .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, + .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, + }; + static const iree_hal_executable_plugin_v0_t plugin = { + .header = &header, + .load = plugin_load, + .unload = plugin_unload, + .resolve = plugin_resolve, + }; + return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST + ? (const iree_hal_executable_plugin_header_t **)&plugin + : NULL; +} diff --git a/runtime/src/iree/builtins/mips/rvv_standalone_test.c b/runtime/src/iree/builtins/mips/rvv_standalone_test.c new file mode 100644 index 000000000000..6d765f6d5521 --- /dev/null +++ b/runtime/src/iree/builtins/mips/rvv_standalone_test.c @@ -0,0 +1,252 @@ +// Standalone QEMU test for rvv_matmul_core. +// +// Build for QEMU (no libc — RV-only syscall wrappers): +// clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ +// -O2 -static -nostdlib -ffreestanding \ +// rvv_standalone_test.c -o rvv_test +// qemu-riscv64 -cpu rv64,v=true,vlen=128,elen=64 ./rvv_test +// +// Build for host validation (x86, scalar fallback, with libc): +// clang rvv_standalone_test.c -O2 -o rvv_test_host && ./rvv_test_host + +#include +#include + +// ── I/O and exit ────────────────────────────────────────────────────────────── +// RISC-V target: raw ecall (no libc dependency for -nostdlib build). +// Host (x86) target: libc stdio — so the --host build also works. + +#ifdef __riscv + +static long _rv_syscall1(long nr, long a0) { + register long _a7 __asm__("a7") = nr; + register long _a0 __asm__("a0") = a0; + __asm__ volatile("ecall" : "+r"(_a0) : "r"(_a7) : "memory"); + return _a0; +} + +static long _rv_syscall3(long nr, long a0, long a1, long a2) { + register long _a7 __asm__("a7") = nr; + register long _a0 __asm__("a0") = a0; + register long _a1 __asm__("a1") = a1; + register long _a2 __asm__("a2") = a2; + __asm__ volatile("ecall" : "+r"(_a0) : "r"(_a7), "r"(_a1), "r"(_a2) + : "memory"); + return _a0; +} + +static void sys_write(const char *buf, size_t len) { + _rv_syscall3(64 /*SYS_write*/, 1 /*stdout*/, (long)buf, (long)len); +} + +__attribute__((noreturn)) static void sys_exit(int code) { + _rv_syscall1(94 /*SYS_exit_group*/, (long)code); + __builtin_unreachable(); +} + +void _start(void); // forward-declare; entry point at bottom + +#else // host x86 build + +#include +#include + +static void sys_write(const char *buf, size_t len) { + fwrite(buf, 1, len, stdout); +} + +__attribute__((noreturn)) static void sys_exit(int code) { + exit(code); +} + +#endif // __riscv + +// ── Minimal print helpers (no sprintf / printf dependency) ──────────────────── + +static void print(const char *s) { + size_t n = 0; + while (s[n]) ++n; + sys_write(s, n); +} + +static void print_float(float v) { + if (v < 0.0f) { print("-"); v = -v; } + int whole = (int)v; + int frac = (int)((v - (float)whole) * 10000.0f + 0.5f); + char buf[20]; + int i = 19; + buf[i--] = '\0'; + for (int j = 0; j < 4; ++j) { buf[i--] = (char)('0' + frac % 10); frac /= 10; } + buf[i--] = '.'; + if (whole == 0) { buf[i--] = '0'; } + else { while (whole > 0) { buf[i--] = (char)('0' + whole % 10); whole /= 10; } } + print(&buf[i + 1]); +} + +// ── Inline RVV matmul core (no IREE headers required) ───────────────────────── +// Guards on __riscv_vector (set by clang when -march=rv64gcv is active) rather +// than __riscv so the RVV path only compiles when vector intrinsics are present. + +#ifdef __riscv_vector +#include + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; B += B_off; C += C_off; + const int b_unit = (B_s1 == 1); + const int c_unit = (C_s1 == 1); + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + for (int64_t k = 0; k < K; ++k) { + float a_val = A[m * A_s0 + k * A_s1]; + vfloat32m4_t b_vec = b_unit + ? __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl) + : __riscv_vlse32_v_f32m4(&B[k * B_s0 + n * B_s1], + B_s1 * (int64_t)sizeof(float), vl); + acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); + } + if (c_unit) + __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); + else + __riscv_vsse32_v_f32m4(&C[m * C_s0 + n * C_s1], + C_s1 * (int64_t)sizeof(float), acc, vl); + n += (int64_t)vl; + } + } +} + +#else // scalar fallback + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; B += B_off; C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +// ── Tests ───────────────────────────────────────────────────────────────────── + +static int tests_passed = 0; +static int tests_failed = 0; + +static float _fabsf(float x) { return x < 0.0f ? -x : x; } + +static void check(const char *name, float got, float expected) { + if (_fabsf(got - expected) < 1e-4f) { + ++tests_passed; + } else { + ++tests_failed; + print(" FAIL "); print(name); + print(" got="); print_float(got); + print(" expected="); print_float(expected); print("\n"); + } +} + +// Test 1: A * I = A (4×4, row-major) +static void test_identity(void) { + print("[1] A * I = A (4x4 row-major)\n"); + float A[16] = { 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16 }; + float I[16] = { 1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1 }; + float C[16] = {0}; + rvv_matmul_core(A, 0,4,1, I, 0,4,1, C, 0,4,1, 4,4,4); + for (int i = 0; i < 16; ++i) check("A*I", C[i], A[i]); +} + +// Test 2: 2x3 * 3x2 = 2x2 +// [[58,64],[139,154]] +static void test_2x3x2(void) { + print("[2] 2x3 * 3x2 = 2x2\n"); + float A[6] = {1,2,3, 4,5,6}; + float B[6] = {7,8, 9,10, 11,12}; + float C[4] = {0}; + rvv_matmul_core(A, 0,3,1, B, 0,2,1, C, 0,2,1, 2,2,3); + check("C[0,0]", C[0], 58.0f); + check("C[0,1]", C[1], 64.0f); + check("C[1,0]", C[2], 139.0f); + check("C[1,1]", C[3], 154.0f); +} + +// Test 3: column-major strides +static void test_col_major(void) { + print("[3] col-major strides\n"); + // A[2×3] col-major: stored [1,4, 2,5, 3,6], s0=1, s1=2 + float A[6] = {1,4, 2,5, 3,6}; + // B[3×2] col-major: stored [7,9,11, 8,10,12], s0=1, s1=3 + float B[6] = {7,9,11, 8,10,12}; + float C[4] = {0}; + rvv_matmul_core(A, 0,1,2, B, 0,1,3, C, 0,1,2, 2,2,3); + // C[m,n] = C[m + n*2] + check("C[0,0]", C[0], 58.0f); + check("C[1,0]", C[1], 139.0f); + check("C[0,1]", C[2], 64.0f); + check("C[1,1]", C[3], 154.0f); +} + +// Test 4: non-zero base offset +static void test_offset(void) { + print("[4] non-zero base offset\n"); + float A[8] = {99,99,99,99, 1,0, 0,1}; + float B[8] = {99,99,99,99, 3,0, 0,5}; + float C[8] = {0}; + rvv_matmul_core(A, 4,2,1, B, 4,2,1, C, 4,2,1, 2,2,2); + check("C[0,0]", C[4], 3.0f); + check("C[0,1]", C[5], 0.0f); + check("C[1,0]", C[6], 0.0f); + check("C[1,1]", C[7], 5.0f); +} + +// ── Entry point ─────────────────────────────────────────────────────────────── + +// On RISC-V nostdlib builds the linker expects _start. +// On host builds we emit main() so it links normally. + +#ifdef __riscv +void _start(void) { +#else +int main(void) { +#endif + print("=== rvv_matmul standalone test"); +#ifdef __riscv_vector + print(" [RVV]\n"); +#else + print(" [scalar]\n"); +#endif + + test_identity(); + test_2x3x2(); + test_col_major(); + test_offset(); + + print("\n"); + print(tests_failed == 0 ? "PASSED" : "FAILED"); + print(" ("); + char b[4] = {'0' + (char)tests_passed, ' ', '\0', '\0'}; + b[1] = '\0'; print(b); + print(" passed, "); + b[0] = '0' + (char)tests_failed; + print(b); + print(" failed)\n"); + +#ifdef __riscv + sys_exit(tests_failed == 0 ? 0 : 1); +#else + sys_exit(tests_failed == 0 ? 0 : 1); + return 0; +#endif +} From ee4e75dd51f56a54ae1054890c9095aef0a9ffa4 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 13:17:56 -0700 Subject: [PATCH 07/12] [MIPS] Restructure kernel library and update dialect description MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit matmul_kernel.h — public API declaration (my_matmul_kernel) matmul_kernel.c — RVV + scalar compute; no IREE headers matmul_plugin.c — IREE HAL executable plugin interface only Also update MIPS_Dialect description in MIPSBase.td to reflect the semantic role of the dialect as an abstraction layer for dispatching to highly-optimized kernels, with op fusion as a key use case. Co-Authored-By: Claude Sonnet 4.6 --- .../iree/compiler/Dialect/MIPS/IR/MIPSBase.td | 22 +- .../compiler/Dialect/MIPS/IR/MIPSDialect.cpp | 1 - runtime/src/iree/builtins/mips/CMakeLists.txt | 18 +- .../src/iree/builtins/mips/matmul_kernel.c | 110 ++++++++ .../src/iree/builtins/mips/matmul_kernel.h | 44 ++++ .../src/iree/builtins/mips/matmul_plugin.c | 118 +++++++++ .../src/iree/builtins/mips/my_matmul_kernel.c | 141 ---------- .../src/iree/builtins/mips/my_matmul_kernel.h | 40 --- .../iree/builtins/mips/rvv_matmul_kernel.c | 242 ------------------ .../iree/builtins/mips/rvv_standalone_test.c | 113 ++------ 10 files changed, 324 insertions(+), 525 deletions(-) create mode 100644 runtime/src/iree/builtins/mips/matmul_kernel.c create mode 100644 runtime/src/iree/builtins/mips/matmul_kernel.h create mode 100644 runtime/src/iree/builtins/mips/matmul_plugin.c delete mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.c delete mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.h delete mode 100644 runtime/src/iree/builtins/mips/rvv_matmul_kernel.c diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td index fd006c3bfcd7..470c2d624549 100644 --- a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td @@ -16,11 +16,24 @@ include "mlir/IR/OpBase.td" def MIPS_Dialect : Dialect { let name = "mips"; let cppNamespace = "::mlir::iree_compiler::IREE::MIPS"; - let summary = "MIPS custom compute dialect for experimental matmul pipeline."; + let summary = "Semantic dialect layer for dispatching to highly-optimized MIPS kernels."; let description = [{ - The MIPS dialect defines a single `mips.matmul` operation that serves as - a custom intermediate representation between the Torch frontend and the - final `func.call @my_matmul_kernel` generated by the MIPS lowering pass. + The MIPS dialect provides a semantic abstraction layer that bridges + high-level tensor operations and hand-tuned, target-specific kernel + implementations (e.g. RVV-vectorized routines). + + Ops in this dialect live entirely in the tensor domain — no memref forms + are ever produced. They are eliminated during One-Shot Bufferize: + the BufferizableOpInterface implementation allocates output buffers and + emits direct `func.call` instructions to the underlying C kernels. + + The semantic level allows the compiler to apply higher-level + transformations before lowering, such as: + - Fusing producer ops (e.g. transposes, type promotions) into a single + kernel call to avoid intermediate allocations. + - Tiling and packing decisions based on target vector width. + - Selecting between kernel variants (e.g. RVV LMUL=m4 vs. m8) based + on operand shapes or hardware configuration. Pipeline: torch.aten.mm → mips.matmul → func.call @my_matmul_kernel @@ -28,7 +41,6 @@ def MIPS_Dialect : Dialect { let dependentDialects = [ "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect", "::mlir::tensor::TensorDialect" ]; diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp index f136f11594fa..1f46719bd7e5 100644 --- a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp @@ -8,7 +8,6 @@ #include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" diff --git a/runtime/src/iree/builtins/mips/CMakeLists.txt b/runtime/src/iree/builtins/mips/CMakeLists.txt index 19574a5148f7..1dc3825b1975 100644 --- a/runtime/src/iree/builtins/mips/CMakeLists.txt +++ b/runtime/src/iree/builtins/mips/CMakeLists.txt @@ -6,18 +6,18 @@ # Shared library containing the MIPS custom matmul kernel. # -# The kernel is exposed as an IREE executable plugin (implements -# iree_hal_executable_plugin_query) so that iree-run-module can load it via: -# iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... +# matmul_kernel.c — RVV (or scalar fallback) compute kernel; no IREE headers. +# matmul_plugin.c — IREE HAL executable plugin interface. # -# The IREE runtime resolves the @my_matmul_kernel import from the dispatch -# executable through this plugin's resolve() function. +# The plugin is loaded at runtime via: +# iree-run-module --executable_plugin=libmips_matmul.so ... -add_library(my_matmul_kernel SHARED - my_matmul_kernel.c +add_library(mips_matmul SHARED + matmul_kernel.c + matmul_plugin.c ) -target_include_directories(my_matmul_kernel +target_include_directories(mips_matmul PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} PRIVATE @@ -25,7 +25,7 @@ target_include_directories(my_matmul_kernel ${PROJECT_SOURCE_DIR}/runtime/src ) -set_target_properties(my_matmul_kernel PROPERTIES +set_target_properties(mips_matmul PROPERTIES C_VISIBILITY_PRESET default POSITION_INDEPENDENT_CODE ON ) diff --git a/runtime/src/iree/builtins/mips/matmul_kernel.c b/runtime/src/iree/builtins/mips/matmul_kernel.c new file mode 100644 index 000000000000..02eccb06ccbf --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel.c @@ -0,0 +1,110 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// RVV (RISC-V Vector 1.0) f32 matmul kernel. +// +// This file is intentionally free of IREE headers so it can be: +// - Compiled to a .o and baked into the dispatch ELF (static embedding). +// - Linked into the IREE plugin .so alongside matmul_plugin.c. +// - Compiled standalone for unit tests (rvv_standalone_test.c). +// +// Vectorization strategy (RVV LMUL=m4): +// Outer loops: m (rows of A) and k (contraction axis). +// Inner loop : n (columns of B), vectorized with vsetvl_e32m4. +// Each vl-wide strip of C[m, n:n+vl] is accumulated across k before storing. +// Arbitrary strides handled via conditional vlse/vsse vs vle/vse. + +#include "matmul_kernel.h" + +#include + +#ifdef __riscv_vector +#include +#endif + +//===----------------------------------------------------------------------===// +// Internal compute kernel +//===----------------------------------------------------------------------===// + +#ifdef __riscv_vector + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + + const int a_unit_col = (A_s1 == 1); + const int b_unit_col = (B_s1 == 1); + const int c_unit_col = (C_s1 == 1); + + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t k = 0; k < K; ++k) { + float a_val = a_unit_col ? A[m * A_s0 + k] + : A[m * A_s0 + k * A_s1]; + vfloat32m4_t b_vec = + b_unit_col + ? __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl) + : __riscv_vlse32_v_f32m4( + &B[k * B_s0 + n * B_s1], + (ptrdiff_t)(B_s1 * (int64_t)sizeof(float)), vl); + acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); + } + + if (c_unit_col) + __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); + else + __riscv_vsse32_v_f32m4( + &C[m * C_s0 + n * C_s1], + (ptrdiff_t)(C_s1 * (int64_t)sizeof(float)), acc, vl); + n += (int64_t)vl; + } + } +} + +#else // scalar fallback + +static void rvv_matmul_core( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +//===----------------------------------------------------------------------===// +// Public entry point +//===----------------------------------------------------------------------===// + +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + rvv_matmul_core(A, A_off, A_s0, A_s1, + B, B_off, B_s0, B_s1, + C, C_off, C_s0, C_s1, + M, N, K); +} diff --git a/runtime/src/iree/builtins/mips/matmul_kernel.h b/runtime/src/iree/builtins/mips/matmul_kernel.h new file mode 100644 index 000000000000..6f46a5ced9fc --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel.h @@ -0,0 +1,44 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Public API for the MIPS matmul kernel. +// +// ABI contract (must match MIPSBufferizableOpInterface.cpp / decomposeMemref2D): +// Each 2-D memref is decomposed into (base_ptr, offset, stride0, stride1). +// {llvm.bareptr = true} means memref → float* in C. +// index → int64_t on RV64. +// +// C signature (15 args total): +// void my_matmul_kernel( +// float* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] +// float* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] +// float* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] +// int64_t M, int64_t N, int64_t K +// ); + +#ifndef IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ +#define IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// 2-D f32 matmul: C = A * B, Destination-Passing Style. +// Supports arbitrary row/col strides and non-zero base offsets. +// RVV-vectorized on RISC-V targets; scalar fallback elsewhere. +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K); + +#ifdef __cplusplus +} +#endif + +#endif // IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ diff --git a/runtime/src/iree/builtins/mips/matmul_plugin.c b/runtime/src/iree/builtins/mips/matmul_plugin.c new file mode 100644 index 000000000000..6f42338934aa --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_plugin.c @@ -0,0 +1,118 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// IREE Executable Plugin interface for the MIPS matmul kernel. +// +// This file wires my_matmul_kernel into the IREE HAL plugin ABI so the +// function can be resolved at runtime via --executable_plugin. +// +// Build as a shared library alongside matmul_kernel.c: +// clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ +// -O2 -fPIC -shared -nostdinc ... \ +// matmul_kernel.c matmul_plugin.c -o librvv_matmul.so + +#include "matmul_kernel.h" + +#include "iree/hal/local/executable_plugin.h" + +//===----------------------------------------------------------------------===// +// Import wrapper +//===----------------------------------------------------------------------===// +// The HAL plugin dispatch table expects functions with the signature: +// int fn(void *params_ptr, void *context, void *reserved) +// where params_ptr points to a packed struct matching the func.call ABI +// emitted by MIPSBufferizableOpInterface::bufferize() → decomposeMemref2D(). + +typedef struct { + float *A; + int64_t A_off, A_s0, A_s1; + float *B; + int64_t B_off, B_s0, B_s1; + float *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} matmul_kernel_args_t; + +static int matmul_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const matmul_kernel_args_t *a = (const matmul_kernel_args_t *)params_ptr; + my_matmul_kernel(a->A, a->A_off, a->A_s0, a->A_s1, + a->B, a->B_off, a->B_s0, a->B_s1, + a->C, a->C_off, a->C_s0, a->C_s1, + a->M, a->N, a->K); + return 0; +} + +//===----------------------------------------------------------------------===// +// Plugin lifecycle +//===----------------------------------------------------------------------===// + +static iree_hal_executable_plugin_status_t plugin_load( + const iree_hal_executable_plugin_environment_v0_t *environment, + size_t param_count, + const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { + (void)environment; + (void)param_count; + (void)params; + *out_self = NULL; + return iree_hal_executable_plugin_ok_status(); +} + +static void plugin_unload(void *self) { (void)self; } + +static iree_hal_executable_plugin_status_t plugin_resolve( + void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, + iree_hal_executable_plugin_resolution_t *out_resolution) { + (void)self; + *out_resolution = 0; + bool any_required_not_found = false; + + for (size_t i = 0; i < params->count; ++i) { + if (params->out_fn_ptrs[i]) continue; + const char *name = params->symbol_names[i]; + bool optional = iree_hal_executable_plugin_import_is_optional(name); + if (optional) ++name; + + if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { + params->out_fn_ptrs[i] = matmul_kernel_import; + params->out_fn_contexts[i] = NULL; + } else { + if (!optional) any_required_not_found = true; + } + } + + return any_required_not_found + ? iree_hal_executable_plugin_status_from_code( + IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) + : iree_hal_executable_plugin_ok_status(); +} + +//===----------------------------------------------------------------------===// +// Plugin query entry point +//===----------------------------------------------------------------------===// + +IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** +iree_hal_executable_plugin_query( + iree_hal_executable_plugin_version_t max_version, void *reserved) { + static const iree_hal_executable_plugin_header_t header = { + .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, + .name = "mips_matmul", + .description = "RISC-V RVV 1.0 matmul kernel plugin", + .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, + .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, + }; + static const iree_hal_executable_plugin_v0_t plugin = { + .header = &header, + .load = plugin_load, + .unload = plugin_unload, + .resolve = plugin_resolve, + }; + return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST + ? (const iree_hal_executable_plugin_header_t **)&plugin + : NULL; +} diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.c b/runtime/src/iree/builtins/mips/my_matmul_kernel.c deleted file mode 100644 index e952e72b22e0..000000000000 --- a/runtime/src/iree/builtins/mips/my_matmul_kernel.c +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// Naive triple-loop matmul kernel exposed as an IREE executable plugin. -// -// The IREE-compiled dispatch calls @my_matmul_kernel through the IREE import -// mechanism (not direct dynamic linking). At runtime the plugin is registered -// via: -// iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... -// -// IREE import calling convention (for all imports): -// int fn(void* params_ptr, void* context, void* reserved) -// -// The params_ptr points to a packed struct whose fields mirror the -// func.call arguments emitted by LowerMIPSToFuncCall, in order: -// float* A, int64_t A_off, A_s0, A_s1, -// float* B, int64_t B_off, B_s0, B_s1, -// float* C, int64_t C_off, C_s0, C_s1, -// int64_t M, int64_t N, int64_t K - -#include -#include - -// IREE standalone plugin header — only requires C99 standard headers. -#include "iree/hal/local/executable_plugin.h" - -//===----------------------------------------------------------------------===// -// Kernel implementation -//===----------------------------------------------------------------------===// - -// Packed argument struct mirroring the func.call arguments from -// LowerMIPSToFuncCall.cpp. -typedef struct { - float *A; - int64_t A_off, A_s0, A_s1; - float *B; - int64_t B_off, B_s0, B_s1; - float *C; - int64_t C_off, C_s0, C_s1; - int64_t M, N, K; -} my_matmul_kernel_args_t; - -// Import thunk wrapper — called by the IREE runtime with the packed args. -static int my_matmul_kernel_import(void *params_ptr, void *context, - void *reserved) { - (void)context; - (void)reserved; - const my_matmul_kernel_args_t *a = (const my_matmul_kernel_args_t *)params_ptr; - - float *A = a->A + a->A_off; - float *B = a->B + a->B_off; - float *C = a->C + a->C_off; - int64_t M = a->M, N = a->N, K = a->K; - int64_t A_s0 = a->A_s0, A_s1 = a->A_s1; - int64_t B_s0 = a->B_s0, B_s1 = a->B_s1; - int64_t C_s0 = a->C_s0, C_s1 = a->C_s1; - - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float acc = 0.0f; - for (int64_t k = 0; k < K; ++k) - acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; - C[m * C_s0 + n * C_s1] = acc; - } - } - return 0; -} - -//===----------------------------------------------------------------------===// -// IREE Executable Plugin interface -//===----------------------------------------------------------------------===// - -static iree_hal_executable_plugin_status_t plugin_load( - const iree_hal_executable_plugin_environment_v0_t *environment, - size_t param_count, - const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { - (void)environment; - (void)param_count; - (void)params; - *out_self = NULL; // stateless plugin - return iree_hal_executable_plugin_ok_status(); -} - -static void plugin_unload(void *self) { (void)self; } - -static iree_hal_executable_plugin_status_t plugin_resolve( - void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, - iree_hal_executable_plugin_resolution_t *out_resolution) { - (void)self; - *out_resolution = 0; - bool any_required_not_found = false; - - for (size_t i = 0; i < params->count; ++i) { - if (params->out_fn_ptrs[i]) continue; // already resolved - const char *name = params->symbol_names[i]; - bool optional = iree_hal_executable_plugin_import_is_optional(name); - if (optional) ++name; // skip the leading '?' - - if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { - params->out_fn_ptrs[i] = my_matmul_kernel_import; - params->out_fn_contexts[i] = NULL; - } else { - if (optional) { - *out_resolution |= - IREE_HAL_EXECUTABLE_PLUGIN_RESOLUTION_MISSING_OPTIONAL; - } else { - any_required_not_found = true; - } - } - } - - return any_required_not_found - ? iree_hal_executable_plugin_status_from_code( - IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) - : iree_hal_executable_plugin_ok_status(); -} - -// Exported entry point queried by the IREE runtime (via dlsym). -IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** -iree_hal_executable_plugin_query( - iree_hal_executable_plugin_version_t max_version, void *reserved) { - static const iree_hal_executable_plugin_header_t header = { - .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, - .name = "mips_matmul", - .description = "MIPS custom matmul kernel plugin", - .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, - .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, - }; - static const iree_hal_executable_plugin_v0_t plugin = { - .header = &header, - .load = plugin_load, - .unload = plugin_unload, - .resolve = plugin_resolve, - }; - return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST - ? (const iree_hal_executable_plugin_header_t **)&plugin - : NULL; -} diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.h b/runtime/src/iree/builtins/mips/my_matmul_kernel.h deleted file mode 100644 index 8f8319d6dad9..000000000000 --- a/runtime/src/iree/builtins/mips/my_matmul_kernel.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// C ABI for the MIPS custom matmul kernel. -// -// The kernel computes: -// C[m, n] = sum_k A[m, k] * B[k, n] -// -// Each matrix is passed as a base pointer plus explicit strided-layout -// parameters that match what MLIR's memref.extract_strided_metadata produces. -// This matches the calling convention emitted by LowerMIPSToFuncCallPass. - -#ifndef IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ -#define IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// void my_matmul_kernel( -// float *A, int64_t A_offset, int64_t A_stride0, int64_t A_stride1, -// float *B, int64_t B_offset, int64_t B_stride0, int64_t B_stride1, -// float *C, int64_t C_offset, int64_t C_stride0, int64_t C_stride1, -// int64_t M, int64_t N, int64_t K); -void my_matmul_kernel(float *A, int64_t A_offset, int64_t A_stride0, - int64_t A_stride1, float *B, int64_t B_offset, - int64_t B_stride0, int64_t B_stride1, float *C, - int64_t C_offset, int64_t C_stride0, int64_t C_stride1, - int64_t M, int64_t N, int64_t K); - -#ifdef __cplusplus -} -#endif - -#endif // IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ diff --git a/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c b/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c deleted file mode 100644 index 149073d416b5..000000000000 --- a/runtime/src/iree/builtins/mips/rvv_matmul_kernel.c +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// RVV (RISC-V Vector 1.0) matmul kernel exposed as an IREE executable plugin. -// -// ABI contract (must match MIPSBufferizableOpInterface.cpp / decomposeMemref2D): -// Each 2-D memref is decomposed into (base_ptr, offset, stride0, stride1). -// {llvm.bareptr = true} means memref → float* in C. -// index → int64_t on RV64. -// -// C signature (15 args total): -// void my_matmul_kernel( -// float* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] -// float* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] -// float* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] -// int64_t M, int64_t N, int64_t K -// ); -// -// Vectorization strategy (RVV LMUL=m4): -// - Outer loops: m (rows of A) and k (contraction axis) -// - Inner loop : n (columns of B), vectorized with vsetvl_e32m4 -// - Each vl-wide chunk of C[m, n:n+vl] is kept in a vfloat32m4_t accumulator -// and accumulated across the full k-dimension before storing. -// - Handles arbitrary strides via conditional vlse/vsse. - -#include -#include - -#include "iree/hal/local/executable_plugin.h" - -#ifdef __riscv_vector -#include -#endif - -//===----------------------------------------------------------------------===// -// Core RVV kernel -//===----------------------------------------------------------------------===// - -// The kernel is separated from the plugin wrapper so it can be unit-tested -// with a plain C harness (rvv_standalone_test.c) without IREE headers. -// -// For non-RV targets (e.g. when the host compiler processes this file) the -// function falls back to a scalar triple loop so the plugin still links. - -#ifdef __riscv_vector - -// RVV implementation — compiled only when targeting RV64GCV. -static void rvv_matmul_core( - const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, - const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, - float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, - int64_t M, int64_t N, int64_t K) { - // Apply offsets once — identical to the scalar implementation. - A += A_off; - B += B_off; - C += C_off; - - // Detect unit-stride fast paths at runtime to choose vle/vlse. - const int a_unit_col = (A_s1 == 1); - const int b_unit_col = (B_s1 == 1); - const int c_unit_col = (C_s1 == 1); - - for (int64_t m = 0; m < M; ++m) { - int64_t n = 0; - while (n < N) { - // Set vector length for this strip of N columns. - size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); - - // Zero accumulator — one vfloat32m4_t per output strip C[m, n:n+vl]. - vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); - - for (int64_t k = 0; k < K; ++k) { - // Scalar load: A[m, k] (row-major access — s0=N, s1=1 for row-major). - float a_val; - if (a_unit_col) { - a_val = A[m * A_s0 + k]; - } else { - a_val = A[m * A_s0 + k * A_s1]; - } - - // Vector load: B[k, n:n+vl]. - vfloat32m4_t b_vec; - if (b_unit_col) { - b_vec = __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl); - } else { - b_vec = __riscv_vlse32_v_f32m4( - &B[k * B_s0 + n * B_s1], - (ptrdiff_t)(B_s1 * (int64_t)sizeof(float)), vl); - } - - // acc += a_val * b_vec - acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); - } - - // Store accumulator to C[m, n:n+vl]. - if (c_unit_col) { - __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); - } else { - __riscv_vsse32_v_f32m4( - &C[m * C_s0 + n * C_s1], - (ptrdiff_t)(C_s1 * (int64_t)sizeof(float)), acc, vl); - } - n += (int64_t)vl; - } - } -} - -#else // !__riscv_vector - -// Scalar fallback — used when this file is compiled for the host (x86/arm) -// during CI or initial bring-up. -static void rvv_matmul_core( - const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, - const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, - float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, - int64_t M, int64_t N, int64_t K) { - A += A_off; - B += B_off; - C += C_off; - for (int64_t m = 0; m < M; ++m) - for (int64_t n = 0; n < N; ++n) { - float acc = 0.0f; - for (int64_t k = 0; k < K; ++k) - acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; - C[m * C_s0 + n * C_s1] = acc; - } -} - -#endif // __riscv_vector - -//===----------------------------------------------------------------------===// -// Direct-call entry point for static embedding -//===----------------------------------------------------------------------===// -// When the dispatch ELF is built with --iree-llvmcpu-link-embedded=true and -// the function declaration carries {hal.import.static}, LLVMCPU codegen emits -// a direct call to this symbol instead of going through the HAL import table. -// The lld wrapper appends this .o to the linker invocation so the symbol is -// resolved at link time. - -void my_matmul_kernel( - const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, - const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, - float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, - int64_t M, int64_t N, int64_t K) { - rvv_matmul_core(A, A_off, A_s0, A_s1, - B, B_off, B_s0, B_s1, - C, C_off, C_s0, C_s1, - M, N, K); -} - -//===----------------------------------------------------------------------===// -// IREE Executable Plugin interface -//===----------------------------------------------------------------------===// - -// Packed argument struct — mirrors the func.call argument list emitted by -// MIPSBufferizableOpInterface::bufferize() → decomposeMemref2D(). -typedef struct { - float *A; - int64_t A_off, A_s0, A_s1; - float *B; - int64_t B_off, B_s0, B_s1; - float *C; - int64_t C_off, C_s0, C_s1; - int64_t M, N, K; -} rvv_matmul_kernel_args_t; - -static int rvv_matmul_kernel_import(void *params_ptr, void *context, - void *reserved) { - (void)context; - (void)reserved; - const rvv_matmul_kernel_args_t *a = - (const rvv_matmul_kernel_args_t *)params_ptr; - rvv_matmul_core(a->A, a->A_off, a->A_s0, a->A_s1, - a->B, a->B_off, a->B_s0, a->B_s1, - a->C, a->C_off, a->C_s0, a->C_s1, - a->M, a->N, a->K); - return 0; -} - -static iree_hal_executable_plugin_status_t plugin_load( - const iree_hal_executable_plugin_environment_v0_t *environment, - size_t param_count, - const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { - (void)environment; - (void)param_count; - (void)params; - *out_self = NULL; - return iree_hal_executable_plugin_ok_status(); -} - -static void plugin_unload(void *self) { (void)self; } - -static iree_hal_executable_plugin_status_t plugin_resolve( - void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, - iree_hal_executable_plugin_resolution_t *out_resolution) { - (void)self; - *out_resolution = 0; - bool any_required_not_found = false; - - for (size_t i = 0; i < params->count; ++i) { - if (params->out_fn_ptrs[i]) continue; - const char *name = params->symbol_names[i]; - bool optional = iree_hal_executable_plugin_import_is_optional(name); - if (optional) ++name; - - if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { - params->out_fn_ptrs[i] = rvv_matmul_kernel_import; - params->out_fn_contexts[i] = NULL; - } else { - if (!optional) any_required_not_found = true; - } - } - - return any_required_not_found - ? iree_hal_executable_plugin_status_from_code( - IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) - : iree_hal_executable_plugin_ok_status(); -} - -IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** -iree_hal_executable_plugin_query( - iree_hal_executable_plugin_version_t max_version, void *reserved) { - static const iree_hal_executable_plugin_header_t header = { - .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, - .name = "rvv_matmul", - .description = "RISC-V RVV 1.0 matmul kernel plugin", - .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, - .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, - }; - static const iree_hal_executable_plugin_v0_t plugin = { - .header = &header, - .load = plugin_load, - .unload = plugin_unload, - .resolve = plugin_resolve, - }; - return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST - ? (const iree_hal_executable_plugin_header_t **)&plugin - : NULL; -} diff --git a/runtime/src/iree/builtins/mips/rvv_standalone_test.c b/runtime/src/iree/builtins/mips/rvv_standalone_test.c index 6d765f6d5521..1b68de181991 100644 --- a/runtime/src/iree/builtins/mips/rvv_standalone_test.c +++ b/runtime/src/iree/builtins/mips/rvv_standalone_test.c @@ -1,20 +1,23 @@ -// Standalone QEMU test for rvv_matmul_core. +// Standalone QEMU smoke-test for my_matmul_kernel. // -// Build for QEMU (no libc — RV-only syscall wrappers): +// Build for QEMU (no libc): // clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ // -O2 -static -nostdlib -ffreestanding \ -// rvv_standalone_test.c -o rvv_test +// matmul_kernel.c rvv_standalone_test.c -o rvv_test // qemu-riscv64 -cpu rv64,v=true,vlen=128,elen=64 ./rvv_test // -// Build for host validation (x86, scalar fallback, with libc): -// clang rvv_standalone_test.c -O2 -o rvv_test_host && ./rvv_test_host +// Build for host validation (x86, scalar fallback): +// clang matmul_kernel.c rvv_standalone_test.c -O2 -o rvv_test_host +// ./rvv_test_host + +#include "matmul_kernel.h" #include #include // ── I/O and exit ────────────────────────────────────────────────────────────── // RISC-V target: raw ecall (no libc dependency for -nostdlib build). -// Host (x86) target: libc stdio — so the --host build also works. +// Host (x86) target: libc stdio. #ifdef __riscv @@ -44,9 +47,9 @@ __attribute__((noreturn)) static void sys_exit(int code) { __builtin_unreachable(); } -void _start(void); // forward-declare; entry point at bottom +void _start(void); // forward-declare; entry point at bottom -#else // host x86 build +#else // host x86 #include #include @@ -55,13 +58,11 @@ static void sys_write(const char *buf, size_t len) { fwrite(buf, 1, len, stdout); } -__attribute__((noreturn)) static void sys_exit(int code) { - exit(code); -} +__attribute__((noreturn)) static void sys_exit(int code) { exit(code); } #endif // __riscv -// ── Minimal print helpers (no sprintf / printf dependency) ──────────────────── +// ── Minimal print helpers ───────────────────────────────────────────────────── static void print(const char *s) { size_t n = 0; @@ -83,64 +84,7 @@ static void print_float(float v) { print(&buf[i + 1]); } -// ── Inline RVV matmul core (no IREE headers required) ───────────────────────── -// Guards on __riscv_vector (set by clang when -march=rv64gcv is active) rather -// than __riscv so the RVV path only compiles when vector intrinsics are present. - -#ifdef __riscv_vector -#include - -static void rvv_matmul_core( - const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, - const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, - float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, - int64_t M, int64_t N, int64_t K) { - A += A_off; B += B_off; C += C_off; - const int b_unit = (B_s1 == 1); - const int c_unit = (C_s1 == 1); - for (int64_t m = 0; m < M; ++m) { - int64_t n = 0; - while (n < N) { - size_t vl = __riscv_vsetvl_e32m4((size_t)(N - n)); - vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); - for (int64_t k = 0; k < K; ++k) { - float a_val = A[m * A_s0 + k * A_s1]; - vfloat32m4_t b_vec = b_unit - ? __riscv_vle32_v_f32m4(&B[k * B_s0 + n], vl) - : __riscv_vlse32_v_f32m4(&B[k * B_s0 + n * B_s1], - B_s1 * (int64_t)sizeof(float), vl); - acc = __riscv_vfmacc_vf_f32m4(acc, a_val, b_vec, vl); - } - if (c_unit) - __riscv_vse32_v_f32m4(&C[m * C_s0 + n], acc, vl); - else - __riscv_vsse32_v_f32m4(&C[m * C_s0 + n * C_s1], - C_s1 * (int64_t)sizeof(float), acc, vl); - n += (int64_t)vl; - } - } -} - -#else // scalar fallback - -static void rvv_matmul_core( - const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, - const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, - float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, - int64_t M, int64_t N, int64_t K) { - A += A_off; B += B_off; C += C_off; - for (int64_t m = 0; m < M; ++m) - for (int64_t n = 0; n < N; ++n) { - float acc = 0.0f; - for (int64_t k = 0; k < K; ++k) - acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; - C[m * C_s0 + n * C_s1] = acc; - } -} - -#endif // __riscv_vector - -// ── Tests ───────────────────────────────────────────────────────────────────── +// ── Test harness ────────────────────────────────────────────────────────────── static int tests_passed = 0; static int tests_failed = 0; @@ -158,24 +102,25 @@ static void check(const char *name, float got, float expected) { } } +// ── Tests ───────────────────────────────────────────────────────────────────── + // Test 1: A * I = A (4×4, row-major) static void test_identity(void) { print("[1] A * I = A (4x4 row-major)\n"); float A[16] = { 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,16 }; float I[16] = { 1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1 }; float C[16] = {0}; - rvv_matmul_core(A, 0,4,1, I, 0,4,1, C, 0,4,1, 4,4,4); + my_matmul_kernel(A, 0,4,1, I, 0,4,1, C, 0,4,1, 4,4,4); for (int i = 0; i < 16; ++i) check("A*I", C[i], A[i]); } -// Test 2: 2x3 * 3x2 = 2x2 -// [[58,64],[139,154]] +// Test 2: 2x3 * 3x2 = 2x2 → [[58,64],[139,154]] static void test_2x3x2(void) { print("[2] 2x3 * 3x2 = 2x2\n"); float A[6] = {1,2,3, 4,5,6}; float B[6] = {7,8, 9,10, 11,12}; float C[4] = {0}; - rvv_matmul_core(A, 0,3,1, B, 0,2,1, C, 0,2,1, 2,2,3); + my_matmul_kernel(A, 0,3,1, B, 0,2,1, C, 0,2,1, 2,2,3); check("C[0,0]", C[0], 58.0f); check("C[0,1]", C[1], 64.0f); check("C[1,0]", C[2], 139.0f); @@ -190,8 +135,8 @@ static void test_col_major(void) { // B[3×2] col-major: stored [7,9,11, 8,10,12], s0=1, s1=3 float B[6] = {7,9,11, 8,10,12}; float C[4] = {0}; - rvv_matmul_core(A, 0,1,2, B, 0,1,3, C, 0,1,2, 2,2,3); - // C[m,n] = C[m + n*2] + my_matmul_kernel(A, 0,1,2, B, 0,1,3, C, 0,1,2, 2,2,3); + // C[m,n] stored at C[m + n*2] check("C[0,0]", C[0], 58.0f); check("C[1,0]", C[1], 139.0f); check("C[0,1]", C[2], 64.0f); @@ -204,7 +149,7 @@ static void test_offset(void) { float A[8] = {99,99,99,99, 1,0, 0,1}; float B[8] = {99,99,99,99, 3,0, 0,5}; float C[8] = {0}; - rvv_matmul_core(A, 4,2,1, B, 4,2,1, C, 4,2,1, 2,2,2); + my_matmul_kernel(A, 4,2,1, B, 4,2,1, C, 4,2,1, 2,2,2); check("C[0,0]", C[4], 3.0f); check("C[0,1]", C[5], 0.0f); check("C[1,0]", C[6], 0.0f); @@ -213,9 +158,6 @@ static void test_offset(void) { // ── Entry point ─────────────────────────────────────────────────────────────── -// On RISC-V nostdlib builds the linker expects _start. -// On host builds we emit main() so it links normally. - #ifdef __riscv void _start(void) { #else @@ -236,17 +178,14 @@ int main(void) { print("\n"); print(tests_failed == 0 ? "PASSED" : "FAILED"); print(" ("); - char b[4] = {'0' + (char)tests_passed, ' ', '\0', '\0'}; - b[1] = '\0'; print(b); + char b[4]; b[1] = '\0'; + b[0] = '0' + (char)tests_passed; print(b); print(" passed, "); - b[0] = '0' + (char)tests_failed; - print(b); + b[0] = '0' + (char)tests_failed; print(b); print(" failed)\n"); -#ifdef __riscv - sys_exit(tests_failed == 0 ? 0 : 1); -#else sys_exit(tests_failed == 0 ? 0 : 1); +#ifndef __riscv return 0; #endif } From 123b7a80bd5c3049683a3d980d8f333d56718339 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 13:50:34 -0700 Subject: [PATCH 08/12] [MIPS] Remove LowerMIPSToFuncCallPass and obsolete zsh scripts mips.matmul is eliminated entirely during One-Shot Bufferize: MIPSBufferizableOpInterface emits func.call @my_matmul_kernel directly, so LowerMIPSToFuncCallPass was already a no-op placeholder. Also remove compile_mips.zsh and run_mips.zsh. Co-Authored-By: Claude Sonnet 4.6 --- compile_mips.zsh | 63 ------------------- .../InputConversion/ConvertTorchToMIPS.cpp | 1 - .../input/Torch/InputConversion/Passes.td | 5 +- .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 - .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 5 -- .../Dialect/MIPS/Transforms/CMakeLists.txt | 56 ----------------- .../MIPS/Transforms/LowerMIPSToFuncCall.cpp | 36 ----------- .../Dialect/MIPS/Transforms/Passes.cpp | 18 ------ .../compiler/Dialect/MIPS/Transforms/Passes.h | 33 ---------- .../Dialect/MIPS/Transforms/Passes.td | 49 --------------- .../MIPS/Transforms/test/CMakeLists.txt | 15 ----- .../Transforms/test/lower_to_func_call.mlir | 49 --------------- .../iree/compiler/Tools/init_iree_passes.h | 2 - run_mips.zsh | 45 ------------- 14 files changed, 3 insertions(+), 375 deletions(-) delete mode 100755 compile_mips.zsh delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt delete mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir delete mode 100755 run_mips.zsh diff --git a/compile_mips.zsh b/compile_mips.zsh deleted file mode 100755 index c92271ce30f5..000000000000 --- a/compile_mips.zsh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/zsh -# compile_mips.zsh -# -# Compiles a simple torch.aten.mm through the MIPS custom matmul path: -# torch.aten.mm -# -> mips.matmul (ConvertTorchToMIPSPass) -# -> flow.dispatch(...) (IREE dispatch formation) -# -> func.call @my_matmul_kernel (bufferize via BufferizableOpInterface) -# -> LLVM / vmfb (iree-compile LLVMCPU backend) -# -# Output: /tmp/mm_mips.vmfb -# IR dump: /tmp/mm_mips_ir_dump.mlir (--mlir-print-ir-after-all) - -set -e # exit on first error - -export LD_LIBRARY_PATH="$HOME/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" -BUILD=$HOME/MLIR_Work/mips/iree-build -IREE_OPT=$BUILD/tools/iree-opt -IREE_COMPILE=$BUILD/tools/iree-compile - -# ── Input: 4x4 f32 matrix multiply ──────────────────────────────────────────── -cat > /tmp/mm_torch.mlir << 'EOF' -module { - func.func @mm(%A: !torch.vtensor<[4,4],f32>, - %B: !torch.vtensor<[4,4],f32>) - -> !torch.vtensor<[4,4],f32> { - %0 = torch.aten.mm %A, %B - : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> - -> !torch.vtensor<[4,4],f32> - return %0 : !torch.vtensor<[4,4],f32> - } -} -EOF - -# ── Step 1: Verify torch.aten.mm → mips.matmul ──────────────────────────────── -echo "==> Step 1: verifying torch.aten.mm → mips.matmul" -$IREE_OPT \ - --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ - /tmp/mm_torch.mlir \ - | grep -q "mips.matmul" && echo " [OK] mips.matmul found in IR" - -# ── Step 2: Full torch → IREE input IR (with MIPS path enabled) ─────────────── -echo "==> Step 2: torch → IREE input IR (use-mips-matmul=true)" -$IREE_OPT \ - --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ - /tmp/mm_torch.mlir \ - -o /tmp/mm_iree.mlir - -# ── Step 3: IREE input IR → vmfb (dispatch + bufferize + LLVM) ─────────────── -IR_DUMP=/tmp/mm_mips_ir_dump.mlir -echo "==> Step 3: IREE input IR → vmfb (IR dump → $IR_DUMP)" -$IREE_COMPILE \ - --iree-hal-target-backends=llvm-cpu \ - --iree-llvmcpu-link-embedded=false \ - --mlir-print-ir-after-all \ - /tmp/mm_iree.mlir \ - -o /tmp/mm_mips.vmfb \ - 2>"$IR_DUMP" - -echo "" -echo "==> Compiled successfully: /tmp/mm_mips.vmfb" -echo " IR dump written to: $IR_DUMP" -echo " Run with: ./run_mips.zsh" diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp index 1999f8695dfd..16273f810684 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp @@ -15,7 +15,6 @@ // 3. Emits mips.matmul on builtin tensors. // 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp. // -// This mirrors the approach in ConvertTorchUnstructuredToLinalgExt.cpp. #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index fb55ba7d189f..e030ff6ba882 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -34,8 +34,9 @@ def ConvertTorchToMIPSPass : let summary = "Convert torch.aten.mm to mips.matmul"; let description = [{ Intercepts torch.aten.mm before the standard torch->linalg conversion and - replaces it with mips.matmul. The mips.matmul op is later lowered to a - func.call @my_matmul_kernel by LowerMIPSToFuncCallPass after bufferization. + replaces it with mips.matmul. The mips.matmul op is eliminated entirely + during One-Shot Bufferize: MIPSBufferizableOpInterface emits a direct + func.call @my_matmul_kernel with decomposed memref arguments. }]; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 173e4cdc06db..ebd8d7d93755 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -166,7 +166,6 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils - iree::compiler::Dialect::MIPS::Transforms iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::Util::IR iree::compiler::Dialect::Util::Transforms diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index c0d9ed9981e4..771839bd3400 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -14,7 +14,6 @@ #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/Utils/CodegenOptions.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/TypeSwitch.h" @@ -519,10 +518,6 @@ static void addLowerToLLVMPasses(OpPassManager &modulePassManager, FunctionLikeNest(modulePassManager) .addPass(createEraseHALDescriptorTypeFromMemRefPass); - // mips.matmul is eliminated during One-Shot Bufferize (func.call emitted - // directly by MIPSBufferizableOpInterface). This pass is now a no-op. - modulePassManager.addPass(IREE::MIPS::createLowerMIPSToFuncCallPass()); - // Lower `ukernel.*` ops to function calls modulePassManager.addPass(createLowerUKernelOpsToCallsPass()); diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt deleted file mode 100644 index b37cc85e3b29..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -iree_add_all_subdirs() - -# ─── Tablegen: generate Passes.h.inc from Passes.td ────────────────────────── - -iree_tablegen_library( - NAME - PassesIncGen - TD_FILE - "Passes.td" - OUTS - --gen-pass-decls Passes.h.inc -) - -# ─── Header-only pass declarations ─────────────────────────────────────────── - -iree_cc_library( - NAME - PassHeaders - HDRS - "Passes.h" - "Passes.h.inc" - DEPS - ::PassesIncGen - MLIRPass - MLIRTransforms - PUBLIC -) - -# ─── Full transforms library ───────────────────────────────────────────────── - -iree_cc_library( - NAME - Transforms - HDRS - "Passes.h" - SRCS - "LowerMIPSToFuncCall.cpp" - "Passes.cpp" - DEPS - ::PassHeaders - ::PassesIncGen - iree::compiler::Dialect::MIPS::IR - MLIRFuncDialect - MLIRIR - MLIRMemRefDialect - MLIRPass - MLIRTransformUtils - MLIRTransforms - PUBLIC -) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp deleted file mode 100644 index df9caa9608f1..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -// mips.matmul is a tensor-only op that is eliminated entirely during -// One-Shot Bufferize: the BufferizableOpInterface implementation in -// MIPSBufferizableOpInterface.cpp emits func.call @my_matmul_kernel directly. -// -// This pass is therefore a no-op and exists only for registration purposes -// (so that --iree-mips-lower-to-func-call can be specified on the command line -// without error, and so that any pipeline that references it still compiles). - -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep -#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::IREE::MIPS { - -#define GEN_PASS_DEF_LOWERMIPSTOFUNCCALLPASS -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" - -namespace { - -struct LowerMIPSToFuncCallPass - : impl::LowerMIPSToFuncCallPassBase { - void runOnOperation() override { - // mips.matmul is eliminated during One-Shot Bufferize (see - // MIPSBufferizableOpInterface.cpp). No work to do here. - } -}; - -} // namespace -} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp deleted file mode 100644 index 751ea73789d8..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" - -namespace mlir::iree_compiler::IREE::MIPS { - -namespace { -#define GEN_PASS_REGISTRATION -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" -} // namespace - -void registerMIPSPasses() { registerPasses(); } - -} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h deleted file mode 100644 index 8b4390167407..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ -#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir::iree_compiler::IREE::MIPS { - -//===----------------------------------------------------------------------===// -// Pass factory functions (generated by tablegen + implemented in Passes.cpp) -//===----------------------------------------------------------------------===// - -/// Creates the pass that lowers memref-form mips.matmul to -/// func.call @my_matmul_kernel. -std::unique_ptr createLowerMIPSToFuncCallPass(); - -/// Registers all MIPS passes with the global pass registry so they can be -/// invoked from the command line (e.g. `iree-opt --iree-mips-lower-to-func-call`). -void registerMIPSPasses(); - -} // namespace mlir::iree_compiler::IREE::MIPS - -// clang-format off -#define GEN_PASS_DECL -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" // IWYU pragma: keep -// clang-format on - -#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td deleted file mode 100644 index b081e3a58a7c..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES -#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES - -include "mlir/Pass/PassBase.td" - -//===----------------------------------------------------------------------===// -// LowerMIPSToFuncCallPass -// -// Converts memref-form mips.matmul operations (produced after one-shot -// bufferization in the LLVMCPU codegen pipeline) to: -// -// func.func private @my_matmul_kernel(...) attributes {llvm.bareptr = true} -// func.call @my_matmul_kernel(A_base, A_off, A_s0, A_s1, -// B_base, B_off, B_s0, B_s1, -// C_base, C_off, C_s0, C_s1, -// M, N, K) -// -// The pass uses memref.extract_strided_metadata to decompose each memref into -// a (base_pointer, offset, strides...) tuple matching the C ABI of the kernel. -//===----------------------------------------------------------------------===// - -def LowerMIPSToFuncCallPass : - Pass<"iree-mips-lower-to-func-call", "ModuleOp"> { - let summary = "Lower mips.matmul (memref form) to func.call @my_matmul_kernel"; - let description = [{ - Walks all mips.matmul ops in the module and replaces each one with a call - to the external C kernel `my_matmul_kernel`. The memref operands are - decomposed via `memref.extract_strided_metadata` into base-pointer + offset - + stride arguments, matching the ABI declared in my_matmul_kernel.h. - - The pass creates a `func.func private @my_matmul_kernel` declaration with - `{llvm.bareptr = true}` so that the LLVM backend passes bare float* pointers - rather than MLIR memref descriptor structs. - - This pass runs after one-shot bufferization in the LLVMCPU codegen pipeline. - }]; - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::memref::MemRefDialect" - ]; -} - -#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt deleted file mode 100644 index 725370463654..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -iree_lit_test_suite( - NAME - lit - SRCS - "lower_to_func_call.mlir" - TOOLS - FileCheck - iree-opt -) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir deleted file mode 100644 index 582419acf589..000000000000 --- a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mips-lower-to-func-call %s \ -// RUN: | FileCheck %s - -// ───────────────────────────────────────────────────────────────────────────── -// Basic static-shape: memref-form mips.matmul → func.call @my_matmul_kernel -// ───────────────────────────────────────────────────────────────────────────── - -// CHECK: func.func private @my_matmul_kernel -// CHECK-SAME: {llvm.bareptr = true} -// -// CHECK-LABEL: func.func @lower_mips_matmul -// CHECK-NOT: mips.matmul -// CHECK: memref.extract_strided_metadata -// CHECK: call @my_matmul_kernel -module { - func.func @lower_mips_matmul(%A: memref<4x8xf32>, - %B: memref<8x4xf32>, - %C: memref<4x4xf32>) { - mips.matmul %A, %B, %C - : memref<4x8xf32>, memref<8x4xf32>, memref<4x4xf32> - return - } -} - -// ───────────────────────────────────────────────────────────────────────────── -// Multiple matmuls reuse the same @my_matmul_kernel declaration. -// ───────────────────────────────────────────────────────────────────────────── - -// CHECK: func.func private @my_matmul_kernel -// Check that there is exactly one declaration (not two). -// CHECK-NOT: func.func private @my_matmul_kernel -// -// CHECK-LABEL: func.func @two_matmuls -// CHECK: call @my_matmul_kernel -// CHECK: call @my_matmul_kernel -module { - func.func @two_matmuls(%A: memref<2x4xf32>, - %B: memref<4x2xf32>, - %C: memref<2x2xf32>, - %D: memref<2x4xf32>, - %E: memref<4x2xf32>, - %F: memref<2x2xf32>) { - mips.matmul %A, %B, %C - : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> - mips.matmul %D, %E, %F - : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> - return - } -} diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h index d113b8d61d3d..6f7de0752f45 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_passes.h +++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h @@ -20,7 +20,6 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" -#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Stream/Transforms/Passes.h" #include "iree/compiler/Dialect/TensorExt/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" @@ -63,7 +62,6 @@ inline void registerAllIreePasses() { IREE::HAL::Loader::registerHALLoaderPasses(); IREE::IO::Parameters::registerParametersPasses(); IREE::LinalgExt::registerPasses(); - IREE::MIPS::registerMIPSPasses(); IREE::Stream::registerStreamPasses(); IREE::TensorExt::registerPasses(); IREE::Util::registerUtilPasses(); diff --git a/run_mips.zsh b/run_mips.zsh deleted file mode 100755 index 3bee80deb071..000000000000 --- a/run_mips.zsh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/zsh -# run_mips.zsh -# -# Runs the vmfb produced by compile_mips.zsh. -# -# The vmfb's dispatch executable calls @my_matmul_kernel through the IREE -# import mechanism (not direct dynamic linking). The kernel is provided by -# libmy_matmul_kernel.dylib which implements the IREE executable plugin API -# (exports iree_hal_executable_plugin_query). -# -# Test: A * I = A (multiply by 4x4 identity → expect the same matrix back) - -BUILD=$HOME/MLIR_Work/mips/iree-build -KERNEL_LIB=$BUILD/runtime/src/iree/builtins/mips/libmy_matmul_kernel.so -IREE_RUN=$BUILD/tools/iree-run-module - -# conda libstdc++ must be visible when iree-run-module dlopen()s the .so -export LD_LIBRARY_PATH="$HOME/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" -VMFB=/tmp/mm_mips.vmfb - -if [[ ! -f $VMFB ]]; then - echo "ERROR: $VMFB not found. Run compile_mips.zsh first." - exit 1 -fi - -if [[ ! -f $KERNEL_LIB ]]; then - echo "ERROR: $KERNEL_LIB not found. Build my_matmul_kernel target first." - exit 1 -fi - -# A = 1..16 (row-major 4x4), B = 4x4 identity -A="4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" -B="4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" - -echo "==> Running mm(A, I) via MIPS kernel" -echo " Kernel plugin : $KERNEL_LIB" -echo " Expected : A * I = A (rows: [1 2 3 4], [5 6 7 8], ...)" -echo "" - -$IREE_RUN \ - --executable_plugin=$KERNEL_LIB \ - --module=$VMFB \ - --function=mm \ - --input="$A" \ - --input="$B" From 2685d5711f5fe2ff1318482f3dd5df19e1c52af8 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 14:38:42 -0700 Subject: [PATCH 09/12] [MIPS] Add RVV/QEMU workflow scripts and kernel library docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - build_tools/riscv/setup_qemu_workflow.sh: One-time setup for toolchain (conda clang-18/lld-18), RISC-V sysroot, QEMU 8.2.2 (riscv64-linux-user from source), and both IREE builds (host x86 + RISC-V cross). Supports --step=N to run individual steps. - build_tools/riscv/rvv_qemu_workflow_static.sh: Clean 6-step static embedding pipeline — torch→flow IR, compile .o, create lld_wrapper, iree-compile with --iree-mips-static-embedding, ELF verification, and QEMU run with VLEN sweep. Supports --host and --vlen N. - build_tools/riscv/rvv_qemu_workflow_dynamic.sh: Parallel dynamic plugin pipeline — compile matmul_kernel.c + matmul_plugin.c into librvv_matmul.so, iree-compile without static embedding flag, and QEMU run with --executable_plugin. Supports --host and --vlen N. - runtime/src/iree/builtins/mips/README.md: Documents the kernel library layout, ABI, build commands for standalone/static/dynamic targets, and integration with the IREE MIPS dialect. Co-Authored-By: Claude Sonnet 4.6 --- .../riscv/rvv_qemu_workflow_dynamic.sh | 207 +++++++++++++ build_tools/riscv/rvv_qemu_workflow_static.sh | 219 ++++++++++++++ build_tools/riscv/setup_qemu_workflow.sh | 272 ++++++++++++++++++ runtime/src/iree/builtins/mips/README.md | 138 +++++++++ 4 files changed, 836 insertions(+) create mode 100755 build_tools/riscv/rvv_qemu_workflow_dynamic.sh create mode 100644 build_tools/riscv/rvv_qemu_workflow_static.sh create mode 100644 build_tools/riscv/setup_qemu_workflow.sh create mode 100644 runtime/src/iree/builtins/mips/README.md diff --git a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh new file mode 100755 index 000000000000..580617fe5bbf --- /dev/null +++ b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh @@ -0,0 +1,207 @@ +#!/usr/bin/env bash +# rvv_qemu_workflow_dynamic.sh +# +# End-to-end MIPS matmul pipeline — DYNAMIC plugin loading. +# +# The RVV kernel is compiled into a shared library (.so) that is loaded at +# runtime via --executable_plugin. No custom linker wrapper is needed at +# iree-compile time. +# +# Pipeline: +# mips_matmul_test.mlir +# ─[iree-opt torch-to-iree{use-mips-matmul=true}]─► flow.mlir +# ─[clang --target=riscv64 -shared]──────────────► librvv_matmul.so +# ─[iree-compile --iree-llvmcpu-link-embedded=false]► matmul.vmfb +# ─[qemu-riscv64 iree-run-module --executable_plugin]► result +# +# Usage: +# bash rvv_qemu_workflow_dynamic.sh # RISC-V QEMU, vlen=512 +# bash rvv_qemu_workflow_dynamic.sh --host # x86 host (scalar fallback) +# bash rvv_qemu_workflow_dynamic.sh --vlen 256 # QEMU with vlen=256 + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration +# ───────────────────────────────────────────────────────────────────────────── +IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +WORK_DIR="${HOME}/MLIR_Work/mips" +HOST_BUILD="${WORK_DIR}/iree-build" +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" +OUT_DIR="${WORK_DIR}/out/dynamic" + +IREE_OPT="${HOST_BUILD}/tools/iree-opt" +IREE_COMPILE="${HOST_BUILD}/tools/iree-compile" +HOST_RUN="${HOST_BUILD}/tools/iree-run-module" +RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" +QEMU="${HOME}/local/bin/qemu-riscv64" +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" + +CLANG="${HOME}/miniforge3/bin/clang" +LLD="${HOME}/miniforge3/bin/ld.lld" +CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" + +KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" +PLUGIN_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_plugin.c" +TEST_MLIR="${WORK_DIR}/mips_matmul_test.mlir" + +# Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +HOST_MODE=0 +VLEN=512 +while [[ $# -gt 0 ]]; do + case "$1" in + --host) HOST_MODE=1 ;; + --vlen) shift; VLEN="$1" ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac + shift +done + +mkdir -p "${OUT_DIR}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +section() { echo ""; echo "══[ $* ]══════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +run_qemu() { + local vlen="$1"; shift + "${QEMU}" -cpu "rv64,v=true,vlen=${vlen},elen=64,vext_spec=v1.0" \ + -L "${SYSROOT}" "${RISCV_RUN}" "$@" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: torch → IREE flow IR +# ───────────────────────────────────────────────────────────────────────────── +section "Step 1: torch → IREE flow IR" + +"${IREE_OPT}" \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + "${TEST_MLIR}" -o "${OUT_DIR}/flow.mlir" +ok "${OUT_DIR}/flow.mlir" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Cross-compile kernel + plugin → shared library +# +# matmul_kernel.c — compute logic (no IREE headers) +# matmul_plugin.c — IREE HAL executable plugin interface +# Both compiled together into a single -fPIC -shared .so. +# ───────────────────────────────────────────────────────────────────────────── +section "Step 2: Compile matmul_kernel.c + matmul_plugin.c → .so" + +PLUGIN_SO="${OUT_DIR}/librvv_matmul.so" + +if [[ "${HOST_MODE}" == "1" ]]; then + "${CLANG}" --target=x86_64-linux-gnu \ + -O2 -fPIC -shared \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" "${PLUGIN_SRC}" -o "${PLUGIN_SO}" + ok "x86 scalar plugin: ${PLUGIN_SO}" +else + "${CLANG}" --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -fPIC -shared -nostdinc -nostdlib \ + -isystem "${CLANG_INC}" \ + -fuse-ld="${LLD}" \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" "${PLUGIN_SRC}" -o "${PLUGIN_SO}" + ok "RISC-V RVV plugin: ${PLUGIN_SO} ($(file -b "${PLUGIN_SO}" | cut -d, -f1))" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: iree-compile → .vmfb (kernel resolved at runtime via plugin) +# +# --iree-llvmcpu-link-embedded=false — host-ABI shared object, not embedded ELF +# No --iree-mips-static-embedding — my_matmul_kernel is a HAL import entry +# ───────────────────────────────────────────────────────────────────────────── +section "Step 3: iree-compile → .vmfb (dynamic)" + +if [[ "${HOST_MODE}" == "1" ]]; then + RISCV_FLAGS=() +else + RISCV_FLAGS=( + "--iree-llvmcpu-target-triple=riscv64-linux-gnu" + "--iree-llvmcpu-target-abi=lp64d" + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+c,+zvl512b,+v" + "--riscv-v-fixed-length-vector-lmul-max=8" + ) +fi + +VMFB="${OUT_DIR}/matmul_dynamic.vmfb" +"${IREE_COMPILE}" \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=false \ + "${RISCV_FLAGS[@]}" \ + "${OUT_DIR}/flow.mlir" -o "${VMFB}" +ok "${VMFB} ($(du -sh "${VMFB}" | cut -f1))" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: Verify kernel appears as an import (unresolved symbol) in the vmfb +# ───────────────────────────────────────────────────────────────────────────── +section "Step 4: Verify dynamic import in vmfb" + +ELF_OFFSET=$(grep -boa $'\x7fELF' "${VMFB}" 2>/dev/null | head -1 | cut -d: -f1 || true) +if [[ -n "${ELF_OFFSET}" ]]; then + dd if="${VMFB}" bs=1 skip="${ELF_OFFSET}" 2>/dev/null > "${OUT_DIR}/dispatch.elf" + python3 - "${OUT_DIR}/dispatch.elf" << 'PYEOF' +import sys +data = open(sys.argv[1], 'rb').read() +idx = data.find(b'my_matmul_kernel') +rvv = sum(1 for i in range(0, len(data)-3, 4) if data[i] & 0x7f == 0x57) +if idx != -1: + tag = "[ok]" if rvv == 0 else "[note]" + print(f" {tag} 'my_matmul_kernel' at offset {idx} (import entry, no RVV here)") + print(f" RVV opcode count in vmfb: {rvv} (expected 0 — kernel is in .so)") +else: + print(" [warn] 'my_matmul_kernel' not found in dispatch ELF") +PYEOF +else + echo " [warn] No ELF found in vmfb" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Run (kernel loaded from .so via --executable_plugin) +# ───────────────────────────────────────────────────────────────────────────── +section "Step 5: Run (--executable_plugin=${PLUGIN_SO})" + +MATMUL_ARGS=( + --module="${VMFB}" + --executable_plugin="${PLUGIN_SO}" + --function="matmul_4x4" + "--input=4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + "--input=4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +) + +if [[ "${HOST_MODE}" == "1" ]]; then + echo " Running on x86 host (scalar fallback)..." + "${HOST_RUN}" "${MATMUL_ARGS[@]}" +else + echo " Running under QEMU vlen=${VLEN}..." + run_qemu "${VLEN}" "${MATMUL_ARGS[@]}" + + echo "" + echo " VLEN sweep:" + for V in 128 256 512; do + printf " vlen=%-4s " "${V}:" + run_qemu "${V}" "${MATMUL_ARGS[@]}" 2>&1 | grep "4x4xf32" || echo "(no output)" + done + echo " Note: vlen=128 may produce zeros — vmfb compiled with +zvl512b" +fi + +echo "" +echo " Expected: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16]" + +# ───────────────────────────────────────────────────────────────────────────── +# Summary +# ───────────────────────────────────────────────────────────────────────────── +echo "" +echo "════════════════════════════════════════════════════════════" +echo " DONE — Dynamic plugin verified." +echo " Artifacts in ${OUT_DIR}/" +echo " librvv_matmul.so — plugin loaded at runtime" +echo " matmul_dynamic.vmfb — vmfb with HAL import (needs plugin)" +echo "════════════════════════════════════════════════════════════" diff --git a/build_tools/riscv/rvv_qemu_workflow_static.sh b/build_tools/riscv/rvv_qemu_workflow_static.sh new file mode 100644 index 000000000000..19bcb755a002 --- /dev/null +++ b/build_tools/riscv/rvv_qemu_workflow_static.sh @@ -0,0 +1,219 @@ +#!/usr/bin/env bash +# rvv_qemu_workflow_static.sh +# +# End-to-end MIPS matmul pipeline — STATIC kernel embedding. +# +# The RVV kernel (.o) is baked into the dispatch ELF inside the .vmfb at +# iree-compile time via a custom lld wrapper. No plugin .so is needed at +# runtime. +# +# Pipeline: +# mips_matmul_test.mlir +# ─[iree-opt torch-to-iree{use-mips-matmul=true}]─► flow.mlir +# ─[clang --target=riscv64]──────────────────────► matmul_kernel_riscv.o +# ─[lld_wrapper.sh] (appends .o to every dispatch link) +# ─[iree-compile --iree-mips-static-embedding]────► matmul.vmfb +# ─[qemu-riscv64 iree-run-module]────────────────► result (no --executable_plugin) +# +# Usage: +# bash rvv_qemu_workflow_static.sh # RISC-V QEMU, vlen=512 +# bash rvv_qemu_workflow_static.sh --host # x86 host (scalar fallback) +# bash rvv_qemu_workflow_static.sh --vlen 256 # QEMU with vlen=256 + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration +# ───────────────────────────────────────────────────────────────────────────── +IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +WORK_DIR="${HOME}/MLIR_Work/mips" +HOST_BUILD="${WORK_DIR}/iree-build" +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" +OUT_DIR="${WORK_DIR}/out/static" + +IREE_OPT="${HOST_BUILD}/tools/iree-opt" +IREE_COMPILE="${HOST_BUILD}/tools/iree-compile" +HOST_RUN="${HOST_BUILD}/tools/iree-run-module" +RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" +QEMU="${HOME}/local/bin/qemu-riscv64" +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" + +CLANG="${HOME}/miniforge3/bin/clang" +LLD="${HOME}/miniforge3/bin/ld.lld" +CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" + +KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" +TEST_MLIR="${WORK_DIR}/mips_matmul_test.mlir" + +# Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Argument parsing +# ───────────────────────────────────────────────────────────────────────────── +HOST_MODE=0 +VLEN=512 +while [[ $# -gt 0 ]]; do + case "$1" in + --host) HOST_MODE=1 ;; + --vlen) shift; VLEN="$1" ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac + shift +done + +mkdir -p "${OUT_DIR}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +section() { echo ""; echo "══[ $* ]══════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +run_qemu() { + local vlen="$1"; shift + "${QEMU}" -cpu "rv64,v=true,vlen=${vlen},elen=64,vext_spec=v1.0" \ + -L "${SYSROOT}" "${RISCV_RUN}" "$@" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: torch → IREE flow IR +# ───────────────────────────────────────────────────────────────────────────── +section "Step 1: torch → IREE flow IR" + +"${IREE_OPT}" \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + "${TEST_MLIR}" -o "${OUT_DIR}/flow.mlir" +ok "${OUT_DIR}/flow.mlir" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Cross-compile RVV kernel → RISC-V relocatable object +# ───────────────────────────────────────────────────────────────────────────── +section "Step 2: Compile matmul_kernel.c → .o" + +KERNEL_O="${OUT_DIR}/matmul_kernel_riscv.o" + +if [[ "${HOST_MODE}" == "1" ]]; then + "${CLANG}" --target=x86_64-linux-gnu \ + -O2 -c -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" -o "${KERNEL_O}" + ok "x86 scalar kernel: ${KERNEL_O}" +else + "${CLANG}" --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -c -nostdinc -isystem "${CLANG_INC}" \ + -I "${IREE_SRC}/runtime/src" \ + "${KERNEL_SRC}" -o "${KERNEL_O}" + ok "RISC-V RVV kernel: ${KERNEL_O} ($(file -b "${KERNEL_O}" | cut -d, -f1))" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: Create lld wrapper +# +# IREE calls its embedded linker as: +# lld -flavor gnu --no-undefined -nostdlib -static -shared ... dispatch.o +# We append the kernel .o so my_matmul_kernel resolves at link time. +# -Bsymbolic: bind all same-ELF symbols locally to avoid R_RISCV_JUMP_SLOT +# entries in .rela.plt — IREE's embedded ELF loader ignores DT_JMPREL +# (.rela.plt) and only processes DT_RELA (.rela.dyn). Without this flag, +# the PLT GOT slot is never patched, causing a segfault on the first call. +# ───────────────────────────────────────────────────────────────────────────── +section "Step 3: Create lld_wrapper.sh" + +LLD_WRAPPER="${OUT_DIR}/lld_wrapper.sh" +cat > "${LLD_WRAPPER}" << WRAPPER +#!/usr/bin/env bash +exec "${LLD}" "\$@" "${KERNEL_O}" -Bsymbolic +WRAPPER +chmod +x "${LLD_WRAPPER}" +ok "${LLD_WRAPPER} (appends ${KERNEL_O})" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: iree-compile → .vmfb (kernel statically linked) +# ───────────────────────────────────────────────────────────────────────────── +section "Step 4: iree-compile → .vmfb (static)" + +if [[ "${HOST_MODE}" == "1" ]]; then + RISCV_FLAGS=() +else + RISCV_FLAGS=( + "--iree-llvmcpu-target-triple=riscv64-linux-gnu" + "--iree-llvmcpu-target-abi=lp64d" + "--iree-llvmcpu-target-cpu-features=+m,+a,+f,+d,+c,+zvl512b,+v" + "--riscv-v-fixed-length-vector-lmul-max=8" + ) +fi + +VMFB="${OUT_DIR}/matmul_static.vmfb" +"${IREE_COMPILE}" \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=true \ + --iree-llvmcpu-embedded-linker-path="${LLD_WRAPPER}" \ + --iree-mips-static-embedding \ + "${RISCV_FLAGS[@]}" \ + "${OUT_DIR}/flow.mlir" -o "${VMFB}" +ok "${VMFB} ($(du -sh "${VMFB}" | cut -f1))" + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Verify kernel is embedded in the dispatch ELF +# ───────────────────────────────────────────────────────────────────────────── +section "Step 5: Verify static embedding" + +ELF_OFFSET=$(grep -boa $'\x7fELF' "${VMFB}" 2>/dev/null | head -1 | cut -d: -f1 || true) +if [[ -n "${ELF_OFFSET}" ]]; then + dd if="${VMFB}" bs=1 skip="${ELF_OFFSET}" 2>/dev/null > "${OUT_DIR}/dispatch.elf" + python3 - "${OUT_DIR}/dispatch.elf" << 'PYEOF' +import sys +data = open(sys.argv[1], 'rb').read() +idx = data.find(b'my_matmul_kernel') +rvv = sum(1 for i in range(0, len(data)-3, 4) if data[i] & 0x7f == 0x57) +if idx != -1: + tag = "[ok]" if rvv > 0 else "[warn]" + print(f" {tag} 'my_matmul_kernel' at offset {idx}, RVV instructions: {rvv}") +else: + print(" [warn] 'my_matmul_kernel' not found in dispatch ELF") +PYEOF +else + echo " [warn] No ELF found in vmfb" +fi + +# ───────────────────────────────────────────────────────────────────────────── +# Step 6: Run +# ───────────────────────────────────────────────────────────────────────────── +section "Step 6: Run (no --executable_plugin)" + +MATMUL_ARGS=( + --module="${VMFB}" + --function="matmul_4x4" + "--input=4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + "--input=4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +) + +if [[ "${HOST_MODE}" == "1" ]]; then + echo " Running on x86 host (scalar fallback)..." + "${HOST_RUN}" "${MATMUL_ARGS[@]}" +else + echo " Running under QEMU vlen=${VLEN}..." + run_qemu "${VLEN}" "${MATMUL_ARGS[@]}" + + echo "" + echo " VLEN sweep:" + for V in 128 256 512; do + printf " vlen=%-4s " "${V}:" + run_qemu "${V}" "${MATMUL_ARGS[@]}" 2>&1 | grep "4x4xf32" || echo "(no output)" + done + echo " Note: vlen=128 may produce zeros — vmfb compiled with +zvl512b" +fi + +echo "" +echo " Expected: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16]" + +# ───────────────────────────────────────────────────────────────────────────── +# Summary +# ───────────────────────────────────────────────────────────────────────────── +echo "" +echo "════════════════════════════════════════════════════════════" +echo " DONE — Static embedding verified." +echo " Artifacts in ${OUT_DIR}/" +echo " matmul_kernel_riscv.o — kernel object (baked into vmfb)" +echo " lld_wrapper.sh — linker interceptor" +echo " matmul_static.vmfb — self-contained vmfb (no plugin at runtime)" +echo "════════════════════════════════════════════════════════════" diff --git a/build_tools/riscv/setup_qemu_workflow.sh b/build_tools/riscv/setup_qemu_workflow.sh new file mode 100644 index 000000000000..6f4b985ca17f --- /dev/null +++ b/build_tools/riscv/setup_qemu_workflow.sh @@ -0,0 +1,272 @@ +#!/usr/bin/env bash +# setup_qemu_workflow.sh +# +# One-time setup for the MIPS/RVV QEMU workflow on a Linux host (Rocky 8). +# Installs toolchain, QEMU, and builds both IREE host and RISC-V targets. +# +# Steps (run all by default; pass --step=N to run one): +# 1. Install toolchain — ninja, clang-18, lld-18 via conda-forge +# 2. Install sysroot — RISC-V prebuilt sysroot + iree-run-module (RISC-V) +# 3. Build QEMU — qemu-riscv64 user-mode from source +# 4. Build IREE (host) — iree-opt, iree-compile, iree-run-module for x86 +# 5. Build IREE (riscv) — iree-run-module cross-compiled for RISC-V +# +# Usage: +# bash setup_qemu_workflow.sh # run all steps +# bash setup_qemu_workflow.sh --step=3 # run only QEMU build +# bash setup_qemu_workflow.sh --step=4 # run only host IREE build + +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────────────── +# Configuration — edit to match your environment +# ───────────────────────────────────────────────────────────────────────────── +WORK_DIR="${HOME}/MLIR_Work/mips" +IREE_SRC="${WORK_DIR}/iree" + +HOST_BUILD="${WORK_DIR}/iree-build" # iree-opt, iree-compile (x86) +RISCV_BUILD="${WORK_DIR}/iree-build-riscv" # iree-run-module (RISC-V) +QEMU_VER="8.2.2" +INSTALL_PREFIX="${HOME}/local" # qemu-riscv64 installed here + +CONDA="${HOME}/miniforge3/bin/conda" +CLANG="${HOME}/miniforge3/bin/clang" +CLANGXX="${HOME}/miniforge3/bin/clang++" +NINJA="${INSTALL_PREFIX}/bin/ninja" + +# RISC-V prebuilt sysroot (downloaded by riscv_bootstrap.sh) +SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" +RISCV_TOOLCHAIN="${HOME}/riscv/toolchain/clang/linux/RISCV" + +# Rocky 8's system libstdc++ is too old; conda's copy has GLIBCXX 3.4.29+. +export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── +STEP_ONLY=0 +for arg in "$@"; do + case "${arg}" in + --step=*) STEP_ONLY="${arg#--step=}" ;; + *) echo "Unknown arg: ${arg}"; exit 1 ;; + esac +done + +should_run() { [[ "${STEP_ONLY}" == "0" || "${STEP_ONLY}" == "$1" ]]; } + +log() { echo ""; echo "════════════════════════════════════════════════════════════"; echo " $*"; echo "════════════════════════════════════════════════════════════"; } +ok() { echo " [ok] $*"; } +skip() { echo " [skip] $*"; } +die() { echo " [FAIL] $*" >&2; exit 1; } + +# ───────────────────────────────────────────────────────────────────────────── +# Step 1: Install toolchain (ninja, clang-18, lld-18 via conda-forge) +# ───────────────────────────────────────────────────────────────────────────── +step1_toolchain() { + log "STEP 1: Install toolchain" + + # Miniforge (conda base) + if [[ -x "${CONDA}" ]]; then + skip "conda already at ${CONDA}" + else + local tmp; tmp="$(mktemp /tmp/miniforge_XXXXX.sh)" + echo " Downloading Miniforge..." + curl -fsSL "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \ + -o "${tmp}" + bash "${tmp}" -b -p "${HOME}/miniforge3" + rm -f "${tmp}" + ok "conda installed" + fi + + # ninja + if [[ -x "${NINJA}" ]]; then + skip "ninja already at ${NINJA}" + else + local tmp; tmp="$(mktemp /tmp/ninja_XXXXX.zip)" + curl -fsSL "https://github.com/ninja-build/ninja/releases/download/v1.12.1/ninja-linux.zip" \ + -o "${tmp}" + mkdir -p "${INSTALL_PREFIX}/bin" + unzip -qo "${tmp}" -d "${INSTALL_PREFIX}/bin" + chmod +x "${NINJA}" + rm -f "${tmp}" + ok "ninja installed" + fi + + # clang-18 + lld-18 + if [[ -x "${CLANG}" ]]; then + skip "clang already at ${CLANG} ($(${CLANG} --version | head -1))" + else + echo " Installing clang-18 + lld-18 (this may take a few minutes)..." + "${CONDA}" install -y -c conda-forge "clang=18" "clangxx=18" "lld=18" --no-update-deps + ok "clang-18 + lld-18 installed" + fi + + ok "Toolchain ready" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 2: Install RISC-V sysroot (IREE prebuilt) +# ───────────────────────────────────────────────────────────────────────────── +step2_sysroot() { + log "STEP 2: Install RISC-V sysroot" + + if [[ -d "${SYSROOT}" ]]; then + skip "Sysroot already at ${SYSROOT}" + return + fi + + echo " Running riscv_bootstrap.sh (interactive — prompts for download paths)..." + bash "$(dirname "${BASH_SOURCE[0]}")/riscv_bootstrap.sh" + ok "Sysroot installed at ${SYSROOT}" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 3: Build QEMU riscv64-linux-user from source +# ───────────────────────────────────────────────────────────────────────────── +step3_qemu() { + log "STEP 3: Build QEMU ${QEMU_VER} (riscv64-linux-user)" + + local qemu_bin="${INSTALL_PREFIX}/bin/qemu-riscv64" + if [[ -x "${qemu_bin}" ]]; then + skip "qemu-riscv64 already at ${qemu_bin} ($(${qemu_bin} --version | head -1))" + return + fi + + "${CONDA}" install -y -c conda-forge glib pkg-config 2>&1 | tail -3 + + local tarball="${WORK_DIR}/qemu-${QEMU_VER}.tar.xz" + if [[ ! -f "${tarball}" ]]; then + echo " Downloading QEMU ${QEMU_VER}..." + curl -fsSL --progress-bar "https://download.qemu.org/qemu-${QEMU_VER}.tar.xz" -o "${tarball}" + else + skip "Tarball already downloaded" + fi + + local src="${WORK_DIR}/qemu-${QEMU_VER}" + if [[ ! -d "${src}" ]]; then + echo " Extracting QEMU source..." + tar -xf "${tarball}" -C "${WORK_DIR}" + fi + + export PKG_CONFIG_PATH="${HOME}/miniforge3/lib/pkgconfig:${HOME}/miniforge3/share/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}}" + export PKG_CONFIG="${HOME}/miniforge3/bin/pkg-config" + export LDFLAGS="-Wl,-rpath,${HOME}/miniforge3/lib" + + echo " Configuring QEMU..." + cd "${src}" + ./configure \ + --prefix="${INSTALL_PREFIX}" \ + --target-list="riscv64-linux-user" \ + --disable-system \ + --enable-linux-user \ + --disable-werror \ + --disable-docs \ + --disable-gtk \ + --disable-sdl \ + --disable-vnc \ + --disable-curl \ + --disable-capstone \ + --disable-kvm \ + --without-default-features \ + --enable-user + + echo " Building QEMU ($(nproc) jobs)..." + if [[ -f "${src}/build/build.ninja" ]]; then + "${NINJA}" -C "${src}/build" -j"$(nproc)" + "${NINJA}" -C "${src}/build" install + else + make -j"$(nproc)" + make install + fi + + ok "qemu-riscv64 installed at ${qemu_bin}" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 4: Build IREE host (iree-opt, iree-compile, iree-run-module for x86) +# ───────────────────────────────────────────────────────────────────────────── +step4_iree_host() { + log "STEP 4: Build IREE (host — iree-opt, iree-compile, iree-run-module)" + + mkdir -p "${HOST_BUILD}" + + cmake -S "${IREE_SRC}" -B "${HOST_BUILD}" \ + -G Ninja \ + -DCMAKE_MAKE_PROGRAM="${NINJA}" \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_C_COMPILER="${CLANG}" \ + -DCMAKE_CXX_COMPILER="${CLANGXX}" \ + -DCMAKE_ASM_COMPILER="${CLANG}" \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DIREE_ENABLE_ASSERTIONS=ON \ + -DIREE_ENABLE_SPLIT_DWARF=ON \ + -DIREE_ENABLE_LLD=ON \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_TARGET_BACKEND_LLVM_CPU=ON \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=OFF \ + -DBENCHMARK_ENABLE_TESTING=OFF \ + -DHAVE_STD_REGEX=ON \ + -DHAVE_POSIX_REGEX=OFF + + echo " Building ($(nproc) jobs)..." + "${NINJA}" -C "${HOST_BUILD}" -j"$(nproc)" iree-opt iree-compile iree-run-module + + ok "iree-opt: ${HOST_BUILD}/tools/iree-opt" + ok "iree-compile: ${HOST_BUILD}/tools/iree-compile" + ok "iree-run-module (host): ${HOST_BUILD}/tools/iree-run-module" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Step 5: Build IREE RISC-V (iree-run-module cross-compiled for riscv64) +# ───────────────────────────────────────────────────────────────────────────── +step5_iree_riscv() { + log "STEP 5: Build IREE (RISC-V cross — iree-run-module for riscv64)" + + [[ -f "${HOST_BUILD}/bin/iree-tblgen" ]] || \ + die "Host tools not found at ${HOST_BUILD}/bin — run step 4 first." + + mkdir -p "${RISCV_BUILD}" + + cmake -S "${IREE_SRC}" -B "${RISCV_BUILD}" \ + -G Ninja \ + -DCMAKE_MAKE_PROGRAM="${NINJA}" \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_TOOLCHAIN_FILE="${IREE_SRC}/build_tools/cmake/riscv.toolchain.cmake" \ + -DIREE_HOST_BIN_DIR="${HOST_BUILD}/bin" \ + -DRISCV_TOOLCHAIN_ROOT="${RISCV_TOOLCHAIN}" \ + -DIREE_BUILD_COMPILER=OFF \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=OFF \ + -DBENCHMARK_ENABLE_TESTING=OFF \ + -DCMAKE_INSTALL_PREFIX="${RISCV_BUILD}/install" + + echo " Building ($(nproc) jobs)..." + "${NINJA}" -C "${RISCV_BUILD}" -j"$(nproc)" iree-run-module + "${NINJA}" -C "${RISCV_BUILD}" install/fast + + ok "iree-run-module (riscv64): ${RISCV_BUILD}/install/bin/iree-run-module" +} + +# ───────────────────────────────────────────────────────────────────────────── +# Main +# ───────────────────────────────────────────────────────────────────────────── +should_run 1 && step1_toolchain +should_run 2 && step2_sysroot +should_run 3 && step3_qemu +should_run 4 && step4_iree_host +should_run 5 && step5_iree_riscv + +echo "" +echo "════════════════════════════════════════════════════════════" +echo " Setup complete. Run the end-to-end workflows:" +echo "" +echo " bash build_tools/riscv/rvv_qemu_workflow_static.sh" +echo " bash build_tools/riscv/rvv_qemu_workflow_dynamic.sh" +echo "════════════════════════════════════════════════════════════" diff --git a/runtime/src/iree/builtins/mips/README.md b/runtime/src/iree/builtins/mips/README.md new file mode 100644 index 000000000000..fff6c4004a57 --- /dev/null +++ b/runtime/src/iree/builtins/mips/README.md @@ -0,0 +1,138 @@ +# MIPS Kernel Library + +This directory contains the hand-tuned kernel implementations for the +`mips` IREE dialect — a semantic dispatch layer that maps high-level tensor +operations to target-specific, optimized C kernels (currently RVV-vectorized +RISC-V matmul). + +## Source Files + +| File | Purpose | +|------|---------| +| `matmul_kernel.h` | Public API declaration for `my_matmul_kernel` | +| `matmul_kernel.c` | RVV-vectorized (or scalar fallback) compute kernel; **no IREE headers** | +| `matmul_plugin.c` | IREE HAL executable plugin interface (wraps `matmul_kernel.c`) | +| `rvv_standalone_test.c` | Standalone QEMU smoke-test (no IREE dependency) | + +### Design Principle + +`matmul_kernel.c` is intentionally free of IREE headers, making it usable +for three different build targets without modification: + +``` +matmul_kernel.c ──┬── (.o) baked into dispatch ELF at compile time (static) + ├── (.so) IREE plugin loaded via --executable_plugin (dynamic) + └── linked with rvv_standalone_test.c (QEMU unit test) +``` + +## Kernel ABI + +Matches the `func.call` emitted by `MIPSBufferizableOpInterface` after +decomposing 2-D memrefs with `memref.extract_strided_metadata` +(`{llvm.bareptr = true}`, so `memref` → `float*`): + +```c +void my_matmul_kernel( + const float *A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] + const float *B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] + float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] + int64_t M, int64_t N, int64_t K); +``` + +Each 2-D matrix is passed as `(base_ptr, offset, row_stride, col_stride)`, +supporting arbitrary memory layouts (row-major, column-major, non-contiguous). + +## Building + +Assumes `IREE_SRC` = path to this repo, `CLANG` = `~/miniforge3/bin/clang`. + +### Standalone test binary (no IREE, no libc) + +```bash +# RISC-V RVV (run under QEMU) +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -static -nostdlib -ffreestanding -nostdinc \ + -isystem ~/miniforge3/lib/clang/18/include \ + matmul_kernel.c rvv_standalone_test.c -o rvv_test +qemu-riscv64 -cpu rv64,v=true,vlen=512,elen=64,vext_spec=v1.0 ./rvv_test + +# x86 host (scalar fallback, with libc) +clang matmul_kernel.c rvv_standalone_test.c -O2 -o rvv_test_host && ./rvv_test_host +``` + +Expected output: +``` +=== rvv_matmul standalone test [RVV] +[1] A * I = A (4x4 row-major) +[2] 2x3 * 3x2 = 2x2 +[3] col-major strides +[4] non-zero base offset +PASSED (20 passed, 0 failed) +``` + +### Static object (.o) — baked into dispatch ELF + +```bash +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -c -nostdinc -isystem ~/miniforge3/lib/clang/18/include \ + -I "${IREE_SRC}/runtime/src" \ + matmul_kernel.c -o matmul_kernel_riscv.o +``` + +### Dynamic plugin (.so) — loaded at runtime + +```bash +clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ + -O2 -fPIC -shared -nostdinc -nostdlib \ + -isystem ~/miniforge3/lib/clang/18/include \ + -fuse-ld=~/miniforge3/bin/ld.lld \ + -I "${IREE_SRC}/runtime/src" \ + matmul_kernel.c matmul_plugin.c -o librvv_matmul.so +``` + +## Integration with IREE + +``` +torch.aten.mm + ─[ConvertTorchToMIPSPass]──► mips.matmul (flow IR, tensor domain) + ─[One-Shot Bufferize]──────► func.call @my_matmul_kernel (buffers decomposed) + ─[iree-compile LLVMCPU]────► dispatch ELF inside .vmfb +``` + +`MIPSBufferizableOpInterface` handles bufferization by decomposing each 2-D +memref into `(base_ptr, offset, stride0, stride1)` via +`memref.extract_strided_metadata` and emitting the `func.call` directly. +No memref form of `mips.matmul` is ever produced in the IR. + +### Static Embedding (`--iree-mips-static-embedding`) + +Pass `--iree-mips-static-embedding` to `iree-compile`. The bufferizer tags +`my_matmul_kernel` with `{hal.import.static}`, causing the LLVMCPU backend +to emit a direct linker-resolved call. A custom `lld_wrapper.sh` appends +`matmul_kernel_riscv.o` to every dispatch link at compile time. + +- No `--executable_plugin` at runtime — kernel is inside the `.vmfb`. +- Requires `-Bsymbolic` in the lld invocation. Without it, lld generates + `R_RISCV_JUMP_SLOT` in `.rela.plt`; IREE's embedded ELF loader ignores + `.rela.plt` (only processes `.rela.dyn`), causing a segfault on first call. + +### Dynamic Loading (`--executable_plugin`) + +Without the flag, `my_matmul_kernel` is a HAL import table entry resolved at +runtime from the plugin `.so` via `iree_hal_executable_plugin_query`. + +```bash +iree-run-module --module=matmul.vmfb \ + --executable_plugin=librvv_matmul.so \ + --function=matmul_4x4 ... +``` + +## End-to-End Workflow Scripts + +See [`build_tools/riscv/`](../../../../../build_tools/riscv/) in the repo root: + +| Script | Description | +|--------|-------------| +| `setup_qemu_workflow.sh` | One-time setup: toolchain, QEMU, IREE host + RISC-V builds | +| `rvv_qemu_workflow_static.sh` | Static-embedding pipeline + QEMU run | +| `rvv_qemu_workflow_dynamic.sh` | Dynamic-plugin pipeline + QEMU run | From 62c7bc8876c2cf73983d8ea1784bd1cca52090eb Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 14:45:41 -0700 Subject: [PATCH 10/12] [MIPS] Add install step to IREE host build and use installed paths - setup_qemu_workflow.sh step 4: add CMAKE_INSTALL_PREFIX, build iree-tblgen alongside the other tools, and run install/fast so all binaries land in iree-build/install/bin/. - step 5: point IREE_HOST_BIN_DIR to the install tree (iree-build/install/bin) instead of the raw build dir. - rvv_qemu_workflow_{static,dynamic}.sh: align IREE_OPT, IREE_COMPILE, and HOST_RUN to iree-build/install/bin/ accordingly. Co-Authored-By: Claude Sonnet 4.6 --- .../riscv/rvv_qemu_workflow_dynamic.sh | 7 ++++--- build_tools/riscv/rvv_qemu_workflow_static.sh | 7 ++++--- build_tools/riscv/setup_qemu_workflow.sh | 20 ++++++++++++------- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh index 580617fe5bbf..d0b8dde1743b 100755 --- a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh +++ b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh @@ -27,12 +27,13 @@ set -euo pipefail IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" WORK_DIR="${HOME}/MLIR_Work/mips" HOST_BUILD="${WORK_DIR}/iree-build" +HOST_INSTALL="${HOST_BUILD}/install" RISCV_BUILD="${WORK_DIR}/iree-build-riscv" OUT_DIR="${WORK_DIR}/out/dynamic" -IREE_OPT="${HOST_BUILD}/tools/iree-opt" -IREE_COMPILE="${HOST_BUILD}/tools/iree-compile" -HOST_RUN="${HOST_BUILD}/tools/iree-run-module" +IREE_OPT="${HOST_INSTALL}/bin/iree-opt" +IREE_COMPILE="${HOST_INSTALL}/bin/iree-compile" +HOST_RUN="${HOST_INSTALL}/bin/iree-run-module" RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" QEMU="${HOME}/local/bin/qemu-riscv64" SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" diff --git a/build_tools/riscv/rvv_qemu_workflow_static.sh b/build_tools/riscv/rvv_qemu_workflow_static.sh index 19bcb755a002..d8faa2851418 100644 --- a/build_tools/riscv/rvv_qemu_workflow_static.sh +++ b/build_tools/riscv/rvv_qemu_workflow_static.sh @@ -28,12 +28,13 @@ set -euo pipefail IREE_SRC="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" WORK_DIR="${HOME}/MLIR_Work/mips" HOST_BUILD="${WORK_DIR}/iree-build" +HOST_INSTALL="${HOST_BUILD}/install" RISCV_BUILD="${WORK_DIR}/iree-build-riscv" OUT_DIR="${WORK_DIR}/out/static" -IREE_OPT="${HOST_BUILD}/tools/iree-opt" -IREE_COMPILE="${HOST_BUILD}/tools/iree-compile" -HOST_RUN="${HOST_BUILD}/tools/iree-run-module" +IREE_OPT="${HOST_INSTALL}/bin/iree-opt" +IREE_COMPILE="${HOST_INSTALL}/bin/iree-compile" +HOST_RUN="${HOST_INSTALL}/bin/iree-run-module" RISCV_RUN="${RISCV_BUILD}/install/bin/iree-run-module" QEMU="${HOME}/local/bin/qemu-riscv64" SYSROOT="${HOME}/riscv/toolchain/clang/linux/RISCV/sysroot" diff --git a/build_tools/riscv/setup_qemu_workflow.sh b/build_tools/riscv/setup_qemu_workflow.sh index 6f4b985ca17f..792cf1353f87 100644 --- a/build_tools/riscv/setup_qemu_workflow.sh +++ b/build_tools/riscv/setup_qemu_workflow.sh @@ -25,6 +25,7 @@ WORK_DIR="${HOME}/MLIR_Work/mips" IREE_SRC="${WORK_DIR}/iree" HOST_BUILD="${WORK_DIR}/iree-build" # iree-opt, iree-compile (x86) +HOST_INSTALL="${HOST_BUILD}/install" # installed host tools RISCV_BUILD="${WORK_DIR}/iree-build-riscv" # iree-run-module (RISC-V) QEMU_VER="8.2.2" INSTALL_PREFIX="${HOME}/local" # qemu-riscv64 installed here @@ -199,6 +200,7 @@ step4_iree_host() { -DCMAKE_ASM_COMPILER="${CLANG}" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_INSTALL_PREFIX="${HOST_INSTALL}" \ -DIREE_ENABLE_ASSERTIONS=ON \ -DIREE_ENABLE_SPLIT_DWARF=ON \ -DIREE_ENABLE_LLD=ON \ @@ -213,11 +215,15 @@ step4_iree_host() { -DHAVE_POSIX_REGEX=OFF echo " Building ($(nproc) jobs)..." - "${NINJA}" -C "${HOST_BUILD}" -j"$(nproc)" iree-opt iree-compile iree-run-module + "${NINJA}" -C "${HOST_BUILD}" -j"$(nproc)" iree-opt iree-compile iree-run-module iree-tblgen - ok "iree-opt: ${HOST_BUILD}/tools/iree-opt" - ok "iree-compile: ${HOST_BUILD}/tools/iree-compile" - ok "iree-run-module (host): ${HOST_BUILD}/tools/iree-run-module" + echo " Installing host tools to ${HOST_INSTALL}..." + "${NINJA}" -C "${HOST_BUILD}" install/fast + + ok "iree-opt: ${HOST_INSTALL}/bin/iree-opt" + ok "iree-compile: ${HOST_INSTALL}/bin/iree-compile" + ok "iree-run-module: ${HOST_INSTALL}/bin/iree-run-module" + ok "iree-tblgen: ${HOST_INSTALL}/bin/iree-tblgen" } # ───────────────────────────────────────────────────────────────────────────── @@ -226,8 +232,8 @@ step4_iree_host() { step5_iree_riscv() { log "STEP 5: Build IREE (RISC-V cross — iree-run-module for riscv64)" - [[ -f "${HOST_BUILD}/bin/iree-tblgen" ]] || \ - die "Host tools not found at ${HOST_BUILD}/bin — run step 4 first." + [[ -f "${HOST_INSTALL}/bin/iree-tblgen" ]] || \ + die "Host install not found at ${HOST_INSTALL}/bin — run step 4 first." mkdir -p "${RISCV_BUILD}" @@ -236,7 +242,7 @@ step5_iree_riscv() { -DCMAKE_MAKE_PROGRAM="${NINJA}" \ -DCMAKE_BUILD_TYPE=RelWithDebInfo \ -DCMAKE_TOOLCHAIN_FILE="${IREE_SRC}/build_tools/cmake/riscv.toolchain.cmake" \ - -DIREE_HOST_BIN_DIR="${HOST_BUILD}/bin" \ + -DIREE_HOST_BIN_DIR="${HOST_INSTALL}/bin" \ -DRISCV_TOOLCHAIN_ROOT="${RISCV_TOOLCHAIN}" \ -DIREE_BUILD_COMPILER=OFF \ -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ From 31153e271448db0195115ec0153c29280ba3bc49 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 9 Mar 2026 15:33:08 -0700 Subject: [PATCH 11/12] [MIPS] Add test MLIR and fix dynamic workflow message Co-Authored-By: Claude Sonnet 4.6 --- build_tools/riscv/mips_matmul_test.mlir | 101 ++++++++++++++++++ .../riscv/rvv_qemu_workflow_dynamic.sh | 6 +- build_tools/riscv/rvv_qemu_workflow_static.sh | 2 +- 3 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 build_tools/riscv/mips_matmul_test.mlir diff --git a/build_tools/riscv/mips_matmul_test.mlir b/build_tools/riscv/mips_matmul_test.mlir new file mode 100644 index 000000000000..f3e4d09c83ca --- /dev/null +++ b/build_tools/riscv/mips_matmul_test.mlir @@ -0,0 +1,101 @@ +// mips_matmul_test.mlir +// +// End-to-end test inputs for the MIPS matmul kernel pipeline. +// +// Each function exercises torch.aten.mm, which is intercepted by +// ConvertTorchToMIPSPass and rewritten as mips.matmul. The op is then +// eliminated during One-Shot Bufferize: MIPSBufferizableOpInterface +// decomposes the 2-D memrefs and emits a direct func.call to the +// hand-tuned C kernel: +// +// torch.aten.mm +// → mips.matmul (ConvertTorchToMIPSPass) +// → flow.dispatch(...) (IREE dispatch formation) +// → func.call @my_matmul_kernel (MIPSBufferizableOpInterface) +// → ELF inside .vmfb (iree-compile LLVMCPU backend) +// +// Usage: +// bash build_tools/riscv/rvv_qemu_workflow_static.sh -- static (.o baked into vmfb) +// bash build_tools/riscv/rvv_qemu_workflow_dynamic.sh -- dynamic (.so plugin at runtime) + +module { + // ── Test 1: 4×4 identity × data → passthrough ──────────────────────────── + // Verifies that A=I leaves B unchanged; a simple correctness smoke-test. + // + // A = identity(4×4), B = [[1..4],[5..8],[9..12],[13..16]] + // Expected: result = B + func.func @matmul_4x4( + %A : !torch.vtensor<[4,4],f32>, + %B : !torch.vtensor<[4,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> + } + + // ── Test 2: 2×3 × 3×2 → 2×2 (non-square, reduced K dimension) ────────── + // Verifies M≠N≠K path through the kernel (inner loop trip-count < vlen). + // + // A = [[1,2,3],[4,5,6]], B = [[1,0],[0,1],[1,0]] + // Expected: [[1+0+3, 0+2+0],[4+0+6, 0+5+0]] = [[4,2],[10,5]] + func.func @matmul_2x3x2( + %A : !torch.vtensor<[2,3],f32>, + %B : !torch.vtensor<[3,2],f32>) + -> !torch.vtensor<[2,2],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,2],f32> + -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> + } + + // ── Test 3: 8×8 × 8×8 → 8×8 (exercises multi-vector-register tiling) ─── + // With vlen=512 and LMUL=m4, N=8 fits in a single VL group. This test + // stresses the vectorized inner loop and accumulation across K=8 steps. + // + // A = upper-triangular ones (row i has ones in columns 0..i). + // B = identity(8×8). + // Expected: A*I = A — result is upper-triangular ones. + // + // A row layout (8×8): + // row 0: [1,0,0,0,0,0,0,0] + // row 1: [1,1,0,0,0,0,0,0] + // row 2: [1,1,1,0,0,0,0,0] + // ... + // row 7: [1,1,1,1,1,1,1,1] + func.func @matmul_8x8( + %A : !torch.vtensor<[8,8],f32>, + %B : !torch.vtensor<[8,8],f32>) + -> !torch.vtensor<[8,8],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[8,8],f32>, !torch.vtensor<[8,8],f32> + -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Expected outputs (iree-run-module) +// ───────────────────────────────────────────────────────────────────────────── +// +// matmul_4x4 A=identity(4x4), B=[1..16 row-major]: +// result[0]: 4x4xf32=[1 2 3 4][5 6 7 8][9 10 11 12][13 14 15 16] +// +// matmul_2x3x2 A=[[1,2,3],[4,5,6]], B=[[1,0],[0,1],[1,0]]: +// result[0]: 2x2xf32=[4 2][10 5] +// +// matmul_8x8 A=upper-triangular-ones(8x8), B=identity(8x8): +// result[0]: 8x8xf32= +// [1 0 0 0 0 0 0 0] +// [1 1 0 0 0 0 0 0] +// [1 1 1 0 0 0 0 0] +// [1 1 1 1 0 0 0 0] +// [1 1 1 1 1 0 0 0] +// [1 1 1 1 1 1 0 0] +// [1 1 1 1 1 1 1 0] +// [1 1 1 1 1 1 1 1] +// +// iree-run-module invocation for matmul_8x8: +// --function=matmul_8x8 +// "--input=8x8xf32=1,0,0,0,0,0,0,0, 1,1,0,0,0,0,0,0, 1,1,1,0,0,0,0,0, 1,1,1,1,0,0,0,0, 1,1,1,1,1,0,0,0, 1,1,1,1,1,1,0,0, 1,1,1,1,1,1,1,0, 1,1,1,1,1,1,1,1" +// "--input=8x8xf32=1,0,0,0,0,0,0,0, 0,1,0,0,0,0,0,0, 0,0,1,0,0,0,0,0, 0,0,0,1,0,0,0,0, 0,0,0,0,1,0,0,0, 0,0,0,0,0,1,0,0, 0,0,0,0,0,0,1,0, 0,0,0,0,0,0,0,1" diff --git a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh index d0b8dde1743b..0d906813071c 100755 --- a/build_tools/riscv/rvv_qemu_workflow_dynamic.sh +++ b/build_tools/riscv/rvv_qemu_workflow_dynamic.sh @@ -44,7 +44,7 @@ CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" PLUGIN_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_plugin.c" -TEST_MLIR="${WORK_DIR}/mips_matmul_test.mlir" +TEST_MLIR="${IREE_SRC}/build_tools/riscv/mips_matmul_test.mlir" # Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" @@ -154,9 +154,7 @@ data = open(sys.argv[1], 'rb').read() idx = data.find(b'my_matmul_kernel') rvv = sum(1 for i in range(0, len(data)-3, 4) if data[i] & 0x7f == 0x57) if idx != -1: - tag = "[ok]" if rvv == 0 else "[note]" - print(f" {tag} 'my_matmul_kernel' at offset {idx} (import entry, no RVV here)") - print(f" RVV opcode count in vmfb: {rvv} (expected 0 — kernel is in .so)") + print(f" [ok] 'my_matmul_kernel' at offset {idx} (import table entry — kernel lives in .so)") else: print(" [warn] 'my_matmul_kernel' not found in dispatch ELF") PYEOF diff --git a/build_tools/riscv/rvv_qemu_workflow_static.sh b/build_tools/riscv/rvv_qemu_workflow_static.sh index d8faa2851418..c40d531b450b 100644 --- a/build_tools/riscv/rvv_qemu_workflow_static.sh +++ b/build_tools/riscv/rvv_qemu_workflow_static.sh @@ -44,7 +44,7 @@ LLD="${HOME}/miniforge3/bin/ld.lld" CLANG_INC="${HOME}/miniforge3/lib/clang/18/include" KERNEL_SRC="${IREE_SRC}/runtime/src/iree/builtins/mips/matmul_kernel.c" -TEST_MLIR="${WORK_DIR}/mips_matmul_test.mlir" +TEST_MLIR="${IREE_SRC}/build_tools/riscv/mips_matmul_test.mlir" # Rocky 8's libstdc++ is too old; conda has GLIBCXX 3.4.29+. export LD_LIBRARY_PATH="${HOME}/miniforge3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" From c8a34bcf387cbfb049767b3da5b67150c94028ee Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 19 Mar 2026 11:46:10 -0700 Subject: [PATCH 12/12] =?UTF-8?q?[MIPS]=20Add=20INT8=20support=20for=20mip?= =?UTF-8?q?s.matmul=20(i8=C3=97i8=E2=86=92i32)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extend ConvertTorchToMIPS to lower aten._int_mm to mips.matmul with signless i8/i32 types - Update MIPSOps verifier to allow i8×i8→i32 widening in addition to the existing f32×f32→f32 combination - Add matmul_kernel_i8.c: RVV widening kernel using vwmacc.vx (i8m2→i16m4→i32m8) with scalar fallback for non-RISC-V builds - Update CMakeLists.txt and matmul_kernel.h / matmul_plugin.c to include and expose the new INT8 kernel entry point --- .../InputConversion/ConvertTorchToMIPS.cpp | 174 ++++++++++++++++-- .../MIPS/IR/MIPSBufferizableOpInterface.cpp | 76 +++++--- .../iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp | 23 ++- runtime/src/iree/builtins/mips/CMakeLists.txt | 8 +- .../src/iree/builtins/mips/matmul_kernel.h | 18 +- .../src/iree/builtins/mips/matmul_kernel_i8.c | 146 +++++++++++++++ .../src/iree/builtins/mips/matmul_plugin.c | 40 +++- 7 files changed, 428 insertions(+), 57 deletions(-) create mode 100644 runtime/src/iree/builtins/mips/matmul_kernel_i8.c diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp index 16273f810684..3c377e93e668 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp @@ -4,17 +4,25 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Converts torch.aten.mm → mips.matmul. +// Converts torch matmul ops → mips.matmul. // -// The pattern runs inside the Torch input-conversion pipeline, BEFORE -// createConvertTorchToLinalgPass(), so it intercepts aten.mm first. +// Patterns handled: +// ConvertAtenMmToMIPSMatmul — torch.aten.mm (f32 or i8 inputs) +// ConvertAtenIntMmToMIPSMatmul — torch.aten._int_mm (i8 × i8 → i32) // -// Since torch ops carry ValueTensorType (torch's tensor type), the pattern: +// Both patterns run inside the Torch input-conversion pipeline, BEFORE +// createConvertTorchToLinalgPass(), so they intercept the ops first. +// +// Since torch ops carry ValueTensorType (torch's tensor type), each pattern: // 1. Casts operands to builtin RankedTensorType via ToBuiltinTensorOp. // 2. Creates a zero-initialised init tensor (Destination Passing Style). // 3. Emits mips.matmul on builtin tensors. // 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp. // +// The mips.matmul op is eliminated during One-Shot Bufferize: +// MIPSBufferizableOpInterface detects the LHS element type and calls either +// my_matmul_kernel (f32) or my_matmul_kernel_i8 (i8→i32). +// #include "compiler/plugins/input/Torch/InputConversion/Passes.h" #include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" @@ -34,6 +42,28 @@ namespace mlir::iree_compiler::TorchInput { namespace { +//===----------------------------------------------------------------------===// +// Helper: normalize signed/unsigned integer types to signless. +// +// arith.constant (and most MLIR arithmetic ops) require signless integers. +// Torch's dtype mapping produces signed types (e.g. si8, si32) which must be +// converted to their signless equivalents (i8, i32) before entering arith/ +// linalg/tensor dialects. +//===----------------------------------------------------------------------===// + +static Type toSignlessElemType(MLIRContext *ctx, Type ty) { + if (auto intTy = dyn_cast(ty)) + if (!intTy.isSignless()) + return IntegerType::get(ctx, intTy.getWidth()); + return ty; +} + +static RankedTensorType toSignlessTensorType(RankedTensorType ty) { + Type elem = toSignlessElemType(ty.getContext(), ty.getElementType()); + if (elem == ty.getElementType()) return ty; + return RankedTensorType::get(ty.getShape(), elem, ty.getEncoding()); +} + //===----------------------------------------------------------------------===// // Helper: create a zero-filled tensor of a given shape and element type. // Accepts (M, N) as dynamic Value dimensions. @@ -41,14 +71,29 @@ namespace { static Value createZeroTensor(PatternRewriter &rewriter, Location loc, RankedTensorType ty, ValueRange dynSizes) { - Value empty = tensor::EmptyOp::create(rewriter, loc, ty, dynSizes); - Attribute zeroAttr = rewriter.getZeroAttr(ty.getElementType()); + // Use signless element types — arith.constant rejects signed integers. + RankedTensorType signlessTy = toSignlessTensorType(ty); + Value empty = tensor::EmptyOp::create(rewriter, loc, signlessTy, dynSizes); + Attribute zeroAttr = rewriter.getZeroAttr(signlessTy.getElementType()); Value zero = arith::ConstantOp::create(rewriter, loc, cast(zeroAttr)); return linalg::FillOp::create(rewriter, loc, zero, empty).result(); } +//===----------------------------------------------------------------------===// +// Helper: check whether a torch dtype is one we handle in mips.matmul. +// Returns true for f32 and si8/i8 inputs. +//===----------------------------------------------------------------------===// + +static bool isSupportedMmDtype(Type torchDtype) { + return torchDtype.isF32() || torchDtype.isSignedInteger(8) || + torchDtype.isInteger(8); +} + //===----------------------------------------------------------------------===// // Pattern: torch.aten.mm → mips.matmul +// +// Handles f32 × f32 → f32 (original path) +// and i8 × i8 → i8 (rare, but supported via same mips.matmul) //===----------------------------------------------------------------------===// struct ConvertAtenMmToMIPSMatmul @@ -72,24 +117,28 @@ struct ConvertAtenMmToMIPSMatmul if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); - // Only handle f32 for now (extensible). - if (!lhsTorchTy.getDtype().isF32()) - return rewriter.notifyMatchFailure(op, "only f32 supported"); + if (!isSupportedMmDtype(lhsTorchTy.getDtype())) + return rewriter.notifyMatchFailure(op, "unsupported dtype (f32 or i8 only)"); // ---------------------------------------------------------------- // 2. Cast operands from torch ValueTensorType → builtin RankedTensorType. // ---------------------------------------------------------------- - auto lhsBuiltinTy = - dyn_cast_or_null(lhsTorchTy.toBuiltinTensor()); - auto rhsBuiltinTy = - dyn_cast_or_null(rhsTorchTy.toBuiltinTensor()); - auto resultBuiltinTy = - dyn_cast_or_null(resultTorchTy.toBuiltinTensor()); + auto lhsBuiltinTy = dyn_cast_or_null( + lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = dyn_cast_or_null( + rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = dyn_cast_or_null( + resultTorchTy.toBuiltinTensor()); if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + // Normalize signed integer element types to signless (arith requires it). + lhsBuiltinTy = toSignlessTensorType(lhsBuiltinTy); + rhsBuiltinTy = toSignlessTensorType(rhsBuiltinTy); + resultBuiltinTy = toSignlessTensorType(resultBuiltinTy); + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( rewriter, loc, lhsBuiltinTy, op.getSelf()); Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( @@ -129,6 +178,100 @@ struct ConvertAtenMmToMIPSMatmul } }; +//===----------------------------------------------------------------------===// +// Pattern: torch.aten._int_mm → mips.matmul +// +// torch.aten._int_mm: i8 × i8 → i32 (integer matrix multiply). +// This is the primary op produced by INT8 quantization pipelines (e.g. +// torch.ao.quantization, torchao). +// +// mips.matmul carries i8 LHS/RHS and i32 output; MIPSBufferizableOpInterface +// detects the i8 LHS element type and emits func.call @my_matmul_kernel_i8. +//===----------------------------------------------------------------------===// + +struct ConvertAtenIntMmToMIPSMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(torch::Torch::Aten_IntMmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // ---------------------------------------------------------------- + // 1. Verify operand and result types. + // ---------------------------------------------------------------- + auto lhsTorchTy = + dyn_cast(op.getSelf().getType()); + auto rhsTorchTy = + dyn_cast(op.getMat2().getType()); + auto resultTorchTy = + dyn_cast(op.getType()); + + if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) + return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); + + // _int_mm expects i8 inputs and i32 output. + if (!lhsTorchTy.getDtype().isSignedInteger(8) && + !lhsTorchTy.getDtype().isInteger(8)) + return rewriter.notifyMatchFailure(op, "expected i8 lhs"); + + // ---------------------------------------------------------------- + // 2. Cast to builtin tensor types (signless integers — arith requires it). + // ---------------------------------------------------------------- + auto lhsBuiltinTy = dyn_cast_or_null( + lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = dyn_cast_or_null( + rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = dyn_cast_or_null( + resultTorchTy.toBuiltinTensor()); + + if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || + lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + + lhsBuiltinTy = toSignlessTensorType(lhsBuiltinTy); + rhsBuiltinTy = toSignlessTensorType(rhsBuiltinTy); + resultBuiltinTy = toSignlessTensorType(resultBuiltinTy); + + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, lhsBuiltinTy, op.getSelf()); + Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, rhsBuiltinTy, op.getMat2()); + + // ---------------------------------------------------------------- + // 3. Dynamic dims for the i32 result tensor (M from lhs, N from rhs). + // ---------------------------------------------------------------- + SmallVector dynSizes; + if (resultBuiltinTy.isDynamicDim(0)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0)); + if (resultBuiltinTy.isDynamicDim(1)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1)); + + // ---------------------------------------------------------------- + // 4. Zero-initialised i32 init tensor. + // ---------------------------------------------------------------- + Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes); + + // ---------------------------------------------------------------- + // 5. Emit mips.matmul — LHS is i8, result is i32. + // MIPSBufferizableOpInterface dispatches to my_matmul_kernel_i8. + // ---------------------------------------------------------------- + Value result = + IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy}, + lhs, rhs, init) + .getResult(); + + // ---------------------------------------------------------------- + // 6. Cast result back to torch ValueTensorType. + // ---------------------------------------------------------------- + Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, resultTorchTy, result); + + rewriter.replaceOp(op, torchResult); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// @@ -145,6 +288,7 @@ struct ConvertTorchToMIPSPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp index b28f4ea3ce01..551192ab492b 100644 --- a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp @@ -10,10 +10,15 @@ // eliminated *entirely* during One-Shot Bufferize: bufferize() obtains memref // buffers for all three operands, decomposes each 2-D memref into // (base_ptr, offset, stride0, stride1) via memref.extract_strided_metadata, -// and emits a func.call @my_matmul_kernel directly. No memref form of -// mips.matmul ever exists in the IR. +// and emits a func.call to the appropriate kernel directly. // -// Before bufferization: +// The kernel is selected based on the LHS element type: +// f32 → func.call @my_matmul_kernel (f32 × f32 → f32) +// i8 → func.call @my_matmul_kernel_i8 (i8 × i8 → i32) +// +// No memref form of mips.matmul ever exists in the IR. +// +// Before bufferization (f32 example): // %C = mips.matmul %A, %B, %init // : tensor, tensor, tensor -> tensor // @@ -29,6 +34,9 @@ // %C_base, %C_off, %C_s0, %C_s1, // %M, %N, %K) // -- tensor result replaced by %C_buf via replaceOpWithBufferizedValues -- +// +// The INT8 path is identical but uses memref / memref base pointers +// and calls @my_matmul_kernel_i8 instead. #include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" #include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" @@ -45,15 +53,15 @@ using namespace mlir; using namespace mlir::bufferization; -// When true, my_matmul_kernel is emitted as a direct linker-resolved call -// (hal.import.static) instead of a dynamic HAL import table entry. +// When true, matmul kernels are emitted as direct linker-resolved calls +// (hal.import.static) instead of dynamic HAL import table entries. // Pass --iree-mips-static-embedding to iree-compile to enable. // Mutually exclusive with --executable_plugin at runtime. static llvm::cl::opt clMIPSStaticEmbedding( "iree-mips-static-embedding", llvm::cl::desc( - "Emit my_matmul_kernel as a direct linker-resolved call " - "(hal.import.static) instead of a dynamic HAL import. " + "Emit mips matmul kernels as direct linker-resolved calls " + "(hal.import.static) instead of dynamic HAL imports. " "Requires the kernel .o to be appended by lld_wrapper at compile " "time. Mutually exclusive with --executable_plugin at runtime."), llvm::cl::init(false)); @@ -61,32 +69,32 @@ static llvm::cl::opt clMIPSStaticEmbedding( namespace mlir::iree_compiler::IREE::MIPS { namespace { -static constexpr StringLiteral kKernelName = "my_matmul_kernel"; +static constexpr StringLiteral kKernelF32 = "my_matmul_kernel"; +static constexpr StringLiteral kKernelI8 = "my_matmul_kernel_i8"; //===----------------------------------------------------------------------===// -// Helper: ensure func.func private @my_matmul_kernel exists at module scope. +// Helper: ensure func.func private @ exists at module scope. // // The declaration carries {llvm.bareptr = true} so the LLVM backend passes -// bare float* arguments instead of MLIR memref descriptor structs, matching +// bare pointer arguments instead of MLIR memref descriptor structs, matching // the C kernel ABI. //===----------------------------------------------------------------------===// static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, Operation *moduleOp, + StringRef kernelName, FunctionType fnType, Location loc) { if (auto existing = dyn_cast_if_present( - SymbolTable::lookupSymbolIn(moduleOp, kKernelName))) + SymbolTable::lookupSymbolIn(moduleOp, kernelName))) return existing; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front()); - auto fnDecl = func::FuncOp::create(rewriter, loc, kKernelName, fnType); + auto fnDecl = func::FuncOp::create(rewriter, loc, kernelName, fnType); SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private); fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); - // If --iree-mips-static-embedding was passed to iree-compile, emit a direct - // linker call instead of a dynamic HAL import table entry. - // Without this flag the call goes through the HAL import table, which lets - // the runtime resolve it from an --executable_plugin .so at run time. + // If --iree-mips-static-embedding was passed, emit a direct linker call + // instead of a dynamic HAL import table entry. if (clMIPSStaticEmbedding) fnDecl->setAttr("hal.import.static", rewriter.getUnitAttr()); return fnDecl; @@ -96,19 +104,19 @@ static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, // Helper: decompose a 2-D memref into (base_ptr, offset, stride0, stride1). // // Uses memref.extract_strided_metadata. The base_ptr is always a rank-0 -// memref with DEFAULT address space (memref), regardless of the source -// memref's address space. Any IREE-specific memory space (e.g. +// memref with DEFAULT address space (memref), regardless of the +// source memref's address space. Any IREE-specific memory space (e.g. // #hal.descriptor_type) is stripped via // memref.memory_space_cast so that: // -// 1. The function declaration uses plain memref, which is stable across -// all pipeline stages. +// 1. The function declaration uses plain memref, which is stable +// across all pipeline stages. // 2. eraseHALDescriptorTypeFromMemRefPass (which runs after bufferization and // does NOT update external function declarations) cannot introduce a // type mismatch between the call operands and the declaration. // // Combined with the {llvm.bareptr = true} attribute on the callee, the -// rank-0 memref lowers to a bare float* matching the C ABI. +// rank-0 memref lowers to a bare pointer matching the C ABI. //===----------------------------------------------------------------------===// static void decomposeMemref2D(RewriterBase &rewriter, Location loc, @@ -149,7 +157,8 @@ static void decomposeMemref2D(RewriterBase &rewriter, Location loc, // Inherits from DstBufferizableOpInterfaceExternalModel which automatically // handles the DPS aliasing (init ↔ result) and write detection for the init // operand. We override bufferizesToMemoryRead to mark lhs and rhs as read, -// and provide a custom bufferize() that emits func.call @my_matmul_kernel. +// and provide a custom bufferize() that selects the right kernel based on the +// LHS element type. //===----------------------------------------------------------------------===// struct MIPSMatmulBufferizableOpInterface @@ -182,14 +191,27 @@ struct MIPSMatmulBufferizableOpInterface if (failed(rhsBuf)) return failure(); // init aliases with result — one-shot bufferize allocates the output buffer - // (via bufferization.alloc_tensor or in-place analysis) and gives it to us - // here as initBuf. + // (via bufferization.alloc_tensor or in-place analysis) and gives it here. FailureOr initBuf = getBuffer(rewriter, matmulOp.getInit(), options, state); if (failed(initBuf)) return failure(); - // Build the flattened argument list for func.call @my_matmul_kernel. + // Select the kernel based on LHS element type. + Type lhsElemTy = cast(lhsBuf->getType()).getElementType(); + StringRef kernelName; + if (lhsElemTy.isF32()) { + kernelName = kKernelF32; + } else if (lhsElemTy.isInteger(8)) { + kernelName = kKernelI8; + } else { + return matmulOp.emitOpError( + "MIPSBufferizableOpInterface: unsupported LHS element type '") + << lhsElemTy + << "'; supported types are f32 and i8"; + } + + // Build the flattened argument list for the kernel call. // For each 2-D memref: (base_ptr, offset, stride0, stride1) // Then: M, N, K as index scalars. SmallVector callOperands; @@ -209,10 +231,10 @@ struct MIPSMatmulBufferizableOpInterface // Declare the kernel function in the enclosing module (idempotent). Operation *moduleOp = SymbolTable::getNearestSymbolTable(matmulOp); FunctionType fnType = rewriter.getFunctionType(callArgTypes, TypeRange{}); - ensureKernelDeclaration(rewriter, moduleOp, fnType, loc); + ensureKernelDeclaration(rewriter, moduleOp, kernelName, fnType, loc); // Emit the call — the kernel writes into *initBuf in place. - func::CallOp::create(rewriter, loc, kKernelName, TypeRange{}, callOperands); + func::CallOp::create(rewriter, loc, kernelName, TypeRange{}, callOperands); // Replace the tensor result with the init buffer (DPS aliasing). replaceOpWithBufferizedValues(rewriter, op, *initBuf); diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp index c60c0b1d6c6f..a79f629f959e 100644 --- a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp @@ -79,11 +79,24 @@ LogicalResult MatmulOp::verify() { if (!compat(shape(getRhs())[1], shape(getInit())[1])) return emitOpError("rhs dim 1 (N) must match init dim 1 (N)"); - // All element types must match. - if (elemTy(getLhs()) != elemTy(getRhs()) || elemTy(getLhs()) != elemTy(getInit())) - return emitOpError("element types of all operands must match"); - - // Result type must match init type (both tensor). + // LHS and RHS element types must match. + if (elemTy(getLhs()) != elemTy(getRhs())) + return emitOpError("lhs and rhs element types must match"); + + // Supported element type combinations: + // f32 × f32 → f32 (standard float matmul) + // i8 × i8 → i32 (INT8 widening matmul) + Type lhsElem = elemTy(getLhs()); + Type outElem = elemTy(getInit()); + bool valid = (lhsElem == outElem) || + (lhsElem.isInteger(8) && outElem.isInteger(32)); + if (!valid) + return emitOpError( + "unsupported element type combination: lhs=") + << lhsElem << ", output=" << outElem + << "; supported: f32×f32→f32, i8×i8→i32"; + + // Result type must match init type. if (getResult().getType() != getInit().getType()) return emitOpError("result type must match init type"); diff --git a/runtime/src/iree/builtins/mips/CMakeLists.txt b/runtime/src/iree/builtins/mips/CMakeLists.txt index 1dc3825b1975..ac1fbaea22b4 100644 --- a/runtime/src/iree/builtins/mips/CMakeLists.txt +++ b/runtime/src/iree/builtins/mips/CMakeLists.txt @@ -4,16 +4,18 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Shared library containing the MIPS custom matmul kernel. +# Shared library containing the MIPS custom matmul kernels. # -# matmul_kernel.c — RVV (or scalar fallback) compute kernel; no IREE headers. -# matmul_plugin.c — IREE HAL executable plugin interface. +# matmul_kernel.c — RVV f32 (or scalar fallback) compute kernel; no IREE headers. +# matmul_kernel_i8.c — RVV INT8 widening compute kernel; no IREE headers. +# matmul_plugin.c — IREE HAL executable plugin interface for both kernels. # # The plugin is loaded at runtime via: # iree-run-module --executable_plugin=libmips_matmul.so ... add_library(mips_matmul SHARED matmul_kernel.c + matmul_kernel_i8.c matmul_plugin.c ) diff --git a/runtime/src/iree/builtins/mips/matmul_kernel.h b/runtime/src/iree/builtins/mips/matmul_kernel.h index 6f46a5ced9fc..11cf87ff1726 100644 --- a/runtime/src/iree/builtins/mips/matmul_kernel.h +++ b/runtime/src/iree/builtins/mips/matmul_kernel.h @@ -11,13 +11,19 @@ // {llvm.bareptr = true} means memref → float* in C. // index → int64_t on RV64. // -// C signature (15 args total): +// C signatures (15 args each): // void my_matmul_kernel( // float* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] // float* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] // float* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] // int64_t M, int64_t N, int64_t K // ); +// void my_matmul_kernel_i8( +// int8_t* A, int64_t A_off, int64_t A_s0, int64_t A_s1, // lhs [M×K] i8 +// int8_t* B, int64_t B_off, int64_t B_s0, int64_t B_s1, // rhs [K×N] i8 +// int32_t* C, int64_t C_off, int64_t C_s0, int64_t C_s1, // out [M×N] i32 +// int64_t M, int64_t N, int64_t K +// ); #ifndef IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ #define IREE_BUILTINS_MIPS_MATMUL_KERNEL_H_ @@ -37,6 +43,16 @@ void my_matmul_kernel( float *C, int64_t C_off, int64_t C_s0, int64_t C_s1, int64_t M, int64_t N, int64_t K); +// 2-D INT8 matmul: C[M×N] (i32) = A[M×K] (i8) * B[K×N] (i8). +// C is zero-initialized by the caller before the call. +// Supports arbitrary row/col strides and non-zero base offsets. +// RVV-vectorized (widening i8→i16→i32) on RISC-V; scalar fallback elsewhere. +void my_matmul_kernel_i8( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K); + #ifdef __cplusplus } #endif diff --git a/runtime/src/iree/builtins/mips/matmul_kernel_i8.c b/runtime/src/iree/builtins/mips/matmul_kernel_i8.c new file mode 100644 index 000000000000..0ed667401dd7 --- /dev/null +++ b/runtime/src/iree/builtins/mips/matmul_kernel_i8.c @@ -0,0 +1,146 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// RVV (RISC-V Vector 1.0) INT8 matmul kernel: C[i32] = A[i8] * B[i8]. +// +// This file is intentionally free of IREE headers so it can be: +// - Compiled to a .o and baked into the dispatch ELF (static embedding). +// - Linked into the IREE plugin .so alongside matmul_plugin.c. +// - Compiled standalone for unit tests. +// +// Vectorization strategy (RVV widening multiply): +// Outer loops: m (rows of A) and k (contraction axis). +// Inner loop : n (columns of B), vectorized. +// +// LMUL selection to match widening chain: +// i8 LMUL=m2 → VLMAX = (VLEN/8) * 2 = VLEN/4 +// i16 LMUL=m4 → VLMAX = (VLEN/16) * 4 = VLEN/4 (after first widening) +// i32 LMUL=m8 → VLMAX = (VLEN/32) * 8 = VLEN/4 (accumulator) +// All three LMULs yield the same VLMAX, so the same application vl is valid +// for all three element widths — no secondary vsetvl is needed per intrinsic. +// +// Per n-strip: +// acc[vl] (i32m8) = 0 +// for k in 0..K: +// a_val (i8 scalar) = A[m, k] +// b_i8 (i8m2 vec) = B[k, n:n+vl] +// b_i16 (i16m4 vec) = sign_extend(b_i8) +// acc += widen_macc(a_val, b_i16) // i16 * i16 → i32 +// C[m, n:n+vl] = acc +// +// The widening chain avoids intermediate overflow: +// i8 × i8 → i16 intermediate → accumulated into i32. + +#include "matmul_kernel.h" + +#include +#include + +#ifdef __riscv_vector +#include +#endif + +//===----------------------------------------------------------------------===// +// Internal compute kernel (RVV path) +//===----------------------------------------------------------------------===// + +#ifdef __riscv_vector + +static void rvv_matmul_i8_core( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + + const int a_unit_col = (A_s1 == 1); + const int b_unit_col = (B_s1 == 1); + const int c_unit_col = (C_s1 == 1); + + for (int64_t m = 0; m < M; ++m) { + int64_t n = 0; + while (n < N) { + // vl: element count for this n-strip. + // __riscv_vsetvl_e32m8 caps vl at VLMAX_i32m8 = VLEN/4, which equals + // VLMAX_i8m2 and VLMAX_i16m4 — all intrinsics below share this vl. + size_t vl = __riscv_vsetvl_e32m8((size_t)(N - n)); + + // Initialize i32 accumulator to zero. + vint32m8_t acc = __riscv_vmv_v_x_i32m8(0, vl); + + for (int64_t k = 0; k < K; ++k) { + // Scalar load from A[m, k]. + int8_t a_val = a_unit_col ? A[m * A_s0 + k] + : A[m * A_s0 + k * A_s1]; + + // Vector load B[k, n:n+vl] as i8. + vint8m2_t b_i8 = + b_unit_col + ? __riscv_vle8_v_i8m2((const int8_t *)&B[k * B_s0 + n], vl) + : __riscv_vlse8_v_i8m2( + (const int8_t *)&B[k * B_s0 + n * B_s1], + (ptrdiff_t)(B_s1 * (int64_t)sizeof(int8_t)), vl); + + // Sign-extend i8m2 → i16m4 (same element count, double the width). + vint16m4_t b_i16 = __riscv_vsext_vf2_i16m4(b_i8, vl); + + // Widening signed multiply-accumulate: + // acc[i] += sign_ext_32(a_val) * sign_ext_32(b_i16[i]) + // vwmacc.vx vd[i32m8], rs1[i16], vs2[i16m4] + acc = __riscv_vwmacc_vx_i32m8(acc, (int16_t)a_val, b_i16, vl); + } + + // Store accumulator to C[m, n:n+vl]. + if (c_unit_col) + __riscv_vse32_v_i32m8((int32_t *)&C[m * C_s0 + n], acc, vl); + else + __riscv_vsse32_v_i32m8( + (int32_t *)&C[m * C_s0 + n * C_s1], + (ptrdiff_t)(C_s1 * (int64_t)sizeof(int32_t)), acc, vl); + + n += (int64_t)vl; + } + } +} + +#else // scalar fallback + +static void rvv_matmul_i8_core( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + A += A_off; + B += B_off; + C += C_off; + for (int64_t m = 0; m < M; ++m) + for (int64_t n = 0; n < N; ++n) { + int32_t acc = 0; + for (int64_t k = 0; k < K; ++k) + acc += (int32_t)A[m * A_s0 + k * A_s1] * + (int32_t)B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } +} + +#endif // __riscv_vector + +//===----------------------------------------------------------------------===// +// Public entry point +//===----------------------------------------------------------------------===// + +void my_matmul_kernel_i8( + const int8_t *A, int64_t A_off, int64_t A_s0, int64_t A_s1, + const int8_t *B, int64_t B_off, int64_t B_s0, int64_t B_s1, + int32_t *C, int64_t C_off, int64_t C_s0, int64_t C_s1, + int64_t M, int64_t N, int64_t K) { + rvv_matmul_i8_core(A, A_off, A_s0, A_s1, + B, B_off, B_s0, B_s1, + C, C_off, C_s0, C_s1, + M, N, K); +} diff --git a/runtime/src/iree/builtins/mips/matmul_plugin.c b/runtime/src/iree/builtins/mips/matmul_plugin.c index 6f42338934aa..5d5765351bf5 100644 --- a/runtime/src/iree/builtins/mips/matmul_plugin.c +++ b/runtime/src/iree/builtins/mips/matmul_plugin.c @@ -4,15 +4,16 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// IREE Executable Plugin interface for the MIPS matmul kernel. +// IREE Executable Plugin interface for the MIPS matmul kernels. // -// This file wires my_matmul_kernel into the IREE HAL plugin ABI so the -// function can be resolved at runtime via --executable_plugin. +// Registers both the f32 and INT8 kernels: +// my_matmul_kernel — f32 matmul +// my_matmul_kernel_i8 — INT8 (i8 inputs, i32 accumulator) matmul // -// Build as a shared library alongside matmul_kernel.c: +// Build as a shared library alongside matmul_kernel.c and matmul_kernel_i8.c: // clang --target=riscv64-linux-gnu -march=rv64gcv -mabi=lp64d \ // -O2 -fPIC -shared -nostdinc ... \ -// matmul_kernel.c matmul_plugin.c -o librvv_matmul.so +// matmul_kernel.c matmul_kernel_i8.c matmul_plugin.c -o librvv_matmul.so #include "matmul_kernel.h" @@ -26,6 +27,7 @@ // where params_ptr points to a packed struct matching the func.call ABI // emitted by MIPSBufferizableOpInterface::bufferize() → decomposeMemref2D(). +// ── f32 kernel args ────────────────────────────────────────────────────────── typedef struct { float *A; int64_t A_off, A_s0, A_s1; @@ -48,6 +50,29 @@ static int matmul_kernel_import(void *params_ptr, void *context, return 0; } +// ── INT8 kernel args ────────────────────────────────────────────────────────── +typedef struct { + int8_t *A; + int64_t A_off, A_s0, A_s1; + int8_t *B; + int64_t B_off, B_s0, B_s1; + int32_t *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} matmul_i8_kernel_args_t; + +static int matmul_i8_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const matmul_i8_kernel_args_t *a = (const matmul_i8_kernel_args_t *)params_ptr; + my_matmul_kernel_i8(a->A, a->A_off, a->A_s0, a->A_s1, + a->B, a->B_off, a->B_s0, a->B_s1, + a->C, a->C_off, a->C_s0, a->C_s1, + a->M, a->N, a->K); + return 0; +} + //===----------------------------------------------------------------------===// // Plugin lifecycle //===----------------------------------------------------------------------===// @@ -81,6 +106,9 @@ static iree_hal_executable_plugin_status_t plugin_resolve( if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { params->out_fn_ptrs[i] = matmul_kernel_import; params->out_fn_contexts[i] = NULL; + } else if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel_i8") == 0) { + params->out_fn_ptrs[i] = matmul_i8_kernel_import; + params->out_fn_contexts[i] = NULL; } else { if (!optional) any_required_not_found = true; } @@ -102,7 +130,7 @@ iree_hal_executable_plugin_query( static const iree_hal_executable_plugin_header_t header = { .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, .name = "mips_matmul", - .description = "RISC-V RVV 1.0 matmul kernel plugin", + .description = "RISC-V RVV 1.0 matmul kernel plugin (f32 + INT8)", .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, };