GitOrigin-RevId: 27b1649c04
tags/v1.0.0-rc1
| @@ -101,9 +101,9 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { | |||||
| compiler = std::make_unique<CudaCompiler>(); | compiler = std::make_unique<CudaCompiler>(); | ||||
| break; | break; | ||||
| } | } | ||||
| #endif | |||||
| mgb_throw(InternalError, "No compiler support for cuda"); | mgb_throw(InternalError, "No compiler support for cuda"); | ||||
| break; | break; | ||||
| #endif | |||||
| case CompNode::DeviceType::CPU: | case CompNode::DeviceType::CPU: | ||||
| #if MGB_JIT_MLIR | #if MGB_JIT_MLIR | ||||
| if (!backend || !strcmp(backend, "MLIR")) { | if (!backend || !strcmp(backend, "MLIR")) { | ||||
| @@ -20,6 +20,10 @@ | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| #if MGB_JIT_MLIR | |||||
| #include "./mlir/ir/each_mode.h" | |||||
| #endif | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| using namespace jit; | using namespace jit; | ||||
| @@ -339,35 +343,76 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| //! As MLIR backend has some contraints | |||||
| auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||||
| // float elemwise | // float elemwise | ||||
| if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | ||||
| return ast_c::check_elem_mode(elem->param().mode) && | |||||
| bool ret = true; | |||||
| #if MGB_JIT_MLIR | |||||
| if (!strcmp(backend, "MLIR")) { | |||||
| switch (elem->param().mode) { | |||||
| #define cb(_, _mode) \ | |||||
| case opr::Elemwise::Mode::_mode: \ | |||||
| ret = true; \ | |||||
| break; | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| default: | |||||
| ret = false; | |||||
| #undef cb | |||||
| } | |||||
| #define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN) | |||||
| //! FIXME mlir on cuda does't support sin currently. | |||||
| if (opr->output(0)->comp_node().device_type() == | |||||
| CompNode::DeviceType::CUDA) { | |||||
| switch (elem->param().mode) { | |||||
| #define cb(_mode) \ | |||||
| case opr::Elemwise::Mode::_mode: \ | |||||
| ret = false; \ | |||||
| break; | |||||
| FOREACH_ELEMWISE_SKIP_MODE(cb) | |||||
| default: | |||||
| break; | |||||
| #undef cb | |||||
| } | |||||
| } | |||||
| #undef FOREACH_ELEMWISE_SKIP_MODE | |||||
| } | |||||
| #endif // MGB_JIT_MLIR | |||||
| return ret && ast_c::check_elem_mode(elem->param().mode) && | |||||
| elem->output(0)->dtype().category() == DTypeCategory::FLOAT; | elem->output(0)->dtype().category() == DTypeCategory::FLOAT; | ||||
| } | } | ||||
| if (opr->same_type<opr::PowC>()) { | |||||
| return true; | |||||
| } | |||||
| if (strcmp(backend, "MLIR")) { | |||||
| if (opr->same_type<opr::PowC>()) { | |||||
| return true; | |||||
| } | |||||
| // float typecvt (e.g. used in f16 training) | |||||
| if (opr->same_type<opr::TypeCvt>()) { | |||||
| auto category = opr->input(0)->dtype().category(); | |||||
| if (category != opr->output(0)->dtype().category()) | |||||
| return false; | |||||
| return category == DTypeCategory::FLOAT; | |||||
| } | |||||
| // float typecvt (e.g. used in f16 training) | |||||
| if (opr->same_type<opr::TypeCvt>()) { | |||||
| auto category = opr->input(0)->dtype().category(); | |||||
| if (category != opr->output(0)->dtype().category()) | |||||
| return false; | |||||
| return category == DTypeCategory::FLOAT; | |||||
| } | |||||
| // float reduce | |||||
| if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||||
| opr->same_type<opr::Reduce>()) { | |||||
| return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||||
| } | |||||
| // float reduce | |||||
| if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||||
| opr->same_type<opr::Reduce>()) { | |||||
| return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||||
| } | |||||
| // dimshuffle | |||||
| if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||||
| opr->same_type<opr::Dimshuffle>()) { | |||||
| auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||||
| return param.pattern_len <= 4; | |||||
| // dimshuffle | |||||
| if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||||
| opr->same_type<opr::Dimshuffle>()) { | |||||
| auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||||
| return param.pattern_len <= 4; | |||||
| } | |||||
| } | } | ||||
| // existing JITExecutor | // existing JITExecutor | ||||
| @@ -10,7 +10,6 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "llvm/Pass.h" | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| @@ -40,6 +39,7 @@ | |||||
| #include <llvm/Support/TargetSelect.h> | #include <llvm/Support/TargetSelect.h> | ||||
| #include <llvm/IRReader/IRReader.h> | #include <llvm/IRReader/IRReader.h> | ||||
| #include <llvm/Linker/Linker.h> | #include <llvm/Linker/Linker.h> | ||||
| #include <llvm/Pass.h> | |||||
| #include <dlfcn.h> | #include <dlfcn.h> | ||||
| #include <dirent.h> | #include <dirent.h> | ||||
| @@ -77,6 +77,16 @@ private: | |||||
| mlir::Location m_location; | mlir::Location m_location; | ||||
| }; | }; | ||||
| template <typename Op> | |||||
| mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, | |||||
| const mlir::Value& val, const mlir::ValueRange& index) { | |||||
| if (val.getType().isa<mlir::MemRefType>()) { | |||||
| return builder.create<Op>(loc, val, index); | |||||
| } else { | |||||
| return val; | |||||
| } | |||||
| } | |||||
| } // namespace jit | } // namespace jit | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -14,6 +14,7 @@ | |||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include "./types.h" | |||||
| #include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
| #include <mlir/IR/OpImplementation.h> | #include <mlir/IR/OpImplementation.h> | ||||
| @@ -74,8 +74,8 @@ struct UnaryOpLowering : public ConversionPattern { | |||||
| typename Op::Adaptor binary_adaptor(memref_operands); | typename Op::Adaptor binary_adaptor(memref_operands); | ||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| auto loaded_lhs = builder.create<AffineLoadOp>( | |||||
| loc, binary_adaptor.lhs(), loop_ivs); | |||||
| auto loaded_lhs = get_operand<AffineLoadOp>( | |||||
| builder, loc, binary_adaptor.lhs(), loop_ivs); | |||||
| return lower_op(builder, loc, {loaded_lhs}); | return lower_op(builder, loc, {loaded_lhs}); | ||||
| }); | }); | ||||
| @@ -104,10 +104,10 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| typename Op::Adaptor binary_adaptor(memref_operands); | typename Op::Adaptor binary_adaptor(memref_operands); | ||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| auto loaded_lhs = builder.create<AffineLoadOp>( | |||||
| loc, binary_adaptor.lhs(), loop_ivs); | |||||
| auto loaded_rhs = builder.create<AffineLoadOp>( | |||||
| loc, binary_adaptor.rhs(), loop_ivs); | |||||
| auto loaded_lhs = get_operand<AffineLoadOp>( | |||||
| builder, loc, binary_adaptor.lhs(), loop_ivs); | |||||
| auto loaded_rhs = get_operand<AffineLoadOp>( | |||||
| builder, loc, binary_adaptor.rhs(), loop_ivs); | |||||
| return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | ||||
| }); | }); | ||||
| @@ -136,12 +136,12 @@ struct TernaryOpLowering : public ConversionPattern { | |||||
| typename Op::Adaptor ternary_adaptor(memref_operands); | typename Op::Adaptor ternary_adaptor(memref_operands); | ||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| auto loaded_x = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.x(), loop_ivs); | |||||
| auto loaded_y = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.y(), loop_ivs); | |||||
| auto loaded_z = builder.create<AffineLoadOp>( | |||||
| loc, ternary_adaptor.z(), loop_ivs); | |||||
| auto loaded_x = get_operand<AffineLoadOp>( | |||||
| builder, loc, ternary_adaptor.x(), loop_ivs); | |||||
| auto loaded_y = get_operand<AffineLoadOp>( | |||||
| builder, loc, ternary_adaptor.y(), loop_ivs); | |||||
| auto loaded_z = get_operand<AffineLoadOp>( | |||||
| builder, loc, ternary_adaptor.z(), loop_ivs); | |||||
| return lower_op(builder, loc, | return lower_op(builder, loc, | ||||
| {loaded_x, loaded_y, loaded_z}); | {loaded_x, loaded_y, loaded_z}); | ||||
| @@ -193,6 +193,19 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
| } | } | ||||
| }; | }; | ||||
| struct ConstantScalarOpLowering | |||||
| : public OpRewritePattern<jit::ConstantScalarOp> { | |||||
| using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern; | |||||
| LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||||
| PatternRewriter& rewriter) const final { | |||||
| ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
| rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||||
| op, constant_scalar_adaptor.value()); | |||||
| return success(); | |||||
| } | |||||
| }; | |||||
| class MgbToAffineLoweringPass | class MgbToAffineLoweringPass | ||||
| : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | ||||
| public: | public: | ||||
| @@ -207,7 +220,8 @@ public: | |||||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | ||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | ||||
| ReturnOpLowering, | ReturnOpLowering, | ||||
| AssignOpLowering>(&getContext()); | |||||
| AssignOpLowering, ConstantScalarOpLowering>( | |||||
| &getContext()); | |||||
| #undef cb | #undef cb | ||||
| if (failed(applyPartialConversion(getFunction(), target, patterns))) { | if (failed(applyPartialConversion(getFunction(), target, patterns))) { | ||||
| @@ -38,16 +38,6 @@ using namespace jit; | |||||
| namespace { | namespace { | ||||
| mlir::Value get_operand(ConversionPatternRewriter& rewriter, | |||||
| const mlir::Location& loc, const mlir::Value& val, | |||||
| const mlir::Value& index) { | |||||
| if (val.getType().isa<mlir::MemRefType>()) { | |||||
| return rewriter.create<LoadOp>(loc, val, index); | |||||
| } else { | |||||
| return val; | |||||
| } | |||||
| } | |||||
| mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | ||||
| auto thread_idx = rewriter.create<gpu::ThreadIdOp>( | auto thread_idx = rewriter.create<gpu::ThreadIdOp>( | ||||
| loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | ||||
| @@ -64,7 +54,7 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||||
| template <typename Op, typename LoweredOp> | template <typename Op, typename LoweredOp> | ||||
| struct UnaryOpLowering : public ConversionPattern { | struct UnaryOpLowering : public ConversionPattern { | ||||
| UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | : ConversionPattern(Op::getOperationName(), 1, ctx), | ||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| @@ -74,11 +64,11 @@ struct UnaryOpLowering : public ConversionPattern { | |||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| typename Op::Adaptor binary_adaptor(operands); | typename Op::Adaptor binary_adaptor(operands); | ||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| auto loaded_lhs = | auto loaded_lhs = | ||||
| get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| @@ -87,7 +77,7 @@ struct UnaryOpLowering : public ConversionPattern { | |||||
| } | } | ||||
| private: | private: | ||||
| gpu::LaunchOp* m_launch_op; | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | }; | ||||
| #define cb(_op, _) \ | #define cb(_op, _) \ | ||||
| @@ -97,7 +87,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||||
| template <typename Op, typename LoweredOp> | template <typename Op, typename LoweredOp> | ||||
| struct BinaryOpLowering : public ConversionPattern { | struct BinaryOpLowering : public ConversionPattern { | ||||
| BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | : ConversionPattern(Op::getOperationName(), 1, ctx), | ||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| @@ -107,13 +97,13 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| typename Op::Adaptor binary_adaptor(operands); | typename Op::Adaptor binary_adaptor(operands); | ||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| auto loaded_lhs = | auto loaded_lhs = | ||||
| get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||||
| auto loaded_rhs = | auto loaded_rhs = | ||||
| get_operand(rewriter, loc, binary_adaptor.rhs(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index); | |||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| @@ -123,7 +113,7 @@ struct BinaryOpLowering : public ConversionPattern { | |||||
| } | } | ||||
| private: | private: | ||||
| gpu::LaunchOp* m_launch_op; | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | }; | ||||
| #define cb(_op, _) \ | #define cb(_op, _) \ | ||||
| @@ -133,7 +123,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||||
| template <typename Op, typename LoweredOp> | template <typename Op, typename LoweredOp> | ||||
| struct TernaryOpLowering : public ConversionPattern { | struct TernaryOpLowering : public ConversionPattern { | ||||
| TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | : ConversionPattern(Op::getOperationName(), 1, ctx), | ||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| @@ -143,15 +133,15 @@ struct TernaryOpLowering : public ConversionPattern { | |||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| typename Op::Adaptor ternary_adaptor(operands); | typename Op::Adaptor ternary_adaptor(operands); | ||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| auto loaded_x = | auto loaded_x = | ||||
| get_operand(rewriter, loc, ternary_adaptor.x(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index); | |||||
| auto loaded_y = | auto loaded_y = | ||||
| get_operand(rewriter, loc, ternary_adaptor.y(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index); | |||||
| auto loaded_z = | auto loaded_z = | ||||
| get_operand(rewriter, loc, ternary_adaptor.z(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index); | |||||
| LoweredOp lower_op; | LoweredOp lower_op; | ||||
| @@ -161,7 +151,7 @@ struct TernaryOpLowering : public ConversionPattern { | |||||
| } | } | ||||
| private: | private: | ||||
| gpu::LaunchOp* m_launch_op; | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | }; | ||||
| #define cb(_op, _) \ | #define cb(_op, _) \ | ||||
| @@ -171,7 +161,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||||
| #undef cb | #undef cb | ||||
| struct ReturnOpLowering : public ConversionPattern { | struct ReturnOpLowering : public ConversionPattern { | ||||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | ||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| @@ -182,10 +172,10 @@ struct ReturnOpLowering : public ConversionPattern { | |||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| //! remove the first gpu.terminator | //! remove the first gpu.terminator | ||||
| m_launch_op->body().front().front().erase(); | |||||
| m_launch_op.body().front().front().erase(); | |||||
| //! if (tid >= nr_tid) {return;} in the begin of the block | //! if (tid >= nr_tid) {return;} in the begin of the block | ||||
| rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); | |||||
| rewriter.setInsertionPointToStart(&(m_launch_op.body().front())); | |||||
| Block* cond_block = rewriter.getInsertionBlock(); | Block* cond_block = rewriter.getInsertionBlock(); | ||||
| Block::iterator op_position = rewriter.getInsertionPoint(); | Block::iterator op_position = rewriter.getInsertionPoint(); | ||||
| Block* remaining_ops_block = | Block* remaining_ops_block = | ||||
| @@ -195,7 +185,7 @@ struct ReturnOpLowering : public ConversionPattern { | |||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| auto comparison = rewriter.create<mlir::CmpIOp>( | auto comparison = rewriter.create<mlir::CmpIOp>( | ||||
| loc, CmpIPredicate::sge, index, | loc, CmpIPredicate::sge, index, | ||||
| m_launch_op->getParentOfType<mlir::FuncOp>() | |||||
| m_launch_op.getParentOfType<mlir::FuncOp>() | |||||
| .getArguments() | .getArguments() | ||||
| .back()); | .back()); | ||||
| @@ -216,11 +206,31 @@ struct ReturnOpLowering : public ConversionPattern { | |||||
| } | } | ||||
| private: | private: | ||||
| gpu::LaunchOp* m_launch_op; | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | |||||
| struct ConstantScalarOpLowering | |||||
| : public OpRewritePattern<jit::ConstantScalarOp> { | |||||
| ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : OpRewritePattern<jit::ConstantScalarOp>(ctx), | |||||
| m_launch_op{launch_op} {} | |||||
| LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||||
| PatternRewriter& rewriter) const final { | |||||
| ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
| rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||||
| op, constant_scalar_adaptor.value()); | |||||
| return success(); | |||||
| } | |||||
| private: | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | }; | ||||
| struct AssignOpLowering : public ConversionPattern { | struct AssignOpLowering : public ConversionPattern { | ||||
| AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||||
| AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||||
| : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | ||||
| m_launch_op{launch_op} {} | m_launch_op{launch_op} {} | ||||
| @@ -230,12 +240,12 @@ struct AssignOpLowering : public ConversionPattern { | |||||
| auto loc = op->getLoc(); | auto loc = op->getLoc(); | ||||
| AssignOpAdaptor assign_adaptor(operands); | AssignOpAdaptor assign_adaptor(operands); | ||||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||||
| auto index = get_tid(rewriter, loc); | auto index = get_tid(rewriter, loc); | ||||
| auto loaded_lhs = | auto loaded_lhs = | ||||
| get_operand(rewriter, loc, assign_adaptor.lhs(), index); | |||||
| get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | |||||
| rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index); | rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index); | ||||
| rewriter.eraseOp(op); | rewriter.eraseOp(op); | ||||
| @@ -243,7 +253,7 @@ struct AssignOpLowering : public ConversionPattern { | |||||
| } | } | ||||
| private: | private: | ||||
| gpu::LaunchOp* m_launch_op; | |||||
| gpu::LaunchOp& m_launch_op; | |||||
| }; | }; | ||||
| class MgbToGpuLoweringPass | class MgbToGpuLoweringPass | ||||
| @@ -271,7 +281,8 @@ public: | |||||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | ||||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | ||||
| ReturnOpLowering, | ReturnOpLowering, | ||||
| AssignOpLowering>(&getContext(), &launch_op); | |||||
| ConstantScalarOpLowering, AssignOpLowering>( | |||||
| &getContext(), launch_op); | |||||
| #undef cb | #undef cb | ||||
| if (failed(applyPartialConversion(func_op, target, patterns))) { | if (failed(applyPartialConversion(func_op, target, patterns))) { | ||||
| @@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" | |||||
| include "mlir/Interfaces/SideEffectInterfaces.td" | include "mlir/Interfaces/SideEffectInterfaces.td" | ||||
| include "./interfaces.td" | include "./interfaces.td" | ||||
| include "./predicates.td" | |||||
| def Mgb_Dialect : Dialect { | def Mgb_Dialect : Dialect { | ||||
| let name = "mgb"; | let name = "mgb"; | ||||
| @@ -90,7 +91,7 @@ def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||||
| class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | ||||
| ElemwiseOp<mnemonic, traits> { | ElemwiseOp<mnemonic, traits> { | ||||
| let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||||
| let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs); | |||||
| let results = (outs F32MemRef); | let results = (outs F32MemRef); | ||||
| let builders = [OpBuilder< | let builders = [OpBuilder< | ||||
| @@ -141,7 +142,7 @@ def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||||
| class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | ||||
| ElemwiseOp<mnemonic, traits> { | ElemwiseOp<mnemonic, traits> { | ||||
| let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z); | |||||
| let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z); | |||||
| let results = (outs F32MemRef); | let results = (outs F32MemRef); | ||||
| let builders = [OpBuilder< | let builders = [OpBuilder< | ||||
| @@ -178,6 +179,25 @@ def ReturnOp : GenericOp<"return", | |||||
| } | } | ||||
| def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | |||||
| let summary = "scalar constant"; | |||||
| let arguments = (ins AnyAttr:$value); | |||||
| let results = (outs F32:$result); | |||||
| let builders = [OpBuilder< | |||||
| "Builder* builder, OperationState& result, float value", [{ | |||||
| result.addAttribute("value", builder->getF32FloatAttr(value)); | |||||
| result.addTypes(builder->getF32Type()); | |||||
| }] | |||||
| >]; | |||||
| let extraClassDeclaration = [{ | |||||
| Attribute getValue() { return getAttr("value"); } | |||||
| FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); } | |||||
| }]; | |||||
| } | |||||
| def AssignOp : GenericOp<"assign", []> { | def AssignOp : GenericOp<"assign", []> { | ||||
| let summary = "assign op"; | let summary = "assign op"; | ||||
| let description = [{ | let description = [{ | ||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/predicates.td | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #ifndef MGB_MLIR_PREDICATES | |||||
| #define MGB_MLIR_PREDICATES | |||||
| #ifndef OP_BASE | |||||
| include "mlir/IR/OpBase.td" | |||||
| #endif | |||||
| def ElemwiseFloatAny : TypeConstraint< | |||||
| CPred<"is_elemwise_float($_self)">, "elemwise-float">; | |||||
| #endif | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * \file src/jit/impl/mlir/ir/types.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain_build_config.h" | |||||
| #if MGB_JIT && MGB_JIT_MLIR | |||||
| #include <mlir/IR/StandardTypes.h> | |||||
| namespace mgb { | |||||
| namespace jit { | |||||
| inline bool is_elemwise_float(const mlir::Type& dt) { | |||||
| if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | |||||
| if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (dt.isa<mlir::FloatType>()) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace jit | |||||
| } // namespace mgb | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -49,6 +49,9 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||||
| megdnn::TensorShape dst; | megdnn::TensorShape dst; | ||||
| megdnn::DType dst_type; | megdnn::DType dst_type; | ||||
| for (auto operand : operands) { | for (auto operand : operands) { | ||||
| if (operand.getType().isa<mlir::FloatType>()) { | |||||
| continue; | |||||
| } | |||||
| auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>(); | auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>(); | ||||
| mgb_assert(type, "currently only support MemRefType"); | mgb_assert(type, "currently only support MemRefType"); | ||||
| @@ -137,6 +137,27 @@ private: | |||||
| return; | return; | ||||
| } | } | ||||
| if (opr->same_type<opr::ImmutableTensor>()) { | |||||
| auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); | |||||
| if (imm.valid()) { | |||||
| auto dtype = imm->dtype(); | |||||
| float scalar_value; | |||||
| if (dtype == dtype::Float32()) { | |||||
| scalar_value = imm->get<float>(); | |||||
| } else { | |||||
| mgb_throw(InternalError, | |||||
| "mlir backend currently only support f32 " | |||||
| "dtype, but got %s", | |||||
| dtype.name()); | |||||
| } | |||||
| auto&& out = m_builder.create<jit::ConstantScalarOp>( | |||||
| m_builder.getUnknownLoc(), m_builder.getF32Type(), | |||||
| m_builder.getF32FloatAttr(scalar_value)); | |||||
| mgb_assert(mlir::succeeded( | |||||
| declare(opr->output(0)->name(), out))); | |||||
| } | |||||
| } | |||||
| if (opr->same_type<opr::Elemwise>()) { | if (opr->same_type<opr::Elemwise>()) { | ||||
| auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -137,7 +137,7 @@ void run_mlir(CompNode cn) { | |||||
| b = opr::Host2DeviceCopy::make(*graph, host_x1), | b = opr::Host2DeviceCopy::make(*graph, host_x1), | ||||
| c = opr::Host2DeviceCopy::make(*graph, host_x2); | c = opr::Host2DeviceCopy::make(*graph, host_x2); | ||||
| auto y = a + b * c; | |||||
| auto y = a + b * c + 0.3f; | |||||
| auto ig_gen = | auto ig_gen = | ||||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | ||||