| @@ -22,6 +22,7 @@ | |||||
| #include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
| #include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
| #include <llvm/Support/raw_ostream.h> | |||||
| #include <mlir/Dialect/StandardOps/IR/Ops.h> | #include <mlir/Dialect/StandardOps/IR/Ops.h> | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -442,31 +443,35 @@ mlir::Value lower_elemwise_to_std(mlir::Operation* op, mlir::OpBuilder& builder, | |||||
| mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, | mlir::Value lower_typecvt_to_std(mlir::Operation* op, mlir::OpBuilder& builder, | ||||
| mlir::Location loc, mlir::Value input) { | mlir::Location loc, mlir::Value input) { | ||||
| auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op); | auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op); | ||||
| megdnn::DType idtype = typecvt.idtype(); | |||||
| megdnn::DType odtype = typecvt.odtype(); | |||||
| mlir::Type idtype = typecvt.idtype(); | |||||
| mlir::Type odtype = | |||||
| megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext()); | |||||
| mlir::Type itype = input.getType(); | mlir::Type itype = input.getType(); | ||||
| mlir::Type otype = megdnn_dtype_to_mlir_type(odtype, builder.getContext()); | |||||
| mlir::Type otype = signless(odtype); | |||||
| mgb_assert(signless(idtype) == itype); | |||||
| if (mlir::FPExtOp::areCastCompatible(itype, otype)) { | if (mlir::FPExtOp::areCastCompatible(itype, otype)) { | ||||
| return builder.create<mlir::FPExtOp>(loc, otype, input); | return builder.create<mlir::FPExtOp>(loc, otype, input); | ||||
| } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { | } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { | ||||
| return builder.create<mlir::FPTruncOp>(loc, otype, input); | return builder.create<mlir::FPTruncOp>(loc, otype, input); | ||||
| } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and | } else if (mlir::FPToSIOp::areCastCompatible(itype, otype) and | ||||
| is_signed_int_dtype(odtype)) { | |||||
| odtype.isSignedInteger()) { | |||||
| return builder.create<mlir::FPToSIOp>(loc, otype, input); | return builder.create<mlir::FPToSIOp>(loc, otype, input); | ||||
| } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and | } else if (mlir::FPToUIOp::areCastCompatible(itype, otype) and | ||||
| is_unsigned_int_dtype(odtype)) { | |||||
| odtype.isUnsignedInteger()) { | |||||
| return builder.create<mlir::FPToUIOp>(loc, otype, input); | return builder.create<mlir::FPToUIOp>(loc, otype, input); | ||||
| } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and | } else if (mlir::SIToFPOp::areCastCompatible(itype, otype) and | ||||
| is_signed_int_dtype(idtype)) { | |||||
| idtype.isSignedInteger()) { | |||||
| return builder.create<mlir::SIToFPOp>(loc, otype, input); | return builder.create<mlir::SIToFPOp>(loc, otype, input); | ||||
| } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and | } else if (mlir::UIToFPOp::areCastCompatible(itype, otype) and | ||||
| is_unsigned_int_dtype(idtype)) { | |||||
| idtype.isUnsignedInteger()) { | |||||
| return builder.create<mlir::UIToFPOp>(loc, otype, input); | return builder.create<mlir::UIToFPOp>(loc, otype, input); | ||||
| } else { | } else { | ||||
| mgb_throw(InternalError, "cannot convert from %s to %s", idtype.name(), | |||||
| odtype.name()); | |||||
| std::string tmp; | |||||
| llvm::raw_string_ostream os(tmp); | |||||
| os << "cannot convert from " << idtype << " to " << odtype; | |||||
| mgb_throw_raw(InternalError{tmp}); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| @@ -28,13 +28,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | |||||
| case megdnn::DTypeEnum::Float32: | case megdnn::DTypeEnum::Float32: | ||||
| return mlir::FloatType::getF32(ctx); | return mlir::FloatType::getF32(ctx); | ||||
| case megdnn::DTypeEnum::Uint8: | case megdnn::DTypeEnum::Uint8: | ||||
| return mlir::IntegerType::get(8, ctx); | |||||
| return mlir::IntegerType::get(8, mlir::IntegerType::Unsigned, ctx); | |||||
| case megdnn::DTypeEnum::Int8: | case megdnn::DTypeEnum::Int8: | ||||
| return mlir::IntegerType::get(8, ctx); | |||||
| return mlir::IntegerType::get(8, mlir::IntegerType::Signed, ctx); | |||||
| case megdnn::DTypeEnum::Int16: | case megdnn::DTypeEnum::Int16: | ||||
| return mlir::IntegerType::get(16, ctx); | |||||
| return mlir::IntegerType::get(16, mlir::IntegerType::Signed, ctx); | |||||
| case megdnn::DTypeEnum::Int32: | case megdnn::DTypeEnum::Int32: | ||||
| return mlir::IntegerType::get(32, ctx); | |||||
| return mlir::IntegerType::get(32, mlir::IntegerType::Signed, ctx); | |||||
| case megdnn::DTypeEnum::IntB1: | case megdnn::DTypeEnum::IntB1: | ||||
| return mlir::IntegerType::get(1, ctx); | return mlir::IntegerType::get(1, ctx); | ||||
| case megdnn::DTypeEnum::IntB2: | case megdnn::DTypeEnum::IntB2: | ||||
| @@ -57,6 +57,13 @@ mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | |||||
| } | } | ||||
| } | } | ||||
| mlir::Type signless(mlir::Type type) { | |||||
| if (auto intty = type.dyn_cast<mlir::IntegerType>()) { | |||||
| return mlir::IntegerType::get(intty.getWidth(), type.getContext()); | |||||
| } | |||||
| return type; | |||||
| } | |||||
| megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { | megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { | ||||
| mlir::Type element_type = type; | mlir::Type element_type = type; | ||||
| if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = type.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
| @@ -91,22 +98,6 @@ megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type) { | |||||
| return megdnn::DType::from_enum(enumv); | return megdnn::DType::from_enum(enumv); | ||||
| } | } | ||||
| bool is_signed_int_dtype(megdnn::DType type) { | |||||
| auto enumv = type.enumv(); | |||||
| return enumv == megdnn::DTypeEnum::Int8 or | |||||
| enumv == megdnn::DTypeEnum::Int16 or | |||||
| enumv == megdnn::DTypeEnum::Int32 or | |||||
| enumv == megdnn::DTypeEnum::IntB1 or | |||||
| enumv == megdnn::DTypeEnum::IntB2 or | |||||
| enumv == megdnn::DTypeEnum::IntB4; | |||||
| } | |||||
| bool is_unsigned_int_dtype(megdnn::DType type) { | |||||
| auto enumv = type.enumv(); | |||||
| return enumv == megdnn::DTypeEnum::Uint8 or | |||||
| enumv == megdnn::DTypeEnum::UintB4; | |||||
| } | |||||
| } // namespace jit | } // namespace jit | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -35,13 +35,10 @@ namespace jit { | |||||
| mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | mlir::Type megdnn_dtype_to_mlir_type(megdnn::DType type, | ||||
| mlir::MLIRContext* ctx); | mlir::MLIRContext* ctx); | ||||
| mlir::Type signless(mlir::Type type); | |||||
| megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); | megdnn::DType mlir_type_to_megdnn_dtype(mlir::Type type); | ||||
| bool is_signed_int_dtype(megdnn::DType type); | |||||
| bool is_unsigned_int_dtype(megdnn::DType type); | |||||
| } // namespace jit | } // namespace jit | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -87,7 +87,7 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout, | |||||
| shape.push_back(layout[i]); | shape.push_back(layout[i]); | ||||
| } | } | ||||
| mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); | mlir::Type type = megdnn_dtype_to_mlir_type(layout.dtype, builder.getContext()); | ||||
| return mlir::MemRefType::get(shape, type); | |||||
| return mlir::MemRefType::get(shape, signless(type)); | |||||
| } | } | ||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| @@ -197,12 +197,15 @@ private: | |||||
| .getType() | .getType() | ||||
| .dyn_cast_or_null<mlir::MemRefType>(); | .dyn_cast_or_null<mlir::MemRefType>(); | ||||
| mgb_assert(itype, "currently only support MemRefType"); | mgb_assert(itype, "currently only support MemRefType"); | ||||
| auto output_type = megdnn_dtype_to_mlir_type(opr.param(), | |||||
| m_builder.getContext()); | |||||
| auto res_type = mlir::MemRefType::get( | auto res_type = mlir::MemRefType::get( | ||||
| itype.getShape(), | |||||
| megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext())); | |||||
| itype.getShape(), signless(output_type)); | |||||
| auto inp_type = megdnn_dtype_to_mlir_type(opr.input(0)->dtype(), | |||||
| m_builder.getContext()); | |||||
| return m_builder.create<dialect::TypeCvt>( | return m_builder.create<dialect::TypeCvt>( | ||||
| m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | m_builder.getUnknownLoc(), res_type, get(opr.input(0)), | ||||
| opr.input(0)->dtype(), opr.param()); | |||||
| mlir::TypeAttr::get(inp_type), opr.param()); | |||||
| } | } | ||||
| mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { | mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) { | ||||
| @@ -15,7 +15,10 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_JIT && MGB_JIT_MLIR | #if MGB_JIT && MGB_JIT_MLIR | ||||
| #include "megdnn/basic_types.h" | |||||
| #include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
| #include "megbrain/opr/param_defs.h" | |||||
| #include "megbrain/comp_node.h" | |||||
| #include <mlir/IR/Dialect.h> | #include <mlir/IR/Dialect.h> | ||||
| #include <mlir/IR/Function.h> | #include <mlir/IR/Function.h> | ||||
| @@ -15,6 +15,8 @@ | |||||
| include "ops.td" | include "ops.td" | ||||
| include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
| class GenericOp<string mnemonic, list<OpTrait> traits = []> : | class GenericOp<string mnemonic, list<OpTrait> traits = []> : | ||||
| Op<Mgb_Dialect, mnemonic, traits>; | Op<Mgb_Dialect, mnemonic, traits>; | ||||