GitOrigin-RevId: ce6f4ea42a
tags/v1.1.0
| @@ -34,7 +34,7 @@ public: | |||
| Property property() const override { | |||
| using F = Property::Flag; | |||
| return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM, | |||
| JITFeatureBits::NONE, 64}; | |||
| JITFeatureBits::DIMSHUFFLE, 64}; | |||
| } | |||
| size_t get_nr_workspace_outputs(JITExecutor* opr) const override; | |||
| @@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern { | |||
| ElemwiseLowering(MLIRContext* ctx) | |||
| : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, | |||
| ctx) {} | |||
| LogicalResult matchAndRewrite( | |||
| Operation* op, ArrayRef<Value> operands, | |||
| ConversionPatternRewriter& rewriter) const final { | |||
| @@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern { | |||
| TypeCvtLowering(MLIRContext* ctx) | |||
| : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, | |||
| ctx) {} | |||
| LogicalResult matchAndRewrite( | |||
| Operation* op, ArrayRef<Value> operands, | |||
| ConversionPatternRewriter& rewriter) const final { | |||
| @@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern { | |||
| } | |||
| }; | |||
| struct DimshuffleLowering : public ConversionPattern { | |||
| DimshuffleLowering(MLIRContext* ctx) | |||
| : ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1, | |||
| ctx) {} | |||
| static mlir::AffineMap get_affinemap_from_pattern( | |||
| const std::vector<int32_t>& pattern, mlir::MLIRContext* ctx) { | |||
| size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||
| std::vector<mlir::AffineExpr> exprs(ndim); | |||
| for (size_t i = 0; i < pattern.size(); i++) { | |||
| int32_t j = pattern[i]; | |||
| if (j >= 0) { | |||
| exprs[j] = mlir::getAffineDimExpr(i, ctx); | |||
| } | |||
| } | |||
| return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx); | |||
| } | |||
| LogicalResult matchAndRewrite( | |||
| Operation* op, ArrayRef<Value> operands, | |||
| ConversionPatternRewriter& rewriter) const final { | |||
| auto loc = op->getLoc(); | |||
| auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||
| auto map = get_affinemap_from_pattern(pattern, op->getContext()); | |||
| lower_op_to_loops( | |||
| op, operands, rewriter, | |||
| [loc, op, &map](OpBuilder& builder, ValueRange memref_operands, | |||
| ValueRange loop_ivs) { | |||
| return builder.create<AffineLoadOp>(loc, memref_operands[0], | |||
| map, loop_ivs); | |||
| }); | |||
| return success(); | |||
| } | |||
| }; | |||
| struct AssignOpLowering : public ConversionPattern { | |||
| AssignOpLowering(MLIRContext* ctx) | |||
| : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { | |||
| @@ -172,9 +209,9 @@ public: | |||
| target.addIllegalDialect<MgbDialect>(); | |||
| OwningRewritePatternList patterns; | |||
| patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||
| AssignOpLowering, ConstantScalarOpLowering>( | |||
| &getContext()); | |||
| patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||
| ReturnOpLowering, AssignOpLowering, | |||
| ConstantScalarOpLowering>(&getContext()); | |||
| if (failed(applyPartialConversion(getFunction(), target, | |||
| std::move(patterns)))) { | |||
| @@ -152,6 +152,47 @@ private: | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| struct DimshuffleLowering : public ConversionPattern { | |||
| DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(dialect::Dimshuffle::getOperationName(), 1, | |||
| ctx), | |||
| m_launch_op{launch_op} {} | |||
| static std::vector<mlir::Value> get_index_from_pattern( | |||
| const std::vector<int32_t>& pattern, | |||
| const std::vector<mlir::Value>& index) { | |||
| size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1; | |||
| std::vector<mlir::Value> res(ndim); | |||
| for (size_t i = 0; i < pattern.size(); i++) { | |||
| int32_t j = pattern[i]; | |||
| if (j >= 0) { | |||
| res[j] = index[i]; | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| LogicalResult matchAndRewrite( | |||
| Operation* op, ArrayRef<Value> operands, | |||
| ConversionPatternRewriter& rewriter) const final { | |||
| auto loc = op->getLoc(); | |||
| rewriter.setInsertionPointToEnd(&(m_launch_op.body().front())); | |||
| auto dst_layout = output_layout(m_launch_op); | |||
| auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout); | |||
| auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern(); | |||
| auto shuffled_index = get_index_from_pattern(pattern, index); | |||
| rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0], | |||
| shuffled_index)); | |||
| return success(); | |||
| } | |||
| private: | |||
| gpu::LaunchOp& m_launch_op; | |||
| }; | |||
| struct ReturnOpLowering : public ConversionPattern { | |||
| ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op) | |||
| : ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx), | |||
| @@ -275,9 +316,9 @@ public: | |||
| target.addLegalDialect<gpu::GPUDialect>(); | |||
| target.addIllegalDialect<MgbDialect>(); | |||
| patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering, | |||
| ConstantScalarOpLowering, AssignOpLowering>( | |||
| &getContext(), launch_op); | |||
| patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering, | |||
| ReturnOpLowering, ConstantScalarOpLowering, | |||
| AssignOpLowering>(&getContext(), launch_op); | |||
| if (failed(applyPartialConversion(func_op, target, | |||
| std::move(patterns)))) { | |||
| @@ -20,6 +20,7 @@ | |||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||
| #include "megbrain/jit/mlir/ir/utils.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megdnn/dtype.h" | |||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | |||
| @@ -160,6 +161,10 @@ private: | |||
| mgb_assert( | |||
| mlir::succeeded(declare(opr->output(0)->name(), out))); | |||
| return; | |||
| } else if (opr->same_type<opr::Dimshuffle>()) { | |||
| auto&& out = gen_dimshuffle(opr->cast_final<opr::Dimshuffle>()); | |||
| mgb_assert( | |||
| mlir::succeeded(declare(opr->output(0)->name(), out))); | |||
| } else if (opr->same_type<opr::TypeCvt>()) { | |||
| auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>()); | |||
| mgb_assert( | |||
| @@ -186,18 +191,44 @@ private: | |||
| } | |||
| mlir::Value gen_typecvt(const opr::TypeCvt& opr) { | |||
| auto shape = get(opr.input(0)) | |||
| auto itype = get(opr.input(0)) | |||
| .getType() | |||
| .dyn_cast_or_null<mlir::MemRefType>() | |||
| .getShape(); | |||
| .dyn_cast_or_null<mlir::MemRefType>(); | |||
| mgb_assert(itype, "currently only support MemRefType"); | |||
| auto res_type = mlir::MemRefType::get( | |||
| shape, | |||
| itype.getShape(), | |||
| megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | |||
| return m_builder.create<dialect::TypeCvt>( | |||
| m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||
| opr.input(0)->dtype(), opr.param()); | |||
| } | |||
| mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { | |||
| auto itype = get(opr.input(0)) | |||
| .getType() | |||
| .dyn_cast_or_null<mlir::MemRefType>(); | |||
| mgb_assert(itype, "the input type of Dimshuffle must be MemRefType"); | |||
| auto ishape = itype.getShape(); | |||
| auto param = opr.param(); | |||
| std::vector<int32_t> pattern; | |||
| std::vector<int64_t> oshape; | |||
| for (size_t i = 0; i < param.pattern_len; i++) { | |||
| int32_t j = param.pattern[i]; | |||
| pattern.push_back(j); | |||
| if (j < 0) { | |||
| oshape.push_back(1); | |||
| } else { | |||
| oshape.push_back(ishape[j]); | |||
| } | |||
| } | |||
| auto res_type = mlir::MemRefType::get(oshape, itype.getElementType()); | |||
| return m_builder.create<dialect::Dimshuffle>( | |||
| m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | |||
| pattern); | |||
| } | |||
| mlir::Type get_type(const TensorLayout& layout) { | |||
| return layout_to_mlir_type(layout, m_builder); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| #include "megbrain/jit/executor_opr.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/basic_arith_wrapper.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/test/helper.h" | |||
| #include "megdnn/dtype.h" | |||
| @@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32); | |||
| #undef add_typecvt_gtest | |||
| /* ===================== TestJITMlirDimshuffle ===================== */ | |||
| void run_dimshuffle(CompNode cn, TensorShape ishape, | |||
| const std::vector<int>& pattern) { | |||
| set_backend(Backend::MLIR); | |||
| auto graph = ComputingGraph::make(); | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen(ishape, cn); | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
| auto y = opr::Dimshuffle::make(x, pattern); | |||
| auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr()); | |||
| for (auto i : get_rev_topo_order(y)) { | |||
| if (!i->template same_type<opr::Host2DeviceCopy>()) { | |||
| ig_gen->add_opr(i); | |||
| } | |||
| } | |||
| auto igraph = ig_gen->generate(); | |||
| auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps()); | |||
| HostTensorND host_y, host_y_jit; | |||
| auto func = graph->compile({make_callback_copy(y, host_y), | |||
| make_callback_copy(y_jit, host_y_jit)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
| } | |||
| void run_dimshuffle_cases(CompNode cn) { | |||
| run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1}); | |||
| run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2}); | |||
| } | |||
| TEST(TestJITMlirDimshuffle, Basic) { | |||
| run_dimshuffle_cases(CompNode::load("cpu0")); | |||
| } | |||
| TEST(TestJITMlirDimshuffle, BasicGPU) { | |||
| REQUIRE_GPU(1); | |||
| run_dimshuffle_cases(CompNode::load("gpu0")); | |||
| } | |||
| #endif // MGB_JIT_MLIR | |||
| #endif // MGB_JIT | |||