GitOrigin-RevId: 27b1649c04
tags/v1.0.0-rc1
| @@ -101,9 +101,9 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) { | |||
| compiler = std::make_unique<CudaCompiler>(); | |||
| break; | |||
| } | |||
| #endif | |||
| mgb_throw(InternalError, "No compiler support for cuda"); | |||
| break; | |||
| #endif | |||
| case CompNode::DeviceType::CPU: | |||
| #if MGB_JIT_MLIR | |||
| if (!backend || !strcmp(backend, "MLIR")) { | |||
| @@ -20,6 +20,10 @@ | |||
| #if MGB_JIT | |||
| #if MGB_JIT_MLIR | |||
| #include "./mlir/ir/each_mode.h" | |||
| #endif | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| using namespace jit; | |||
| @@ -339,35 +343,76 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||
| return false; | |||
| } | |||
| //! As MLIR backend has some contraints | |||
| auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | |||
| // float elemwise | |||
| if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) { | |||
| return ast_c::check_elem_mode(elem->param().mode) && | |||
| bool ret = true; | |||
| #if MGB_JIT_MLIR | |||
| if (!strcmp(backend, "MLIR")) { | |||
| switch (elem->param().mode) { | |||
| #define cb(_, _mode) \ | |||
| case opr::Elemwise::Mode::_mode: \ | |||
| ret = true; \ | |||
| break; | |||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
| default: | |||
| ret = false; | |||
| #undef cb | |||
| } | |||
| #define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN) | |||
| //! FIXME mlir on cuda does't support sin currently. | |||
| if (opr->output(0)->comp_node().device_type() == | |||
| CompNode::DeviceType::CUDA) { | |||
| switch (elem->param().mode) { | |||
| #define cb(_mode) \ | |||
| case opr::Elemwise::Mode::_mode: \ | |||
| ret = false; \ | |||
| break; | |||
| FOREACH_ELEMWISE_SKIP_MODE(cb) | |||
| default: | |||
| break; | |||
| #undef cb | |||
| } | |||
| } | |||
| #undef FOREACH_ELEMWISE_SKIP_MODE | |||
| } | |||
| #endif // MGB_JIT_MLIR | |||
| return ret && ast_c::check_elem_mode(elem->param().mode) && | |||
| elem->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
| } | |||
| if (opr->same_type<opr::PowC>()) { | |||
| return true; | |||
| } | |||
| if (strcmp(backend, "MLIR")) { | |||
| if (opr->same_type<opr::PowC>()) { | |||
| return true; | |||
| } | |||
| // float typecvt (e.g. used in f16 training) | |||
| if (opr->same_type<opr::TypeCvt>()) { | |||
| auto category = opr->input(0)->dtype().category(); | |||
| if (category != opr->output(0)->dtype().category()) | |||
| return false; | |||
| return category == DTypeCategory::FLOAT; | |||
| } | |||
| // float typecvt (e.g. used in f16 training) | |||
| if (opr->same_type<opr::TypeCvt>()) { | |||
| auto category = opr->input(0)->dtype().category(); | |||
| if (category != opr->output(0)->dtype().category()) | |||
| return false; | |||
| return category == DTypeCategory::FLOAT; | |||
| } | |||
| // float reduce | |||
| if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||
| opr->same_type<opr::Reduce>()) { | |||
| return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
| } | |||
| // float reduce | |||
| if ((m_feature_bits & JITFeatureBits::REDUCE) && | |||
| opr->same_type<opr::Reduce>()) { | |||
| return opr->output(0)->dtype().category() == DTypeCategory::FLOAT; | |||
| } | |||
| // dimshuffle | |||
| if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||
| opr->same_type<opr::Dimshuffle>()) { | |||
| auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||
| return param.pattern_len <= 4; | |||
| // dimshuffle | |||
| if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && | |||
| opr->same_type<opr::Dimshuffle>()) { | |||
| auto param = opr->cast_final_safe<opr::Dimshuffle>().param(); | |||
| return param.pattern_len <= 4; | |||
| } | |||
| } | |||
| // existing JITExecutor | |||
| @@ -10,7 +10,6 @@ | |||
| * implied. | |||
| */ | |||
| #include "llvm/Pass.h" | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_JIT && MGB_JIT_MLIR | |||
| @@ -40,6 +39,7 @@ | |||
| #include <llvm/Support/TargetSelect.h> | |||
| #include <llvm/IRReader/IRReader.h> | |||
| #include <llvm/Linker/Linker.h> | |||
| #include <llvm/Pass.h> | |||
| #include <dlfcn.h> | |||
| #include <dirent.h> | |||
| @@ -77,6 +77,16 @@ private: | |||
| mlir::Location m_location; | |||
| }; | |||
| template <typename Op> | |||
| mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc, | |||
| const mlir::Value& val, const mlir::ValueRange& index) { | |||
| if (val.getType().isa<mlir::MemRefType>()) { | |||
| return builder.create<Op>(loc, val, index); | |||
| } else { | |||
| return val; | |||
| } | |||
| } | |||
| } // namespace jit | |||
| } // namespace mgb | |||
| @@ -14,6 +14,7 @@ | |||
| #if MGB_JIT && MGB_JIT_MLIR | |||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||
| #include "./types.h" | |||
| #include <mlir/IR/Builders.h> | |||
| #include <mlir/IR/OpImplementation.h> | |||
| @@ -74,8 +74,8 @@ struct UnaryOpLowering : public ConversionPattern { | |||
| typename Op::Adaptor binary_adaptor(memref_operands); | |||
| LoweredOp lower_op; | |||
| auto loaded_lhs = builder.create<AffineLoadOp>( | |||
| loc, binary_adaptor.lhs(), loop_ivs); | |||
| auto loaded_lhs = get_operand<AffineLoadOp>( | |||
| builder, loc, binary_adaptor.lhs(), loop_ivs); | |||
| return lower_op(builder, loc, {loaded_lhs}); | |||
| }); | |||
| @@ -104,10 +104,10 @@ struct BinaryOpLowering : public ConversionPattern { | |||
| typename Op::Adaptor binary_adaptor(memref_operands); | |||
| LoweredOp lower_op; | |||
| auto loaded_lhs = builder.create<AffineLoadOp>( | |||
| loc, binary_adaptor.lhs(), loop_ivs); | |||
| auto loaded_rhs = builder.create<AffineLoadOp>( | |||
| loc, binary_adaptor.rhs(), loop_ivs); | |||
| auto loaded_lhs = get_operand<AffineLoadOp>( | |||
| builder, loc, binary_adaptor.lhs(), loop_ivs); | |||
| auto loaded_rhs = get_operand<AffineLoadOp>( | |||
| builder, loc, binary_adaptor.rhs(), loop_ivs); | |||
| return lower_op(builder, loc, {loaded_lhs, loaded_rhs}); | |||
| }); | |||
| @@ -136,12 +136,12 @@ struct TernaryOpLowering : public ConversionPattern { | |||
| typename Op::Adaptor ternary_adaptor(memref_operands); | |||
| LoweredOp lower_op; | |||
| auto loaded_x = builder.create<AffineLoadOp>( | |||
| loc, ternary_adaptor.x(), loop_ivs); | |||
| auto loaded_y = builder.create<AffineLoadOp>( | |||
| loc, ternary_adaptor.y(), loop_ivs); | |||
| auto loaded_z = builder.create<AffineLoadOp>( | |||
| loc, ternary_adaptor.z(), loop_ivs); | |||
| auto loaded_x = get_operand<AffineLoadOp>( | |||
| builder, loc, ternary_adaptor.x(), loop_ivs); | |||
| auto loaded_y = get_operand<AffineLoadOp>( | |||
| builder, loc, ternary_adaptor.y(), loop_ivs); | |||
| auto loaded_z = get_operand<AffineLoadOp>( | |||
| builder, loc, ternary_adaptor.z(), loop_ivs); | |||
| return lower_op(builder, loc, | |||
| {loaded_x, loaded_y, loaded_z}); | |||
| @@ -193,6 +193,19 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||
| } | |||
| }; | |||
| struct ConstantScalarOpLowering | |||
| : public OpRewritePattern<jit::ConstantScalarOp> { | |||
| using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern; | |||
| LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||
| PatternRewriter& rewriter) const final { | |||
| ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||
| rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||
| op, constant_scalar_adaptor.value()); | |||
| return success(); | |||
| } | |||
| }; | |||
| class MgbToAffineLoweringPass | |||
| : public PassWrapper<MgbToAffineLoweringPass, FunctionPass> { | |||
| public: | |||
| @@ -207,7 +220,8 @@ public: | |||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
| ReturnOpLowering, | |||
| AssignOpLowering>(&getContext()); | |||
| AssignOpLowering, ConstantScalarOpLowering>( | |||
| &getContext()); | |||
| #undef cb | |||
| if (failed(applyPartialConversion(getFunction(), target, patterns))) { | |||
| @@ -38,16 +38,6 @@ using namespace jit; | |||
| namespace { | |||
| mlir::Value get_operand(ConversionPatternRewriter& rewriter, | |||
| const mlir::Location& loc, const mlir::Value& val, | |||
| const mlir::Value& index) { | |||
| if (val.getType().isa<mlir::MemRefType>()) { | |||
| return rewriter.create<LoadOp>(loc, val, index); | |||
| } else { | |||
| return val; | |||
| } | |||
| } | |||
| mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||
| auto thread_idx = rewriter.create<gpu::ThreadIdOp>( | |||
| loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); | |||
| @@ -64,7 +54,7 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { | |||
| template <typename Op, typename LoweredOp> | |||
| struct UnaryOpLowering : public ConversionPattern { | |||
| UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
| UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||
| m_launch_op{launch_op} {} | |||
| @@ -74,11 +64,11 @@ struct UnaryOpLowering : public ConversionPattern { | |||
| auto loc = op->getLoc(); | |||
| typename Op::Adaptor binary_adaptor(operands); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| auto index = get_tid(rewriter, loc); | |||
| auto loaded_lhs = | |||
| get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
| LoweredOp lower_op; | |||
| @@ -87,7 +77,7 @@ struct UnaryOpLowering : public ConversionPattern { | |||
| } | |||
| private: | |||
| gpu::LaunchOp* m_launch_op; | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| #define cb(_op, _) \ | |||
| @@ -97,7 +87,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb) | |||
| template <typename Op, typename LoweredOp> | |||
| struct BinaryOpLowering : public ConversionPattern { | |||
| BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
| BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||
| m_launch_op{launch_op} {} | |||
| @@ -107,13 +97,13 @@ struct BinaryOpLowering : public ConversionPattern { | |||
| auto loc = op->getLoc(); | |||
| typename Op::Adaptor binary_adaptor(operands); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| auto index = get_tid(rewriter, loc); | |||
| auto loaded_lhs = | |||
| get_operand(rewriter, loc, binary_adaptor.lhs(), index); | |||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index); | |||
| auto loaded_rhs = | |||
| get_operand(rewriter, loc, binary_adaptor.rhs(), index); | |||
| get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index); | |||
| LoweredOp lower_op; | |||
| @@ -123,7 +113,7 @@ struct BinaryOpLowering : public ConversionPattern { | |||
| } | |||
| private: | |||
| gpu::LaunchOp* m_launch_op; | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| #define cb(_op, _) \ | |||
| @@ -133,7 +123,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
| template <typename Op, typename LoweredOp> | |||
| struct TernaryOpLowering : public ConversionPattern { | |||
| TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
| TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(Op::getOperationName(), 1, ctx), | |||
| m_launch_op{launch_op} {} | |||
| @@ -143,15 +133,15 @@ struct TernaryOpLowering : public ConversionPattern { | |||
| auto loc = op->getLoc(); | |||
| typename Op::Adaptor ternary_adaptor(operands); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| auto index = get_tid(rewriter, loc); | |||
| auto loaded_x = | |||
| get_operand(rewriter, loc, ternary_adaptor.x(), index); | |||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index); | |||
| auto loaded_y = | |||
| get_operand(rewriter, loc, ternary_adaptor.y(), index); | |||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index); | |||
| auto loaded_z = | |||
| get_operand(rewriter, loc, ternary_adaptor.z(), index); | |||
| get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index); | |||
| LoweredOp lower_op; | |||
| @@ -161,7 +151,7 @@ struct TernaryOpLowering : public ConversionPattern { | |||
| } | |||
| private: | |||
| gpu::LaunchOp* m_launch_op; | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| #define cb(_op, _) \ | |||
| @@ -171,7 +161,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
| #undef cb | |||
| struct ReturnOpLowering : public ConversionPattern { | |||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), | |||
| m_launch_op{launch_op} {} | |||
| @@ -182,10 +172,10 @@ struct ReturnOpLowering : public ConversionPattern { | |||
| auto loc = op->getLoc(); | |||
| //! remove the first gpu.terminator | |||
| m_launch_op->body().front().front().erase(); | |||
| m_launch_op.body().front().front().erase(); | |||
| //! if (tid >= nr_tid) {return;} in the begin of the block | |||
| rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); | |||
| rewriter.setInsertionPointToStart(&(m_launch_op.body().front())); | |||
| Block* cond_block = rewriter.getInsertionBlock(); | |||
| Block::iterator op_position = rewriter.getInsertionPoint(); | |||
| Block* remaining_ops_block = | |||
| @@ -195,7 +185,7 @@ struct ReturnOpLowering : public ConversionPattern { | |||
| auto index = get_tid(rewriter, loc); | |||
| auto comparison = rewriter.create<mlir::CmpIOp>( | |||
| loc, CmpIPredicate::sge, index, | |||
| m_launch_op->getParentOfType<mlir::FuncOp>() | |||
| m_launch_op.getParentOfType<mlir::FuncOp>() | |||
| .getArguments() | |||
| .back()); | |||
| @@ -216,11 +206,31 @@ struct ReturnOpLowering : public ConversionPattern { | |||
| } | |||
| private: | |||
| gpu::LaunchOp* m_launch_op; | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| struct ConstantScalarOpLowering | |||
| : public OpRewritePattern<jit::ConstantScalarOp> { | |||
| ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : OpRewritePattern<jit::ConstantScalarOp>(ctx), | |||
| m_launch_op{launch_op} {} | |||
| LogicalResult matchAndRewrite(jit::ConstantScalarOp op, | |||
| PatternRewriter& rewriter) const final { | |||
| ConstantScalarOpAdaptor constant_scalar_adaptor(op); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| rewriter.replaceOpWithNewOp<mlir::ConstantOp>( | |||
| op, constant_scalar_adaptor.value()); | |||
| return success(); | |||
| } | |||
| private: | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| struct AssignOpLowering : public ConversionPattern { | |||
| AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) | |||
| AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), | |||
| m_launch_op{launch_op} {} | |||
| @@ -230,12 +240,12 @@ struct AssignOpLowering : public ConversionPattern { | |||
| auto loc = op->getLoc(); | |||
| AssignOpAdaptor assign_adaptor(operands); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| auto index = get_tid(rewriter, loc); | |||
| auto loaded_lhs = | |||
| get_operand(rewriter, loc, assign_adaptor.lhs(), index); | |||
| get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index); | |||
| rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index); | |||
| rewriter.eraseOp(op); | |||
| @@ -243,7 +253,7 @@ struct AssignOpLowering : public ConversionPattern { | |||
| } | |||
| private: | |||
| gpu::LaunchOp* m_launch_op; | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| class MgbToGpuLoweringPass | |||
| @@ -271,7 +281,8 @@ public: | |||
| cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) | |||
| MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) | |||
| ReturnOpLowering, | |||
| AssignOpLowering>(&getContext(), &launch_op); | |||
| ConstantScalarOpLowering, AssignOpLowering>( | |||
| &getContext(), launch_op); | |||
| #undef cb | |||
| if (failed(applyPartialConversion(func_op, target, patterns))) { | |||
| @@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" | |||
| include "mlir/Interfaces/SideEffectInterfaces.td" | |||
| include "./interfaces.td" | |||
| include "./predicates.td" | |||
| def Mgb_Dialect : Dialect { | |||
| let name = "mgb"; | |||
| @@ -90,7 +91,7 @@ def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||
| class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
| ElemwiseOp<mnemonic, traits> { | |||
| let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs); | |||
| let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs); | |||
| let results = (outs F32MemRef); | |||
| let builders = [OpBuilder< | |||
| @@ -141,7 +142,7 @@ def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||
| class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
| ElemwiseOp<mnemonic, traits> { | |||
| let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z); | |||
| let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z); | |||
| let results = (outs F32MemRef); | |||
| let builders = [OpBuilder< | |||
| @@ -178,6 +179,25 @@ def ReturnOp : GenericOp<"return", | |||
| } | |||
| def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | |||
| let summary = "scalar constant"; | |||
| let arguments = (ins AnyAttr:$value); | |||
| let results = (outs F32:$result); | |||
| let builders = [OpBuilder< | |||
| "Builder* builder, OperationState& result, float value", [{ | |||
| result.addAttribute("value", builder->getF32FloatAttr(value)); | |||
| result.addTypes(builder->getF32Type()); | |||
| }] | |||
| >]; | |||
| let extraClassDeclaration = [{ | |||
| Attribute getValue() { return getAttr("value"); } | |||
| FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); } | |||
| }]; | |||
| } | |||
| def AssignOp : GenericOp<"assign", []> { | |||
| let summary = "assign op"; | |||
| let description = [{ | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * \file src/jit/impl/mlir/ir/predicates.td | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifndef MGB_MLIR_PREDICATES | |||
| #define MGB_MLIR_PREDICATES | |||
| #ifndef OP_BASE | |||
| include "mlir/IR/OpBase.td" | |||
| #endif | |||
| def ElemwiseFloatAny : TypeConstraint< | |||
| CPred<"is_elemwise_float($_self)">, "elemwise-float">; | |||
| #endif | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * \file src/jit/impl/mlir/ir/types.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_JIT && MGB_JIT_MLIR | |||
| #include <mlir/IR/StandardTypes.h> | |||
| namespace mgb { | |||
| namespace jit { | |||
| inline bool is_elemwise_float(const mlir::Type& dt) { | |||
| if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | |||
| if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | |||
| return true; | |||
| } | |||
| } | |||
| if (dt.isa<mlir::FloatType>()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace jit | |||
| } // namespace mgb | |||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -49,6 +49,9 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) { | |||
| megdnn::TensorShape dst; | |||
| megdnn::DType dst_type; | |||
| for (auto operand : operands) { | |||
| if (operand.getType().isa<mlir::FloatType>()) { | |||
| continue; | |||
| } | |||
| auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>(); | |||
| mgb_assert(type, "currently only support MemRefType"); | |||
| @@ -137,6 +137,27 @@ private: | |||
| return; | |||
| } | |||
| if (opr->same_type<opr::ImmutableTensor>()) { | |||
| auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar(); | |||
| if (imm.valid()) { | |||
| auto dtype = imm->dtype(); | |||
| float scalar_value; | |||
| if (dtype == dtype::Float32()) { | |||
| scalar_value = imm->get<float>(); | |||
| } else { | |||
| mgb_throw(InternalError, | |||
| "mlir backend currently only support f32 " | |||
| "dtype, but got %s", | |||
| dtype.name()); | |||
| } | |||
| auto&& out = m_builder.create<jit::ConstantScalarOp>( | |||
| m_builder.getUnknownLoc(), m_builder.getF32Type(), | |||
| m_builder.getF32FloatAttr(scalar_value)); | |||
| mgb_assert(mlir::succeeded( | |||
| declare(opr->output(0)->name(), out))); | |||
| } | |||
| } | |||
| if (opr->same_type<opr::Elemwise>()) { | |||
| auto&& out = gen_op(opr->cast_final<opr::Elemwise>()); | |||
| mgb_assert( | |||
| @@ -137,7 +137,7 @@ void run_mlir(CompNode cn) { | |||
| b = opr::Host2DeviceCopy::make(*graph, host_x1), | |||
| c = opr::Host2DeviceCopy::make(*graph, host_x2); | |||
| auto y = a + b * c; | |||
| auto y = a + b * c + 0.3f; | |||
| auto ig_gen = | |||
| std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||