GitOrigin-RevId: eb6bcadf54
tags/v1.0.0-rc1
| @@ -77,7 +77,7 @@ if (MGE_USE_SYSTEM_LIB) | |||||
| endif() | endif() | ||||
| endfunction(find_mlir_llvm_lib) | endfunction(find_mlir_llvm_lib) | ||||
| set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms) | |||||
| set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) | |||||
| foreach(c ${MLIR_COMPONENTS}) | foreach(c ${MLIR_COMPONENTS}) | ||||
| find_mlir_llvm_lib(${c}) | find_mlir_llvm_lib(${c}) | ||||
| @@ -120,4 +120,4 @@ set(MLIR_LLVM_INCLUDE_DIR | |||||
| ) | ) | ||||
| set(MLIR_TABLEGEN_EXE mlir-tblgen) | set(MLIR_TABLEGEN_EXE mlir-tblgen) | ||||
| set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms) | |||||
| set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) | |||||
| @@ -64,7 +64,6 @@ mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, | |||||
| void add_cpu_lowering_pass(mlir::PassManager& manager) { | void add_cpu_lowering_pass(mlir::PassManager& manager) { | ||||
| { | { | ||||
| mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | ||||
| opt_pm.addPass(create_shape_inference_pass()); | |||||
| opt_pm.addPass(mlir::createCanonicalizerPass()); | opt_pm.addPass(mlir::createCanonicalizerPass()); | ||||
| opt_pm.addPass(mlir::createCSEPass()); | opt_pm.addPass(mlir::createCSEPass()); | ||||
| } | } | ||||
| @@ -84,7 +83,6 @@ void add_cpu_lowering_pass(mlir::PassManager& manager) { | |||||
| void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { | void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { | ||||
| { | { | ||||
| mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>(); | ||||
| opt_pm.addPass(create_shape_inference_pass()); | |||||
| opt_pm.addPass(mlir::createCanonicalizerPass()); | opt_pm.addPass(mlir::createCanonicalizerPass()); | ||||
| opt_pm.addPass(mlir::createCSEPass()); | opt_pm.addPass(mlir::createCSEPass()); | ||||
| } | } | ||||
| @@ -14,9 +14,10 @@ | |||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "./executable_cpu.h" | #include "./executable_cpu.h" | ||||
| #include "./utils.h" | |||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include <mlir/ExecutionEngine/OptUtils.h> | #include <mlir/ExecutionEngine/OptUtils.h> | ||||
| #include <mlir/ExecutionEngine/CRunnerUtils.h> | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace jit; | using namespace jit; | ||||
| @@ -113,7 +114,7 @@ void MLIRCPUExecutable::execute(JITExecutor* fusion_opr) { | |||||
| idx++; | idx++; | ||||
| } | } | ||||
| args_array_pointer[idx++] = &nr_elements; | |||||
| args_array_pointer.push_back(&nr_elements); | |||||
| std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name; | std::string adapter_name = std::string("_mlir_ciface_") + m_kernel_name; | ||||
| auto err = m_engine->invoke( | auto err = m_engine->invoke( | ||||
| adapter_name, llvm::MutableArrayRef<void*>(args_array_pointer)); | adapter_name, llvm::MutableArrayRef<void*>(args_array_pointer)); | ||||
| @@ -17,13 +17,15 @@ | |||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| #include "./executable_cuda.h" | #include "./executable_cuda.h" | ||||
| #include "./utils.h" | |||||
| #include "megbrain/utils/timer.h" | |||||
| #include "megbrain/utils/persistent_cache.h" | |||||
| #include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include "megbrain/utils/persistent_cache.h" | |||||
| #include "megbrain/utils/timer.h" | |||||
| #include <mlir/ExecutionEngine/OptUtils.h> | |||||
| #include <mlir/Dialect/GPU/GPUDialect.h> | #include <mlir/Dialect/GPU/GPUDialect.h> | ||||
| #include <mlir/ExecutionEngine/CRunnerUtils.h> | |||||
| #include <mlir/ExecutionEngine/OptUtils.h> | |||||
| #include <mlir/IR/OpDefinition.h> | #include <mlir/IR/OpDefinition.h> | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -8,12 +8,12 @@ external_tablegen_library( | |||||
| TBLGEN | TBLGEN | ||||
| MLIR | MLIR | ||||
| SRCS | SRCS | ||||
| "shape_inference_interface.td" | |||||
| "interfaces.td" | |||||
| INCLUDES | INCLUDES | ||||
| ${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} | ${MGB_MLIR_TABLEGEN_INC} ${MLIR_LLVM_INCLUDE_DIR} | ||||
| OUTS | OUTS | ||||
| -gen-op-interface-decls include/megbrain/jit/mlir/ir/shape_inference_interface.h.inc | |||||
| -gen-op-interface-defs include/megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc | |||||
| -gen-op-interface-decls include/megbrain/jit/mlir/ir/interfaces.h.inc | |||||
| -gen-op-interface-defs include/megbrain/jit/mlir/ir/interfaces.cpp.inc | |||||
| ) | ) | ||||
| external_tablegen_library( | external_tablegen_library( | ||||
| @@ -13,29 +13,88 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "common.h" | |||||
| #include "./common.h" | |||||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace jit; | using namespace jit; | ||||
| mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, | |||||
| mlir::Location loc, | |||||
| mlir::PatternRewriter& rewriter) { | |||||
| auto alloc = rewriter.create<mlir::AllocOp>(loc, type); | |||||
| #define cb(name, op) \ | |||||
| mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | |||||
| return m_builder.create<mlir::op>(m_location, lhs, rhs); \ | |||||
| } | |||||
| cb(add, AddFOp); | |||||
| cb(sub, SubFOp); | |||||
| cb(mul, MulFOp); | |||||
| cb(div, DivFOp); | |||||
| cb(mod, RemFOp); | |||||
| #undef cb | |||||
| // Make sure to allocate at the beginning of the block. | |||||
| auto* parent_block = alloc.getOperation()->getBlock(); | |||||
| alloc.getOperation()->moveBefore(&parent_block->front()); | |||||
| #define cb(name, mode) \ | |||||
| mlir::Value ValueBuilderHelper::name(mlir::Value lhs, mlir::Value rhs) { \ | |||||
| return m_builder.create<mlir::CmpFOp>( \ | |||||
| m_location, mlir::CmpFPredicate::mode, lhs, rhs); \ | |||||
| } | |||||
| cb(gt, OGT); | |||||
| cb(ge, OGE); | |||||
| cb(lt, OLT); | |||||
| cb(le, OLE); | |||||
| cb(eq, OEQ); | |||||
| #undef cb | |||||
| // Make sure to deallocate this alloc at the end of the block. This is fine | |||||
| // as toy functions have no control flow. | |||||
| auto dealloc = rewriter.create<mlir::DeallocOp>(loc, alloc); | |||||
| dealloc.getOperation()->moveBefore(&parent_block->back()); | |||||
| return alloc; | |||||
| mlir::Value ValueBuilderHelper::min(mlir::Value lhs, mlir::Value rhs) { | |||||
| mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | |||||
| m_location, mlir::CmpFPredicate::OLT, lhs, rhs); | |||||
| return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | |||||
| } | |||||
| mlir::Value ValueBuilderHelper::max(mlir::Value lhs, mlir::Value rhs) { | |||||
| mlir::Value cmp = m_builder.create<mlir::CmpFOp>( | |||||
| m_location, mlir::CmpFPredicate::OGT, lhs, rhs); | |||||
| return m_builder.create<mlir::SelectOp>(m_location, cmp, lhs, rhs); | |||||
| } | |||||
| mlir::Value ValueBuilderHelper::const_val(float val) { | |||||
| return m_builder.create<mlir::ConstantOp>(m_location, | |||||
| m_builder.getF32FloatAttr(val)); | |||||
| } | |||||
| #define cb(name, op) \ | |||||
| mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \ | |||||
| return m_builder.create<mlir::op>(m_location, lhs); \ | |||||
| } | |||||
| cb(neg, NegFOp); | |||||
| cb(abs, AbsFOp); | |||||
| cb(ceil, CeilFOp); | |||||
| cb(cos, CosOp); | |||||
| cb(exp, ExpOp); | |||||
| cb(exp2, Exp2Op); | |||||
| cb(log10, Log10Op); | |||||
| cb(log2, Log2Op); | |||||
| cb(rsqrt, RsqrtOp); | |||||
| cb(sin, SinOp); | |||||
| cb(sqrt, SqrtOp); | |||||
| cb(tanh, TanhOp); | |||||
| #undef cb | |||||
| mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | |||||
| //! FIXME use standard floor when upgrade llvm | |||||
| return neg(ceil(neg(lhs))); | |||||
| } | |||||
| mlir::Value ValueBuilderHelper::log(mlir::Value lhs) { | |||||
| // math.log10(math.e) = 0.4342944819032518f | |||||
| return div(log10(lhs), const_val(0.4342944819032518f)); | |||||
| } | |||||
| mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | |||||
| mlir::Value false_val) { | |||||
| return m_builder.create<mlir::SelectOp>(m_location, cond, true_val, | |||||
| false_val); | |||||
| } | } | ||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -14,19 +15,71 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include <mlir/IR/PatternMatch.h> | |||||
| #include <mlir/IR/StandardTypes.h> | |||||
| #include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
| #include <mlir/IR/OperationSupport.h> | |||||
| #include <mlir/IR/Value.h> | #include <mlir/IR/Value.h> | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace jit { | namespace jit { | ||||
| mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, | |||||
| mlir::PatternRewriter& rewriter); | |||||
| /** | |||||
| * \brief Helper function for common value builder | |||||
| */ | |||||
| class ValueBuilderHelper { | |||||
| public: | |||||
| ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location) | |||||
| : m_builder{b}, m_location{location} {}; | |||||
| #define cb(name) \ | |||||
| mlir::Value name(mlir::ValueRange operands) { \ | |||||
| return name(operands[0], operands[1]); \ | |||||
| } \ | |||||
| mlir::Value name(mlir::Value lhs, mlir::Value rhs) | |||||
| cb(add); | |||||
| cb(sub); | |||||
| cb(mul); | |||||
| cb(div); | |||||
| cb(max); | |||||
| cb(min); | |||||
| cb(mod); | |||||
| cb(gt); | |||||
| cb(ge); | |||||
| cb(lt); | |||||
| cb(le); | |||||
| cb(eq); | |||||
| #undef cb | |||||
| mlir::Value const_val(float val); | |||||
| #define cb(name) \ | |||||
| mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \ | |||||
| mlir::Value name(mlir::Value lhs) | |||||
| cb(neg); | |||||
| cb(abs); | |||||
| cb(ceil); | |||||
| cb(floor); | |||||
| cb(cos); | |||||
| cb(exp); | |||||
| cb(exp2); | |||||
| cb(log10); | |||||
| cb(log2); | |||||
| cb(log); | |||||
| cb(rsqrt); | |||||
| cb(sin); | |||||
| cb(sqrt); | |||||
| cb(tanh); | |||||
| #undef cb | |||||
| mlir::Value select(mlir::Value cond, mlir::Value true_val, | |||||
| mlir::Value false_val); | |||||
| private: | |||||
| mlir::OpBuilder& m_builder; | |||||
| mlir::Location m_location; | |||||
| }; | |||||
| } // namespace jit | } // namespace jit | ||||
| } // namespace mgb | } // namespace mgb | ||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -15,77 +15,26 @@ | |||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include <mlir/Support/LogicalResult.h> | |||||
| #include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
| #include <mlir/IR/OpImplementation.h> | #include <mlir/IR/OpImplementation.h> | ||||
| #include <mlir/IR/StandardTypes.h> | #include <mlir/IR/StandardTypes.h> | ||||
| #include <mlir/Support/LogicalResult.h> | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace jit; | using namespace jit; | ||||
| MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) { | |||||
| addOperations< | |||||
| MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx) { | |||||
| addOperations< | |||||
| #define GET_OP_LIST | #define GET_OP_LIST | ||||
| #include "megbrain/jit/mlir/ir/ops.cpp.inc" | #include "megbrain/jit/mlir/ir/ops.cpp.inc" | ||||
| >(); | |||||
| } | |||||
| static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, | |||||
| mlir::OperationState &result) { | |||||
| SmallVector<mlir::OpAsmParser::OperandType, 2> operands; | |||||
| llvm::SMLoc operandsLoc = parser.getCurrentLocation(); | |||||
| Type type; | |||||
| if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || | |||||
| parser.parseOptionalAttrDict(result.attributes) || | |||||
| parser.parseColonType(type)) | |||||
| return mlir::failure(); | |||||
| // If the type is a function type, it contains the input and result types of | |||||
| // this operation. | |||||
| if (FunctionType funcType = type.dyn_cast<FunctionType>()) { | |||||
| if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, | |||||
| result.operands)) | |||||
| return mlir::failure(); | |||||
| result.addTypes(funcType.getResults()); | |||||
| return mlir::success(); | |||||
| } | |||||
| // Otherwise, the parsed type is the type of both operands and results. | |||||
| if (parser.resolveOperands(operands, type, result.operands)) | |||||
| return mlir::failure(); | |||||
| result.addTypes(type); | |||||
| return mlir::success(); | |||||
| } | |||||
| static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { | |||||
| printer << op->getName() << " " << op->getOperands(); | |||||
| printer.printOptionalAttrDict(op->getAttrs()); | |||||
| printer << " : "; | |||||
| // If all of the types are the same, print the type directly. | |||||
| Type resultType = *op->result_type_begin(); | |||||
| if (llvm::all_of(op->getOperandTypes(), | |||||
| [=](Type type) { return type == resultType; })) { | |||||
| printer << resultType; | |||||
| return; | |||||
| } | |||||
| // Otherwise, print a functional type. | |||||
| printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); | |||||
| >(); | |||||
| } | } | ||||
| ///////////////////////// ElemwiseOp ///////////////////////////////////////////// | |||||
| void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, | |||||
| mlir::Value lhs, mlir::Value rhs) { | |||||
| state.addTypes(lhs.getType()); | |||||
| state.addOperands({lhs, rhs}); | |||||
| } | |||||
| void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); } | |||||
| #define GET_OP_CLASSES | #define GET_OP_CLASSES | ||||
| #include "megbrain/jit/mlir/ir/ops.cpp.inc" | #include "megbrain/jit/mlir/ir/ops.cpp.inc" | ||||
| #include "megbrain/jit/mlir/ir/interfaces.cpp.inc" | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,412 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/each_mode.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain_build_config.h" | |||||
| #if MGB_JIT && MGB_JIT_MLIR | |||||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||||
| #include "./common.h" | |||||
| #include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
| #include <mlir/IR/Builders.h> | |||||
| #include <mlir/IR/Value.h> | |||||
| // clang-format off | |||||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) \ | |||||
| cb(ReluOp, RELU) \ | |||||
| cb(AbsOp, ABS) \ | |||||
| cb(NegOp, NEGATE) \ | |||||
| cb(CeilOp, CEIL) \ | |||||
| cb(CosOp, COS) \ | |||||
| cb(ExpOp, EXP) \ | |||||
| cb(FloorOp, FLOOR) \ | |||||
| cb(LogOp, LOG) \ | |||||
| cb(Log1POp, LOG1P) \ | |||||
| cb(SigmoidOp, SIGMOID) \ | |||||
| cb(SinOp, SIN) \ | |||||
| cb(TanhOp, TANH) \ | |||||
| cb(FastTanhOp, FAST_TANH) \ | |||||
| cb(HswishOp, H_SWISH) \ | |||||
| cb(ExpM1Op, EXPM1) \ | |||||
| cb(RoundOp, ROUND) | |||||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | |||||
| cb(AbsGradOp, ABS_GRAD) \ | |||||
| cb(AddOp, ADD) \ | |||||
| cb(FloorDivOp, FLOOR_DIV) \ | |||||
| cb(MaxOp, MAX) \ | |||||
| cb(MinOp, MIN) \ | |||||
| cb(ModOp, MOD) \ | |||||
| cb(SubOp, SUB) \ | |||||
| cb(MulOp, MUL) \ | |||||
| cb(TrueDivOp, TRUE_DIV) \ | |||||
| cb(SigmoidGradOp, SIGMOID_GRAD) \ | |||||
| cb(SwishGt0Op, SWITCH_GT0) \ | |||||
| cb(TanhGradOp, TANH_GRAD) \ | |||||
| cb(LtOp, LT) \ | |||||
| cb(LeqOp, LEQ) \ | |||||
| cb(EqOp, EQ) \ | |||||
| cb(FuseAddReluOp, FUSE_ADD_RELU) \ | |||||
| cb(LogSumExpOp, LOG_SUM_EXP) \ | |||||
| cb(FuseAddTanhOp, FUSE_ADD_TANH) \ | |||||
| cb(FastTanhGradOp, FAST_TANH_GRAD) \ | |||||
| cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | |||||
| cb(HswishGradOp, H_SWISH_GRAD) \ | |||||
| cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) | |||||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | |||||
| cb(CondLeqMovOp, COND_LEQ_MOV) \ | |||||
| cb(FuseMulAdd3Op, FUSE_MUL_ADD3) | |||||
| // clang-format on | |||||
| namespace mgb { | |||||
| namespace jit { | |||||
| template <typename mgb_op> | |||||
| struct StandardOp; | |||||
| #define cb(mgb_op, fun) \ | |||||
| template <> \ | |||||
| struct StandardOp<jit::mgb_op> { \ | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ | |||||
| ValueRange operands) { \ | |||||
| ValueBuilderHelper helper(builder, loc); \ | |||||
| return helper.fun(operands); \ | |||||
| } \ | |||||
| } | |||||
| //! unary | |||||
| cb(AbsOp, abs); | |||||
| cb(NegOp, neg); | |||||
| cb(ExpOp, exp); | |||||
| cb(CosOp, cos); | |||||
| cb(CeilOp, ceil); | |||||
| cb(FloorOp, floor); | |||||
| cb(LogOp, log); | |||||
| cb(SinOp, sin); | |||||
| cb(TanhOp, tanh); | |||||
| //! binary | |||||
| cb(AddOp, add); | |||||
| cb(MaxOp, max); | |||||
| cb(MinOp, min); | |||||
| cb(SubOp, sub); | |||||
| cb(MulOp, mul); | |||||
| cb(ModOp, mod); | |||||
| cb(TrueDivOp, div); | |||||
| #undef cb | |||||
| /////////////////////////// unary op /////////////////////////// | |||||
| //! max(x, 0) | |||||
| template <> | |||||
| struct StandardOp<jit::ReluOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.max(operands[0], helper.const_val(0.f)); | |||||
| } | |||||
| }; | |||||
| //! x * (27.f + x * x) / (27.f + 9.f * x * x); | |||||
| template <> | |||||
| struct StandardOp<jit::FastTanhOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| auto square = helper.mul(operands[0], operands[0]); | |||||
| return helper.div( | |||||
| helper.mul(operands[0], | |||||
| helper.add(helper.const_val(27.f), square)), | |||||
| helper.add(helper.const_val(27.f), | |||||
| helper.mul(helper.const_val(9.f), square))); | |||||
| } | |||||
| }; | |||||
| //! x * clip(x + 3, 0, 6) / 6 | |||||
| template <> | |||||
| struct StandardOp<jit::HswishOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| auto const_3 = helper.const_val(3.f); | |||||
| auto const_0 = helper.const_val(0.f); | |||||
| auto const_6 = helper.const_val(6.f); | |||||
| auto tmp = helper.add(operands[0], const_3); | |||||
| return helper.div( | |||||
| helper.mul(operands[0], | |||||
| helper.min(helper.max(tmp, const_0), const_6)), | |||||
| const_6); | |||||
| } | |||||
| }; | |||||
| //! log(1 + p) | |||||
| template <> | |||||
| struct StandardOp<jit::Log1POp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.log(helper.add(operands[0], helper.const_val(1.f))); | |||||
| } | |||||
| }; | |||||
| //! 1.f / (expf(-y) + 1.f)) | |||||
| template <> | |||||
| struct StandardOp<jit::SigmoidOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.div(helper.const_val(1.f), | |||||
| helper.add(helper.exp(helper.neg(operands[0])), | |||||
| helper.const_val(1.f))); | |||||
| } | |||||
| }; | |||||
| //! exp(x) - 1 | |||||
| template <> | |||||
| struct StandardOp<jit::ExpM1Op> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.sub(helper.exp(operands[0]), helper.const_val(1.f)); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct StandardOp<jit::RoundOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select( | |||||
| helper.gt(operands[0], helper.const_val(0.f)), | |||||
| helper.floor(helper.add(operands[0], helper.const_val(0.5f))), | |||||
| helper.ceil(helper.sub(operands[0], helper.const_val(0.5f)))); | |||||
| } | |||||
| }; | |||||
| /////////////////////////// binary op /////////////////////////// | |||||
| //! binary: x > 0 ? y : -y | |||||
| template <> | |||||
| struct StandardOp<jit::AbsGradOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select(helper.gt(operands[0], helper.const_val(0.f)), | |||||
| operands[1], helper.neg(operands[1])); | |||||
| } | |||||
| }; | |||||
| //! x * (1 - x) * y | |||||
| template <> | |||||
| struct StandardOp<jit::SigmoidGradOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.mul( | |||||
| helper.mul(operands[0], | |||||
| helper.sub(helper.const_val(1.f), operands[0])), | |||||
| operands[1]); | |||||
| } | |||||
| }; | |||||
| //! (x > 0) * y | |||||
| template <> | |||||
| struct StandardOp<jit::SwishGt0Op> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select(helper.gt(operands[0], helper.const_val(0.f)), | |||||
| operands[1], helper.const_val(0.f)); | |||||
| } | |||||
| }; | |||||
| //! (1 - x * x) * y | |||||
| template <> | |||||
| struct StandardOp<jit::TanhGradOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.mul(helper.sub(helper.const_val(1.0f), | |||||
| helper.mul(operands[0], operands[0])), | |||||
| operands[1]); | |||||
| } | |||||
| }; | |||||
| #define cb(op, fun) \ | |||||
| template <> \ | |||||
| struct StandardOp<jit::op> { \ | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, \ | |||||
| ValueRange operands) { \ | |||||
| ValueBuilderHelper helper(builder, loc); \ | |||||
| return helper.select(helper.fun(operands[0], operands[1]), \ | |||||
| helper.const_val(1.f), \ | |||||
| helper.const_val(0.f)); \ | |||||
| } \ | |||||
| } | |||||
| cb(LtOp, lt); | |||||
| cb(LeqOp, le); | |||||
| cb(EqOp, eq); | |||||
| #undef cb | |||||
| //! (x + y) <= ctype(0) ? ctype(0) : (x + y) | |||||
| template <> | |||||
| struct StandardOp<jit::FuseAddReluOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| auto sum = helper.add(operands[0], operands[1]); | |||||
| return helper.max(sum, helper.const_val(0.f)); | |||||
| } | |||||
| }; | |||||
| //! log(exp(x) + exp(y)) | |||||
| template <> | |||||
| struct StandardOp<jit::LogSumExpOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.log( | |||||
| helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); | |||||
| } | |||||
| }; | |||||
| //! floor(x/y) | |||||
| template <> | |||||
| struct StandardOp<jit::FloorDivOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.floor(helper.div(operands[0], operands[1])); | |||||
| } | |||||
| }; | |||||
| //! tanh(x + y) | |||||
| template <> | |||||
| struct StandardOp<jit::FuseAddTanhOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.tanh(helper.add(operands[0], operands[1])); | |||||
| } | |||||
| }; | |||||
| //! ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x * x) * y | |||||
| template <> | |||||
| struct StandardOp<jit::FastTanhGradOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| auto x_pow2 = helper.mul(operands[0], operands[0]); | |||||
| auto deno = helper.add(helper.const_val(3.f), x_pow2); | |||||
| return helper.mul( | |||||
| helper.div( | |||||
| helper.add( | |||||
| helper.add( | |||||
| helper.div(helper.mul(helper.const_val( | |||||
| -48.f), | |||||
| x_pow2), | |||||
| deno), | |||||
| helper.const_val(27.f)), | |||||
| x_pow2), | |||||
| helper.mul(deno, helper.const_val(9.f))), | |||||
| operands[1]); | |||||
| } | |||||
| }; | |||||
| //! 1.f / (expf(-(x+y)) + 1.f)) | |||||
| template <> | |||||
| struct StandardOp<jit::FuseAddSigmoidOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.div(helper.const_val(1.f), | |||||
| helper.add(helper.exp(helper.neg(helper.add( | |||||
| operands[0], operands[1]))), | |||||
| helper.const_val(1.f))); | |||||
| } | |||||
| }; | |||||
| //! x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) | |||||
| template <> | |||||
| struct StandardOp<jit::HswishGradOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select( | |||||
| helper.lt(operands[0], helper.const_val(-3.f)), | |||||
| helper.const_val(0.f), | |||||
| helper.select( | |||||
| helper.gt(operands[0], helper.const_val(3.f)), | |||||
| operands[1], | |||||
| helper.mul( | |||||
| helper.div( | |||||
| helper.add(helper.mul(helper.const_val( | |||||
| 2.f), | |||||
| operands[0]), | |||||
| helper.const_val(3.f)), | |||||
| helper.const_val(6.f)), | |||||
| operands[1]))); | |||||
| } | |||||
| }; | |||||
| //! (x+y) * min(max(x + y + 3, 0), 6) * (1/6) | |||||
| template <> | |||||
| struct StandardOp<jit::FuseAddHswishOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| auto sum = helper.add(operands[0], operands[1]); | |||||
| auto const_3 = helper.const_val(3.f); | |||||
| auto const_0 = helper.const_val(0.f); | |||||
| auto const_6 = helper.const_val(6.f); | |||||
| auto tmp = helper.add(sum, const_3); | |||||
| return helper.div( | |||||
| helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), | |||||
| const_6); | |||||
| } | |||||
| }; | |||||
| /////////////////////////// ternary op /////////////////////////// | |||||
| //! x <= y ? z : ctype(0) | |||||
| template <> | |||||
| struct StandardOp<jit::CondLeqMovOp> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select(helper.le(operands[0], operands[1]), operands[2], | |||||
| helper.const_val(0.f)); | |||||
| } | |||||
| }; | |||||
| //! x * y + z | |||||
| template <> | |||||
| struct StandardOp<jit::FuseMulAdd3Op> { | |||||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
| ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.add(helper.mul(operands[0], operands[1]), operands[2]); | |||||
| } | |||||
| }; | |||||
| } // namespace jit | |||||
| } // namespace mgb | |||||
| #endif // MGB_JIT_MLIR | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/interfaces.td | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #ifndef MGB_MLIR_INTERFACES | |||||
| #define MGB_MLIR_INTERFACES | |||||
| #ifndef OP_BASE | |||||
| include "mlir/IR/OpBase.td" | |||||
| #endif | |||||
| def GenericBuilderInterface : OpInterface<"GenericBuilder"> { | |||||
| let methods = [ | |||||
| StaticInterfaceMethod<"TODO", "Type", "getResultType", (ins "ArrayRef<Value>":$operands)>, | |||||
| StaticInterfaceMethod<"TODO", "Operation*", "create", (ins | |||||
| "OpBuilder*":$builder, | |||||
| "Location":$loc, | |||||
| "ArrayRef<Value>":$operands | |||||
| )>, | |||||
| ]; | |||||
| } | |||||
| def ElemwiseOpInterface : OpInterface<"ElemwiseOp">; | |||||
| #endif | |||||
| @@ -16,11 +16,11 @@ | |||||
| #include "megbrain/common.h" | #include "megbrain/common.h" | ||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include "megbrain/jit/mlir/ir/passes.h" | #include "megbrain/jit/mlir/ir/passes.h" | ||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include "./common.h" | |||||
| #include "./each_mode.h" | |||||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
| #include <mlir/Dialect/StandardOps/IR/Ops.h> | |||||
| #include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
| #include <mlir/Transforms/DialectConversion.h> | #include <mlir/Transforms/DialectConversion.h> | ||||
| @@ -57,10 +57,40 @@ void lower_op_to_loops(Operation* op, ValueRange operands, | |||||
| rewriter.replaceOp(op, alloc); | rewriter.replaceOp(op, alloc); | ||||
| } | } | ||||
| template <typename BinaryOp, typename LoweredBinaryOp> | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct UnaryOpLowering : public ConversionPattern { | |||||
| UnaryOpLowering(MLIRContext* ctx) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
| LogicalResult matchAndRewrite( | |||||
| Operation* op, ArrayRef<Value> operands, | |||||
| ConversionPatternRewriter& rewriter) const final { | |||||
| auto loc = op->getLoc(); | |||||
| lower_op_to_loops( | |||||
| op, operands, rewriter, | |||||
| [loc](OpBuilder& builder, ValueRange memref_operands, | |||||
| ValueRange loop_ivs) { | |||||
| typename Op::Adaptor binary_adaptor(memref_operands); | |||||
| LoweredOp lower_op; | |||||
| auto loaded_lhs = builder.create<AffineLoadOp>( | |||||
| loc, binary_adaptor.lhs(), loop_ivs); | |||||
| return lower_op(builder, loc, {loaded_lhs}); | |||||
| }); | |||||
| return success(); | |||||
| } | |||||
| }; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
| #undef cb | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct BinaryOpLowering : public ConversionPattern { | struct BinaryOpLowering : public ConversionPattern { | ||||
| BinaryOpLowering(MLIRContext* ctx) | BinaryOpLowering(MLIRContext* ctx) | ||||
| : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
| LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
| Operation* op, ArrayRef<Value> operands, | Operation* op, ArrayRef<Value> operands, | ||||
| @@ -70,20 +100,61 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| op, operands, rewriter, | op, operands, rewriter, | ||||
| [loc](OpBuilder& builder, ValueRange memref_operands, | [loc](OpBuilder& builder, ValueRange memref_operands, | ||||
| ValueRange loop_ivs) { | ValueRange loop_ivs) { | ||||
| typename BinaryOp::Adaptor binary_adaptor(memref_operands); | |||||
| typename Op::Adaptor binary_adaptor(memref_operands); | |||||
| LoweredOp lower_op; | |||||
| auto loaded_lhs = builder.create<AffineLoadOp>( | auto loaded_lhs = builder.create<AffineLoadOp>( | ||||
| loc, binary_adaptor.lhs(), loop_ivs); | loc, binary_adaptor.lhs(), loop_ivs); | ||||
| auto loaded_rhs = builder.create<AffineLoadOp>( | auto loaded_rhs = builder.create<AffineLoadOp>( | ||||
| loc, binary_adaptor.rhs(), loop_ivs); | loc, binary_adaptor.rhs(), loop_ivs); | ||||
| return builder.create<LoweredBinaryOp>(loc, loaded_lhs, | |||||
| loaded_rhs); | |||||
| return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | |||||
| }); | }); | ||||
| return success(); | return success(); | ||||
| } | } | ||||
| }; | }; | ||||
| using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| #undef cb | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct TernaryOpLowering : public ConversionPattern { | |||||
| TernaryOpLowering(MLIRContext* ctx) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx) {} | |||||
| LogicalResult matchAndRewrite( | |||||
| Operation* op, ArrayRef<Value> operands, | |||||
| ConversionPatternRewriter& rewriter) const final { | |||||
| auto loc = op->getLoc(); | |||||
| lower_op_to_loops( | |||||
| op, operands, rewriter, | |||||
| [loc](OpBuilder& builder, ValueRange memref_operands, | |||||
| ValueRange loop_ivs) { | |||||
| typename Op::Adaptor ternary_adaptor(memref_operands); | |||||
| LoweredOp lower_op; | |||||
| auto loaded_x = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.x(), loop_ivs); | |||||
| auto loaded_y = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.y(), loop_ivs); | |||||
| auto loaded_z = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.z(), loop_ivs); | |||||
| return lower_op(builder, loc, | |||||
| {loaded_x, loaded_y, loaded_z}); | |||||
| }); | |||||
| return success(); | |||||
| } | |||||
| }; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = \ | |||||
| TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| #undef cb | |||||
| struct AssignOpLowering : public ConversionPattern { | struct AssignOpLowering : public ConversionPattern { | ||||
| AssignOpLowering(MLIRContext* ctx) | AssignOpLowering(MLIRContext* ctx) | ||||
| @@ -126,21 +197,18 @@ class MgbToAffineLoweringPass | |||||
| : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | ||||
| public: | public: | ||||
| void runOnFunction() override final { | void runOnFunction() override final { | ||||
| auto function = getFunction(); | |||||
| // Verify that the given main has no inputs and results. | |||||
| if (function.getType().getNumResults()) { | |||||
| mgb_log_error("expected 'main' to have 0 results"); | |||||
| return signalPassFailure(); | |||||
| } | |||||
| ConversionTarget target(getContext()); | ConversionTarget target(getContext()); | ||||
| target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | ||||
| target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
| OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
| patterns.insert<AddOpLowering, ReturnOpLowering, AssignOpLowering>( | |||||
| &getContext()); | |||||
| #define cb(_op, _) _op##Lowering, | |||||
| patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY( | |||||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| ReturnOpLowering, | |||||
| AssignOpLowering>(&getContext()); | |||||
| #undef cb | |||||
| if (failed(applyPartialConversion(getFunction(), target, patterns))) { | if (failed(applyPartialConversion(getFunction(), target, patterns))) { | ||||
| signalPassFailure(); | signalPassFailure(); | ||||
| @@ -13,11 +13,11 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "./each_mode.h" | |||||
| #include "megbrain/common.h" | #include "megbrain/common.h" | ||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include "megbrain/jit/mlir/ir/passes.h" | #include "megbrain/jit/mlir/ir/passes.h" | ||||
| #include "../utils.h" | |||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include <mlir/Dialect/GPU/GPUDialect.h> | #include <mlir/Dialect/GPU/GPUDialect.h> | ||||
| #include <mlir/Dialect/SCF/SCF.h> | #include <mlir/Dialect/SCF/SCF.h> | ||||
| @@ -62,10 +62,43 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||||
| return index; | return index; | ||||
| } | } | ||||
| template <typename BinaryOp, typename LoweredBinaryOp> | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct UnaryOpLowering : public ConversionPattern { | |||||
| UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
| m_launch_op{launch_op} {} | |||||
| LogicalResult matchAndRewrite( | |||||
| Operation* op, ArrayRef<Value> operands, | |||||
| ConversionPatternRewriter& rewriter) const final { | |||||
| auto loc = op->getLoc(); | |||||
| typename Op::Adaptor binary_adaptor(operands); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| auto index = get_tid(rewriter, loc); | |||||
| auto loaded_lhs = | |||||
| get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||||
| LoweredOp lower_op; | |||||
| rewriter.replaceOp(op, lower_op(rewriter, loc, {loaded_lhs})); | |||||
| return success(); | |||||
| } | |||||
| private: | |||||
| gpu::LaunchOp* m_launch_op; | |||||
| }; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = UnaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
| #undef cb | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct BinaryOpLowering : public ConversionPattern { | struct BinaryOpLowering : public ConversionPattern { | ||||
| BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | ||||
| : ConversionPattern(BinaryOp::getOperationName(), 1, ctx), | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| LogicalResult matchAndRewrite( | LogicalResult matchAndRewrite( | ||||
| @@ -73,7 +106,7 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| ConversionPatternRewriter& rewriter) const final { | ConversionPatternRewriter& rewriter) const final { | ||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| typename BinaryOp::Adaptor binary_adaptor(operands); | |||||
| typename Op::Adaptor binary_adaptor(operands); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | ||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| @@ -82,10 +115,48 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| auto loaded_rhs = | auto loaded_rhs = | ||||
| get_operand(rewriter, loc, binary_adaptor.rhs(), index); | get_operand(rewriter, loc, binary_adaptor.rhs(), index); | ||||
| auto binary_op = | |||||
| rewriter.create<LoweredBinaryOp>(loc, loaded_lhs, loaded_rhs); | |||||
| LoweredOp lower_op; | |||||
| rewriter.replaceOp(op, | |||||
| lower_op(rewriter, loc, {loaded_lhs, loaded_rhs})); | |||||
| return success(); | |||||
| } | |||||
| private: | |||||
| gpu::LaunchOp* m_launch_op; | |||||
| }; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = BinaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| #undef cb | |||||
| template <typename Op, typename LoweredOp> | |||||
| struct TernaryOpLowering : public ConversionPattern { | |||||
| TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||||
| m_launch_op{launch_op} {} | |||||
| LogicalResult matchAndRewrite( | |||||
| Operation* op, ArrayRef<Value> operands, | |||||
| ConversionPatternRewriter& rewriter) const final { | |||||
| auto loc = op->getLoc(); | |||||
| rewriter.replaceOp(op, binary_op.getResult()); | |||||
| typename Op::Adaptor ternary_adaptor(operands); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| auto index = get_tid(rewriter, loc); | |||||
| auto loaded_x = | |||||
| get_operand(rewriter, loc, ternary_adaptor.x(), index); | |||||
| auto loaded_y = | |||||
| get_operand(rewriter, loc, ternary_adaptor.y(), index); | |||||
| auto loaded_z = | |||||
| get_operand(rewriter, loc, ternary_adaptor.z(), index); | |||||
| LoweredOp lower_op; | |||||
| rewriter.replaceOp( | |||||
| op, lower_op(rewriter, loc, {loaded_x, loaded_y, loaded_z})); | |||||
| return success(); | return success(); | ||||
| } | } | ||||
| @@ -93,7 +164,11 @@ private: | |||||
| gpu::LaunchOp* m_launch_op; | gpu::LaunchOp* m_launch_op; | ||||
| }; | }; | ||||
| using AddOpLowering = BinaryOpLowering<jit::AddOp, AddFOp>; | |||||
| #define cb(_op, _) \ | |||||
| using _op##Lowering = \ | |||||
| TernaryOpLowering<jit::_op, jit::StandardOp<jit::_op>>; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| #undef cb | |||||
| struct ReturnOpLowering : public ConversionPattern { | struct ReturnOpLowering : public ConversionPattern { | ||||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | ||||
| @@ -194,6 +269,14 @@ public: | |||||
| patterns.insert<AddOpLowering, AssignOpLowering, ReturnOpLowering>( | patterns.insert<AddOpLowering, AssignOpLowering, ReturnOpLowering>( | ||||
| &getContext(), &launch_op); | &getContext(), &launch_op); | ||||
| #define cb(_op, _) _op##Lowering, | |||||
| patterns.insert<MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY( | |||||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| ReturnOpLowering, | |||||
| AssignOpLowering>(&getContext(), &launch_op); | |||||
| #undef cb | |||||
| if (failed(applyPartialConversion(func_op, target, patterns))) { | if (failed(applyPartialConversion(func_op, target, patterns))) { | ||||
| signalPassFailure(); | signalPassFailure(); | ||||
| } | } | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <mlir/Conversion/SCFToStandard/SCFToStandard.h> | #include <mlir/Conversion/SCFToStandard/SCFToStandard.h> | ||||
| #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> | #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> | ||||
| #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> | ||||
| #include <mlir/Dialect/StandardOps/Transforms/Passes.h> | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace jit; | using namespace jit; | ||||
| @@ -39,6 +40,7 @@ class AffineToLLVMLoweringPass : public PassWrapper<AffineToLLVMLoweringPass, | |||||
| populateAffineToStdConversionPatterns(patterns, &getContext()); | populateAffineToStdConversionPatterns(patterns, &getContext()); | ||||
| populateLoopToStdConversionPatterns(patterns, &getContext()); | populateLoopToStdConversionPatterns(patterns, &getContext()); | ||||
| populateStdToLLVMConversionPatterns(typeConverter, patterns); | populateStdToLLVMConversionPatterns(typeConverter, patterns); | ||||
| populateExpandTanhPattern(patterns, &getContext()); | |||||
| auto module = getOperation(); | auto module = getOperation(); | ||||
| if (failed(applyFullConversion(module, target, patterns))) | if (failed(applyFullConversion(module, target, patterns))) | ||||
| @@ -16,40 +16,158 @@ | |||||
| include "mlir/IR/OpBase.td" | include "mlir/IR/OpBase.td" | ||||
| include "mlir/Interfaces/SideEffectInterfaces.td" | include "mlir/Interfaces/SideEffectInterfaces.td" | ||||
| include "./shape_inference_interface.td" | |||||
| include "./interfaces.td" | |||||
| def Mgb_Dialect : Dialect { | def Mgb_Dialect : Dialect { | ||||
| let name = "mgb"; | let name = "mgb"; | ||||
| let cppNamespace = "mgb::jit"; | let cppNamespace = "mgb::jit"; | ||||
| } | } | ||||
| class ElemwiseOp<string mnemonic, list<OpTrait> traits = []> : | |||||
| Op<Mgb_Dialect, mnemonic, traits>; | |||||
| class ElemwiseBuilderImpl { | |||||
| code ElemwiseBuilderImpl_create = [{ | |||||
| static Operation* create(OpBuilder* builder, Location loc, ValueRange operands) { | |||||
| OperationState state(loc, getOperationName()); | |||||
| state.addOperands(operands); | |||||
| state.addTypes(getResultType(operands)); | |||||
| return builder->createOperation(state); | |||||
| } | |||||
| }]; | |||||
| } | |||||
| class ElemwiseOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
| Op<Mgb_Dialect, mnemonic, !listconcat(traits, [ElemwiseOpInterface, | |||||
| GenericBuilderInterface])>, ElemwiseBuilderImpl; | |||||
| class GenericOp<string mnemonic, list<OpTrait> traits = []> : | class GenericOp<string mnemonic, list<OpTrait> traits = []> : | ||||
| Op<Mgb_Dialect, mnemonic, traits>; | Op<Mgb_Dialect, mnemonic, traits>; | ||||
| def AddOp : ElemwiseOp<"add", | |||||
| [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | |||||
| let summary = "element-wise addition operation"; | |||||
| let description = [{ | |||||
| The "add" operation performs element-wise addition between two tensors. | |||||
| The shapes of the tensor operands are expected to match. | |||||
| }]; | |||||
| class ElemwiseUnaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
| ElemwiseOp<mnemonic, traits> { | |||||
| let arguments = (ins F32MemRef:$lhs); | |||||
| let results = (outs F32MemRef); | |||||
| let builders = [OpBuilder< | |||||
| "Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
| result.addOperands(operands); | |||||
| result.addTypes(getResultType(operands)); | |||||
| }]>, OpBuilder < | |||||
| "OpBuilder& builder, OperationState& result, Value lhs", [{ | |||||
| result.addOperands(lhs); | |||||
| result.addTypes(getResultType({lhs})); | |||||
| }] | |||||
| >]; | |||||
| let extraClassDeclaration = [{ | |||||
| static Type getResultType(ValueRange operands) { | |||||
| return deduce_result_type(operands); | |||||
| } | |||||
| }] # ElemwiseBuilderImpl_create; | |||||
| } | |||||
| def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; | |||||
| def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; | |||||
| def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; | |||||
| /* ACOS */ | |||||
| /* ASIN */ | |||||
| def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; | |||||
| def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; | |||||
| def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; | |||||
| def ExpM1Op : ElemwiseUnaryOp<"expm1", [NoSideEffect]>; | |||||
| def FloorOp : ElemwiseUnaryOp<"floor", [NoSideEffect]>; | |||||
| def LogOp : ElemwiseUnaryOp<"log", [NoSideEffect]>; | |||||
| def Log1POp : ElemwiseUnaryOp<"log1p", [NoSideEffect]>; | |||||
| def SigmoidOp: ElemwiseUnaryOp<"sigmoid", [NoSideEffect]>; | |||||
| def SinOp : ElemwiseUnaryOp<"sin", [NoSideEffect]>; | |||||
| def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; | |||||
| def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; | |||||
| def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; | |||||
| def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||||
| /* ERF */ | |||||
| /* ERFINV */ | |||||
| /* ERFC */ | |||||
| /* ERFCINV */ | |||||
| class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
| ElemwiseOp<mnemonic, traits> { | |||||
| let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | ||||
| let results = (outs F32MemRef); | let results = (outs F32MemRef); | ||||
| // Specify a parser and printer method. | |||||
| let parser = [{ return ::parseBinaryOp(parser, result); }]; | |||||
| let printer = [{ return ::printBinaryOp(p, *this); }]; | |||||
| let builders = [OpBuilder< | |||||
| "Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
| result.addOperands(operands); | |||||
| result.addTypes(getResultType(operands)); | |||||
| }] | |||||
| >, OpBuilder < | |||||
| "OpBuilder& builder, OperationState& result, Value lhs, Value rhs", [{ | |||||
| result.addOperands(lhs); | |||||
| result.addOperands(rhs); | |||||
| result.addTypes(getResultType({lhs, rhs})); | |||||
| }] | |||||
| >]; | |||||
| let extraClassDeclaration = [{ | |||||
| static Type getResultType(ValueRange operands) { | |||||
| return deduce_result_type(operands); | |||||
| } | |||||
| }] # ElemwiseBuilderImpl_create; | |||||
| } | |||||
| def AbsGradOp : ElemwiseBinaryOp<"abs_grad", [NoSideEffect]>; | |||||
| def AddOp : ElemwiseBinaryOp<"add", [Commutative, NoSideEffect]>; | |||||
| def FloorDivOp : ElemwiseBinaryOp<"floor_div", [NoSideEffect]>; | |||||
| def MaxOp : ElemwiseBinaryOp<"max", [Commutative, NoSideEffect]>; | |||||
| def MinOp : ElemwiseBinaryOp<"min", [Commutative, NoSideEffect]>; | |||||
| def ModOp : ElemwiseBinaryOp<"mod", [NoSideEffect]>; | |||||
| def MulOp : ElemwiseBinaryOp<"mul", [Commutative, NoSideEffect]>; | |||||
| def SubOp : ElemwiseBinaryOp<"sub", [NoSideEffect]>; | |||||
| def SigmoidGradOp : ElemwiseBinaryOp<"sigmoid_grad", [NoSideEffect]>; | |||||
| def SwishGt0Op : ElemwiseBinaryOp<"switch_gt0", [NoSideEffect]>; | |||||
| def TanhGradOp : ElemwiseBinaryOp<"tanh_grad", [NoSideEffect]>; | |||||
| def LtOp : ElemwiseBinaryOp<"lt", [NoSideEffect]>; | |||||
| def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; | |||||
| def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; | |||||
| def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; | |||||
| def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; | |||||
| /* POW */ | |||||
| def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; | |||||
| def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; | |||||
| def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; | |||||
| def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; | |||||
| def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; | |||||
| def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||||
| /* ATAN2 */ | |||||
| class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
| ElemwiseOp<mnemonic, traits> { | |||||
| let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z); | |||||
| let results = (outs F32MemRef); | |||||
| let builders = [OpBuilder< | |||||
| "Builder* builder, OperationState& result, ValueRange operands", [{ | |||||
| result.addOperands(operands); | |||||
| result.addTypes(getResultType(operands)); | |||||
| }] | |||||
| >, OpBuilder < | |||||
| "OpBuilder& builder, OperationState& result, Value x, Value y, Value z", [{ | |||||
| result.addOperands(x); | |||||
| result.addOperands(y); | |||||
| result.addOperands(z); | |||||
| result.addTypes(getResultType({x, y, z})); | |||||
| }] | |||||
| >]; | |||||
| // Allow building an AddOp with from the two input operands. | |||||
| let builders = [ | |||||
| OpBuilder<"OpBuilder &b, OperationState &state, Value lhs, Value rhs"> | |||||
| ]; | |||||
| let extraClassDeclaration = [{ | |||||
| static Type getResultType(ValueRange operands) { | |||||
| return deduce_result_type(operands); | |||||
| } | |||||
| }] # ElemwiseBuilderImpl_create; | |||||
| } | } | ||||
| def CondLeqMovOp: ElemwiseTernaryOp<"cond_leq_mov", [NoSideEffect]>; | |||||
| def FuseMulAdd3Op: ElemwiseTernaryOp<"fuse_mul_add3", [NoSideEffect]>; | |||||
| def ReturnOp : GenericOp<"return", | def ReturnOp : GenericOp<"return", | ||||
| [NoSideEffect, HasParent<"FuncOp">, Terminator]> { | [NoSideEffect, HasParent<"FuncOp">, Terminator]> { | ||||
| let summary = "return operation"; | let summary = "return operation"; | ||||
| @@ -1,30 +0,0 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/shape_inference_interface.td | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #ifndef MGB_JIT_SHAPE_INFERENCE_INTERFACE | |||||
| #define MGB_JIT_SHAPE_INFERENCE_INTERFACE | |||||
| include "mlir/IR/OpBase.td" | |||||
| def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { | |||||
| let description = [{ | |||||
| Interface to access a registered method to infer the return types for an | |||||
| operation that can be used during type inference. | |||||
| }]; | |||||
| let methods = [ | |||||
| InterfaceMethod<"Infer and set the output shape for the current operation.", | |||||
| "void", "infer_shapes"> | |||||
| ]; | |||||
| } | |||||
| #endif // MGB_SHAPE_INFERENCE_INTERFACE | |||||
| @@ -1,100 +0,0 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/shape_inference_pass.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain_build_config.h" | |||||
| #if MGB_JIT && MGB_JIT_MLIR | |||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||||
| #include "megbrain/jit/mlir/ir/passes.h" | |||||
| #include "megbrain/jit/mlir/ir/shape_inference_interface.h" | |||||
| #include <llvm/ADT/SmallPtrSet.h> | |||||
| #include <mlir/IR/StandardTypes.h> | |||||
| #include <mlir/Pass/Pass.h> | |||||
| using namespace mgb; | |||||
| using namespace jit; | |||||
| #include "megbrain/jit/mlir/ir/shape_inference_interface.cpp.inc" | |||||
| namespace { | |||||
| class ShapeInferencePass | |||||
| : public mlir::PassWrapper<ShapeInferencePass, FunctionPass> { | |||||
| public: | |||||
| void runOnFunction() override { | |||||
| auto f = getFunction(); | |||||
| llvm::SmallPtrSet<mlir::Operation*, 16> op_worklist; | |||||
| f.walk([&](mlir::Operation* op) { | |||||
| if (returns_dynamic_shape(op)) | |||||
| op_worklist.insert(op); | |||||
| }); | |||||
| // Iterate on the operations in the worklist until all operations have | |||||
| // been inferred or no change happened (fix point). | |||||
| while (!op_worklist.empty()) { | |||||
| // Find the next operation ready for inference, that is an operation | |||||
| // with all operands already resolved (non-generic). | |||||
| auto nextop = llvm::find_if(op_worklist, all_operands_inferred); | |||||
| if (nextop == op_worklist.end()) | |||||
| break; | |||||
| Operation* op = *nextop; | |||||
| op_worklist.erase(op); | |||||
| if (auto shapeOp = dyn_cast<ShapeInference>(op)) { | |||||
| shapeOp.infer_shapes(); | |||||
| } else { | |||||
| mgb_log_error( | |||||
| "unable to infer shape of operation without shape " | |||||
| "inference interface"); | |||||
| return signalPassFailure(); | |||||
| } | |||||
| } | |||||
| // If the operation worklist isn't empty, this indicates a failure. | |||||
| if (!op_worklist.empty()) { | |||||
| mgb_log_error( | |||||
| "Shape inference failed, %zu operations couldn't be " | |||||
| "inferred", | |||||
| op_worklist.size()); | |||||
| signalPassFailure(); | |||||
| } | |||||
| } | |||||
| //! A utility method that returns if the given operation has all of its | |||||
| //! operands inferred. | |||||
| static bool all_operands_inferred(Operation* op) { | |||||
| return llvm::all_of(op->getOperandTypes(), [](Type operandType) { | |||||
| return operandType.isa<mlir::MemRefType>(); | |||||
| }); | |||||
| } | |||||
| //! A utility method that returns if the given operation has a dynamically | |||||
| //! shaped result. | |||||
| static bool returns_dynamic_shape(Operation* op) { | |||||
| return llvm::any_of(op->getResultTypes(), [](Type resultType) { | |||||
| return !resultType.isa<mlir::MemRefType>(); | |||||
| }); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| std::unique_ptr<mlir::Pass> mgb::jit::create_shape_inference_pass() { | |||||
| return std::make_unique<ShapeInferencePass>(); | |||||
| } | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/utils.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain_build_config.h" | |||||
| #if MGB_JIT && MGB_JIT_MLIR | |||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/exception.h" | |||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include "megdnn/oprs/general.h" | |||||
| #include "megdnn/basic_types.h" | |||||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | |||||
| #include <mlir/IR/Builders.h> | |||||
| #include <mlir/IR/StandardTypes.h> | |||||
| #include <mlir/IR/Types.h> | |||||
| #include <mlir/Support/LLVM.h> | |||||
| using namespace mgb; | |||||
| using namespace jit; | |||||
| mlir::Value jit::insert_alloc_and_dealloc(mlir::MemRefType type, | |||||
| mlir::Location loc, | |||||
| mlir::PatternRewriter& rewriter) { | |||||
| auto alloc = rewriter.create<mlir::AllocOp>(loc, type); | |||||
| // Make sure to allocate at the beginning of the block. | |||||
| auto* parent_block = alloc.getOperation()->getBlock(); | |||||
| alloc.getOperation()->moveBefore(&parent_block->front()); | |||||
| // Make sure to deallocate this alloc at the end of the block. This is fine | |||||
| // as toy functions have no control flow. | |||||
| auto dealloc = rewriter.create<mlir::DeallocOp>(loc, alloc); | |||||
| dealloc.getOperation()->moveBefore(&parent_block->back()); | |||||
| return alloc; | |||||
| } | |||||
| mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||||
| megdnn::TensorShapeArray srcs; | |||||
| megdnn::TensorShape dst; | |||||
| megdnn::DType dst_type; | |||||
| for (auto operand : operands) { | |||||
| auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>(); | |||||
| mgb_assert(type, "currently only support MemRefType"); | |||||
| srcs.push_back(mlir_type_to_layout(type)); | |||||
| } | |||||
| megdnn::Elemwise::deduce_shape(srcs, dst); | |||||
| mlir::Builder builder(operands[0].getContext()); | |||||
| return layout_to_mlir_type({dst, mlir_type_to_dtype(operands[0].getType())}, | |||||
| builder); | |||||
| } | |||||
| megdnn::TensorLayout jit::mlir_type_to_layout(mlir::Type type) { | |||||
| megdnn::TensorLayout ret; | |||||
| if (type.isa<mlir::MemRefType>()) { | |||||
| auto real_type = type.dyn_cast_or_null<mlir::MemRefType>(); | |||||
| mgb_assert(real_type); | |||||
| ret.ndim = real_type.getRank(); | |||||
| for (size_t i = 0; i < ret.ndim; i++) { | |||||
| ret.shape[i] = real_type.getDimSize(i); | |||||
| } | |||||
| ret.dtype = mlir_type_to_dtype(real_type.getElementType()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { | |||||
| mlir::Type element_type = type; | |||||
| if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | |||||
| element_type = cast.getElementType(); | |||||
| } | |||||
| switch (element_type.getKind()) { | |||||
| case mlir::StandardTypes::F32: | |||||
| return megdnn::dtype::Float32{}; | |||||
| default: | |||||
| mgb_throw(InternalError, | |||||
| "Unsupport mlir type for MemRefType, got: %s\n", | |||||
| mlir_type_to_string(type).c_str()); | |||||
| } | |||||
| return {}; | |||||
| } | |||||
| mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||||
| mlir::Builder& builder) { | |||||
| std::vector<int64_t> shape; | |||||
| for (size_t i = 0; i < layout.ndim; i++) { | |||||
| shape.push_back(layout[i]); | |||||
| } | |||||
| switch (layout.dtype.enumv()) { | |||||
| case megdnn::DTypeEnum::Float32: | |||||
| return mlir::MemRefType::get(shape, builder.getF32Type()); | |||||
| default: | |||||
| mgb_throw(InternalError, "No supported dtype: %s", | |||||
| layout.dtype.name()); | |||||
| } | |||||
| } | |||||
| #endif // MGB_JIT_MLIR | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -14,8 +14,10 @@ | |||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "./mlir_gen.h" | #include "./mlir_gen.h" | ||||
| #include "./utils.h" | |||||
| #include "./ir/each_mode.h" | |||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| @@ -118,7 +120,7 @@ private: | |||||
| if (!return_op) { | if (!return_op) { | ||||
| m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc()); | m_builder.create<jit::ReturnOp>(m_builder.getUnknownLoc()); | ||||
| } | } | ||||
| std::string op_content = to_string(func_op); | |||||
| std::string op_content = mlir_type_to_string(func_op); | |||||
| func_op.setName( | func_op.setName( | ||||
| ssprintf("jit_mlir_%" PRIx64, | ssprintf("jit_mlir_%" PRIx64, | ||||
| XXHash{}.update(op_content.data(), op_content.size()) | XXHash{}.update(op_content.data(), op_content.size()) | ||||
| @@ -140,7 +142,8 @@ private: | |||||
| mgb_assert( | mgb_assert( | ||||
| mlir::succeeded(declare(opr->output(0)->name(), out))); | mlir::succeeded(declare(opr->output(0)->name(), out))); | ||||
| } | } | ||||
| }}.add(internal_graph.output()); | |||||
| }} | |||||
| .add(internal_graph.output()); | |||||
| m_builder.create<AssignOp>(m_builder.getUnknownLoc(), | m_builder.create<AssignOp>(m_builder.getUnknownLoc(), | ||||
| get(internal_graph.output()), | get(internal_graph.output()), | ||||
| get(args.outputs[0].from)); | get(args.outputs[0].from)); | ||||
| @@ -150,11 +153,31 @@ private: | |||||
| mlir::Value gen_op(const opr::Elemwise& opr) { | mlir::Value gen_op(const opr::Elemwise& opr) { | ||||
| switch (opr.param().mode) { | switch (opr.param().mode) { | ||||
| case opr::Elemwise::Mode::ADD: | |||||
| return m_builder.create<AddOp>(m_builder.getUnknownLoc(), | |||||
| get(opr.input(0)), | |||||
| get(opr.input(1))); | |||||
| break; | |||||
| #define cb(mlir_op, mgb_mode) \ | |||||
| case opr::Elemwise::Mode::mgb_mode: \ | |||||
| return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \ | |||||
| get(opr.input(0)), \ | |||||
| get(opr.input(1))); \ | |||||
| break; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| #undef cb | |||||
| #define cb(mlir_op, mgb_mode) \ | |||||
| case opr::Elemwise::Mode::mgb_mode: \ | |||||
| return m_builder.create<jit::mlir_op>(m_builder.getUnknownLoc(), \ | |||||
| get(opr.input(0))); \ | |||||
| break; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
| #undef cb | |||||
| #define cb(mlir_op, mgb_mode) \ | |||||
| case opr::Elemwise::Mode::mgb_mode: \ | |||||
| return m_builder.create<jit::mlir_op>( \ | |||||
| m_builder.getUnknownLoc(), get(opr.input(0)), \ | |||||
| get(opr.input(1)), get(opr.input(2))); \ | |||||
| break; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| #undef cb | |||||
| default: | default: | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -162,19 +185,7 @@ private: | |||||
| } | } | ||||
| mlir::Type get_type(const TensorLayout& layout) { | mlir::Type get_type(const TensorLayout& layout) { | ||||
| std::vector<int64_t> shape; | |||||
| for (size_t i = 0; i < layout.ndim; i++) { | |||||
| shape.push_back(layout[i]); | |||||
| } | |||||
| mgb_assert(layout.ndim != 0); | |||||
| switch (layout.dtype.enumv()) { | |||||
| case DTypeEnum::Float32: | |||||
| return mlir::MemRefType::get(shape, m_builder.getF32Type()); | |||||
| default: | |||||
| mgb_throw(InternalError, "No supported dtype: %s", | |||||
| layout.dtype.name()); | |||||
| } | |||||
| return mlir::UnrankedMemRefType::get(m_builder.getNoneType(), 0); | |||||
| return layout_to_mlir_type(layout, m_builder); | |||||
| } | } | ||||
| mlir::Value get(const VarNode* var) { | mlir::Value get(const VarNode* var) { | ||||
| @@ -9,12 +9,12 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "megbrain/jit/utils.h" | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT | #if MGB_JIT | ||||
| #include "megbrain/utils/debug.h" | #include "megbrain/utils/debug.h" | ||||
| #include "megbrain/jit/utils.h" | |||||
| #include <atomic> | #include <atomic> | ||||
| @@ -15,13 +15,14 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include <mlir/IR/OpDefinition.h> | |||||
| #include "megbrain/jit/mlir/ir/interfaces.h" | |||||
| #include "megbrain/jit/mlir/ir/utils.h" | |||||
| #include <mlir/IR/Dialect.h> | #include <mlir/IR/Dialect.h> | ||||
| #include <mlir/IR/Function.h> | #include <mlir/IR/Function.h> | ||||
| #include <mlir/IR/OpDefinition.h> | |||||
| #include <mlir/Interfaces/SideEffectInterfaces.h> | #include <mlir/Interfaces/SideEffectInterfaces.h> | ||||
| #include "megbrain/jit/mlir/ir/shape_inference_interface.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace jit { | namespace jit { | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * \file src/jit/impl/mlir/ir/shape_inference_interface.h | |||||
| * \file src/jit/include/mlir/ir/interfaces.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| @@ -13,21 +13,16 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | |||||
| #if MGB_JIT_MLIR | |||||
| #include "mlir/IR/OpDefinition.h" | |||||
| namespace mgb { | |||||
| namespace jit { | |||||
| #include <mlir/IR/OpDefinition.h> | |||||
| #include <mlir/IR/Types.h> | |||||
| namespace mlir { | |||||
| /// Include the auto-generated declarations. | /// Include the auto-generated declarations. | ||||
| #include "megbrain/jit/mlir/ir/shape_inference_interface.h.inc" | |||||
| } // end namespace toy | |||||
| } // end namespace mlir | |||||
| #include "megbrain/jit/mlir/ir/interfaces.h.inc" | |||||
| } | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||||
| #endif // MGB_JIT_MLIR | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -13,19 +13,15 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #include <mlir/IR/Module.h> | |||||
| #include "megbrain_build_config.h" | |||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include <memory> | #include <memory> | ||||
| #include <mlir/IR/Module.h> | |||||
| #include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
| namespace mgb { | namespace mgb { | ||||
| namespace jit { | namespace jit { | ||||
| std::unique_ptr<mlir::Pass> create_shape_inference_pass(); | |||||
| /** | /** | ||||
| * \brief Create a pass for lowering to operations in the `Affine` and `Std` | * \brief Create a pass for lowering to operations in the `Affine` and `Std` | ||||
| * dialects, for a subset of the megbrain IR. | * dialects, for a subset of the megbrain IR. | ||||
| @@ -1,13 +1,12 @@ | |||||
| /** | /** | ||||
| * \file src/jit/impl/mlir/utils.h | |||||
| * \file src/jit/include/megbrain/mlir/ir/utils.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -15,28 +14,37 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "megbrain/common.h" | |||||
| #include "megbrain/exception.h" | |||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| #include <string> | |||||
| #include <mlir/ExecutionEngine/CRunnerUtils.h> | |||||
| #include <llvm/Support/raw_ostream.h> | |||||
| #include <mlir/IR/PatternMatch.h> | |||||
| #include <mlir/IR/StandardTypes.h> | |||||
| #include <mlir/IR/Value.h> | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace jit { | namespace jit { | ||||
| template <typename T> | template <typename T> | ||||
| std::string to_string(T&& t) { | |||||
| std::string mlir_type_to_string(T&& t) { | |||||
| std::string ret; | std::string ret; | ||||
| llvm::raw_string_ostream stream(ret); | llvm::raw_string_ostream stream(ret); | ||||
| t.print(stream); | t.print(stream); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| mlir::Value insert_alloc_and_dealloc(mlir::MemRefType type, mlir::Location loc, | |||||
| mlir::PatternRewriter& rewriter); | |||||
| mlir::Type deduce_result_type(mlir::ValueRange operands); | |||||
| /** | |||||
| * \brief convert mlir type to TensorShape | |||||
| */ | |||||
| megdnn::TensorLayout mlir_type_to_layout(mlir::Type type); | |||||
| megdnn::DType mlir_type_to_dtype(mlir::Type type); | |||||
| mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||||
| mlir::Builder& builder); | |||||
| } // namespace jit | } // namespace jit | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -9,9 +9,11 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include <memory> | |||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "megbrain/jit/executor_opr.h" | #include "megbrain/jit/executor_opr.h" | ||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| @@ -129,7 +131,7 @@ void run_mlir(CompNode cn) { | |||||
| HostTensorGenerator<dtype::Float32> gen; | HostTensorGenerator<dtype::Float32> gen; | ||||
| auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), | auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn), | ||||
| host_x2 = gen({23, 42}, cn); | |||||
| host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn); | |||||
| auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | auto a = opr::Host2DeviceCopy::make(*graph, host_x0), | ||||
| b = opr::Host2DeviceCopy::make(*graph, host_x1), | b = opr::Host2DeviceCopy::make(*graph, host_x1), | ||||
| @@ -137,7 +139,6 @@ void run_mlir(CompNode cn) { | |||||
| auto y = a + b + c; | auto y = a + b + c; | ||||
| VarNodeArray inputs{a.node(), b.node(), c.node()}, outputs{y.node()}; | |||||
| auto ig_gen = | auto ig_gen = | ||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | ||||
| @@ -157,6 +158,48 @@ void run_mlir(CompNode cn) { | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | ||||
| } | } | ||||
| template <typename tag, int arity> | |||||
| void run_mlir_mode(CompNode cn) { | |||||
| set_backend(Backend::MLIR); | |||||
| auto graph = ComputingGraph::make(); | |||||
| float low = 0.f, high = 1.f; | |||||
| if (tag::mode == opr::Elemwise::Mode::LOG) { | |||||
| low = 0.1; | |||||
| high = 4; | |||||
| } | |||||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(low, | |||||
| high); | |||||
| SmallVector<std::shared_ptr<HostTensorND>> hosts; | |||||
| VarNodeArray input_vars; | |||||
| for (int i = 0; i < arity; i++) { | |||||
| hosts.push_back(gen({23, 42}, cn)); | |||||
| input_vars.push_back( | |||||
| opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); | |||||
| } | |||||
| auto y = opr::Elemwise::make(input_vars, tag::mode); | |||||
| auto ig_gen = | |||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||||
| for (auto i : get_rev_topo_order(y)) { | |||||
| if (!i->template same_type<opr::Host2DeviceCopy>()) { | |||||
| ig_gen->add_opr(i); | |||||
| } | |||||
| } | |||||
| auto igraph = ig_gen->generate(); | |||||
| auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||||
| HostTensorND host_y, host_y_jit; | |||||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||||
| make_callback_copy(y_jit, host_y_jit)}); | |||||
| func->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
| } | |||||
| #endif | #endif | ||||
| } // anonymous namespace | } // anonymous namespace | ||||
| @@ -191,6 +234,117 @@ TEST(TestJITMlirCodeGen, BasicGPU) { | |||||
| run_mlir(cn); | run_mlir(cn); | ||||
| } | } | ||||
| ///////////////////////// unary /////////////////////////////// | |||||
| // clang-format off | |||||
| #define FOREACH_UNARY_MODE(cb) \ | |||||
| cb(RELU) \ | |||||
| cb(ABS) \ | |||||
| cb(NEGATE) \ | |||||
| cb(CEIL) \ | |||||
| cb(EXP) \ | |||||
| cb(FLOOR) \ | |||||
| cb(LOG) \ | |||||
| cb(LOG1P) \ | |||||
| cb(SIN) \ | |||||
| cb(TANH) \ | |||||
| cb(FAST_TANH) \ | |||||
| cb(H_SWISH) \ | |||||
| cb(SIGMOID) \ | |||||
| cb(EXPM1) \ | |||||
| cb(ROUND) | |||||
| // clang-format on | |||||
| template <typename tag> | |||||
| class TestJITMlirUnaryElemwise : public ::testing::Test {}; | |||||
| #define def_tag(x) \ | |||||
| struct x { \ | |||||
| static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ | |||||
| }; | |||||
| FOREACH_UNARY_MODE(def_tag) | |||||
| #undef def_tag | |||||
| #define t(n) n, | |||||
| using mlir_elemwise_unary_types = | |||||
| ::testing::Types<FOREACH_UNARY_MODE(t) ABS>; | |||||
| #undef t | |||||
| TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); | |||||
| TYPED_TEST(TestJITMlirUnaryElemwise, run) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| run_mlir_mode<TypeParam, 1>(cn); | |||||
| } | |||||
| ///////////////////////// binary /////////////////////////////// | |||||
| // clang-format off | |||||
| #define FOREACH_BINARY_MODE(cb) \ | |||||
| cb(ADD) \ | |||||
| cb(FLOOR_DIV) \ | |||||
| cb(MUL) \ | |||||
| cb(MAX) \ | |||||
| cb(MIN) \ | |||||
| cb(MOD) \ | |||||
| cb(SUB) \ | |||||
| cb(TRUE_DIV) \ | |||||
| cb(ABS_GRAD) \ | |||||
| cb(SIGMOID_GRAD) \ | |||||
| cb(SWITCH_GT0) \ | |||||
| cb(TANH_GRAD) \ | |||||
| cb(LT) \ | |||||
| cb(LEQ) \ | |||||
| cb(EQ) \ | |||||
| cb(FUSE_ADD_RELU) \ | |||||
| cb(LOG_SUM_EXP) \ | |||||
| cb(FUSE_ADD_TANH) \ | |||||
| cb(FAST_TANH_GRAD) \ | |||||
| cb(FUSE_ADD_SIGMOID) \ | |||||
| cb(H_SWISH_GRAD) \ | |||||
| cb(FUSE_ADD_H_SWISH) | |||||
| // clang-format on | |||||
| template <typename tag> | |||||
| class TestJITMlirBinaryElemwise : public ::testing::Test {}; | |||||
| #define def_tag(x) \ | |||||
| struct x { \ | |||||
| static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ | |||||
| }; | |||||
| FOREACH_BINARY_MODE(def_tag) | |||||
| #undef def_tag | |||||
| #define t(n) n, | |||||
| using mlir_elemwise_binary_types = | |||||
| ::testing::Types<FOREACH_BINARY_MODE(t) ADD>; | |||||
| #undef t | |||||
| TYPED_TEST_CASE(TestJITMlirBinaryElemwise, mlir_elemwise_binary_types); | |||||
| TYPED_TEST(TestJITMlirBinaryElemwise, run) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| run_mlir_mode<TypeParam, 2>(cn); | |||||
| } | |||||
| ///////////////////////// ternary /////////////////////////////// | |||||
| // clang-format off | |||||
| #define FOREACH_TERNARY_MODE(cb) \ | |||||
| cb(COND_LEQ_MOV) \ | |||||
| cb(FUSE_MUL_ADD3) \ | |||||
| // clang-format on | |||||
| template <typename tag> | |||||
| class TestJITMlirTernaryElemwise : public ::testing::Test {}; | |||||
| #define def_tag(x) \ | |||||
| struct x { \ | |||||
| static constexpr opr::Elemwise::Mode mode = opr::Elemwise::Mode::x; \ | |||||
| }; | |||||
| FOREACH_TERNARY_MODE(def_tag) | |||||
| #undef def_tag | |||||
| #define t(n) n, | |||||
| using mlir_elemwise_ternary_types = | |||||
| ::testing::Types<FOREACH_TERNARY_MODE(t) COND_LEQ_MOV>; | |||||
| #undef t | |||||
| TYPED_TEST_CASE(TestJITMlirTernaryElemwise, mlir_elemwise_ternary_types); | |||||
| TYPED_TEST(TestJITMlirTernaryElemwise, run) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| run_mlir_mode<TypeParam, 3>(cn); | |||||
| } | |||||
| #endif | #endif | ||||
| #endif // MGB_JIT | #endif // MGB_JIT | ||||
| @@ -57,6 +57,7 @@ namespace opr { | |||||
| EL2(and_, AND) | EL2(and_, AND) | ||||
| EL2(or_, OR) | EL2(or_, OR) | ||||
| EL2(xor_, XOR) | EL2(xor_, XOR) | ||||
| EL2(mod, MOD) | |||||
| #undef EL2 | #undef EL2 | ||||