#include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include "./common.h" #include "./each_mode.h" #include "./numerical.h" #include "./types.h" #include "megbrain/common.h" #include "megbrain/exception.h" #include "megbrain/jit/mlir/ir/dialect.h" #include #include namespace mgb { namespace jit { using Mode = megdnn::param::Elemwise::Mode; template mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands); /* ===================== trivial implementations ===================== */ #define cb(mode, fun) \ template <> \ mlir::Value lower_mode( \ mlir::OpBuilder & builder, mlir::Location loc, ValueRange operands) { \ ValueBuilderHelper helper(builder, loc); \ return helper.fun(operands); \ } //! unary cb(ABS, abs); cb(CEIL, ceil); cb(COS, cos); cb(EXP, exp); cb(FLOOR, floor); cb(LOG, log); cb(NEGATE, neg); cb(SIN, sin); cb(TANH, tanh); //! binary cb(ADD, add); cb(MAX, max); cb(MIN, min); cb(MOD, mod); cb(MUL, mul); cb(SUB, sub); cb(TRUE_DIV, div); #undef cb /* ===================== unary op ===================== */ //! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto x = operands[0]; auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); auto pi_over_2 = helper.const_f32(1.57079637f); return helper.sub(pi_over_2, asin); } //! ASIN: arctan2(x, sqrt(1 - x * x)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto x = operands[0]; auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x)); return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); } //! ERFCINV: inverse of complementary gauss error function //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto minus_sqrt2 = helper.const_f32(-1.4142135623f); auto x = helper.mul(helper.const_f32(0.5f), operands[0]); return helper.div(ndtri_approx(helper, x), minus_sqrt2); } //! ERFC: complementary error function template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0])); } //! ERFINV: inverse of gauss error function //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto sqrt2 = helper.const_f32(1.4142135623f); auto x = helper.mul( helper.const_f32(0.5f), helper.add(operands[0], helper.const_f32(1.f))); return helper.div(ndtri_approx(helper, x), sqrt2); } //! ERF: gauss error function template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return erf_approx(helper, operands[0]); } //! EXPM1: exp(x) - 1 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f)); } //! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x); template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto square = helper.mul(operands[0], operands[0]); return helper.div( helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)), helper.add( helper.const_f32(27.f), helper.mul(helper.const_f32(9.f), square))); } //! H_SWISH: x * clip(x + 3, 0, 6) / 6 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto const_3 = helper.const_f32(3.f); auto const_0 = helper.const_f32(0.f); auto const_6 = helper.const_f32(6.f); auto tmp = helper.add(operands[0], const_3); return helper.div( helper.mul(operands[0], helper.min(helper.max(tmp, const_0), const_6)), const_6); } //! LOG1P: log(1 + p) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.log(helper.add(operands[0], helper.const_f32(1.f))); } //! RELU: max(x, 0) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.max(operands[0], helper.const_f32(0.f)); } //! ROUND template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.gt(operands[0], helper.const_f32(0.f)), helper.floor(helper.add(operands[0], helper.const_f32(0.5f))), helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f)))); } //! SIGMOID: 1.f / (expf(-y) + 1.f)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.div( helper.const_f32(1.f), helper.add(helper.exp(helper.neg(operands[0])), helper.const_f32(1.f))); } /* ===================== binary op ===================== */ //! ABS_GRAD: x > 0 ? y : -y template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.gt(operands[0], helper.const_f32(0.f)), operands[1], helper.neg(operands[1])); } //! ATAN2 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return atan2_approx(helper, operands[0], operands[1]); } //! EQ: x == y ? 1 : 0 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.eq(operands[0], operands[1]), helper.const_f32(1.f), helper.const_f32(0.f)); } //! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x //! * x) * y template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto x_pow2 = helper.mul(operands[0], operands[0]); auto deno = helper.add(helper.const_f32(3.f), x_pow2); return helper.mul( helper.div( helper.add( helper.add( helper.div( helper.mul(helper.const_f32(-48.f), x_pow2), deno), helper.const_f32(27.f)), x_pow2), helper.mul(deno, helper.const_f32(9.f))), operands[1]); } //! FLOOR_DIV: floor(x/y) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.floor(helper.div(operands[0], operands[1])); } //! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto sum = helper.add(operands[0], operands[1]); auto const_3 = helper.const_f32(3.f); auto const_0 = helper.const_f32(0.f); auto const_6 = helper.const_f32(6.f); auto tmp = helper.add(sum, const_3); return helper.div( helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), const_6); } //! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); auto sum = helper.add(operands[0], operands[1]); return helper.max(sum, helper.const_f32(0.f)); } //! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.div( helper.const_f32(1.f), helper.add( helper.exp(helper.neg(helper.add(operands[0], operands[1]))), helper.const_f32(1.f))); } //! FUSE_ADD_TANH: tanh(x + y) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.tanh(helper.add(operands[0], operands[1])); } //! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.lt(operands[0], helper.const_f32(-3.f)), helper.const_f32(0.f), helper.select( helper.gt(operands[0], helper.const_f32(3.f)), operands[1], helper.mul( helper.div( helper.add( helper.mul( helper.const_f32(2.f), operands[0]), helper.const_f32(3.f)), helper.const_f32(6.f)), operands[1]))); } //! LEQ: x <= y ? 1 : 0 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.le(operands[0], operands[1]), helper.const_f32(1.f), helper.const_f32(0.f)); } //! LOG_SUM_EXP: log(exp(x) + exp(y)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.log(helper.add(helper.exp(operands[0]), helper.exp(operands[1]))); } //! LT: x < y ? 1 : 0 template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.lt(operands[0], operands[1]), helper.const_f32(1.f), helper.const_f32(0.f)); } //! POW: x^y = exp(y * log(x)) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); } //! SIGMOID_GRAD: x * (1 - x) * y template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.mul( helper.mul(operands[0], helper.sub(helper.const_f32(1.f), operands[0])), operands[1]); } //! SWITCH_GT0: (x > 0) * y template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.gt(operands[0], helper.const_f32(0.f)), operands[1], helper.const_f32(0.f)); } //! TANH_GRAD: (1 - x * x) * y template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.mul( helper.sub(helper.const_f32(1.0f), helper.mul(operands[0], operands[0])), operands[1]); } /* ===================== ternary op ===================== */ //! COND_LEQ_MOV: x <= y ? z : ctype(0) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); } //! COND_LT_MOV: x < y ? z : ctype(0) template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.select( helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); } //! FUSE_MUL_ADD3: x * y + z template <> mlir::Value lower_mode( mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { ValueBuilderHelper helper(builder, loc); return helper.add(helper.mul(operands[0], operands[1]), operands[2]); } /* ===================== elemwise ===================== */ mlir::Value lower_elemwise_to_std( mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { auto mode = llvm::dyn_cast(op).mode(); switch (mode) { #define cb(_, _mode) \ case Mode::_mode: \ return lower_mode(builder, loc, operands); MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb); MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb); MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb); default: return nullptr; } #undef cb } /* ===================== typecvt ===================== */ mlir::Value lower_typecvt_to_std( mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc, mlir::Value input) { auto&& typecvt = llvm::dyn_cast(op); mlir::Type idtype = typecvt.idtype(); mlir::Type odtype = megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext()); mlir::Type itype = input.getType(); mlir::Type otype = signless(odtype); mgb_assert(signless(idtype) == itype); if (mlir::FPExtOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) { return builder.create(loc, otype, input); } else if ( mlir::FPToSIOp::areCastCompatible(itype, otype) and odtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if ( mlir::FPToUIOp::areCastCompatible(itype, otype) and odtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else if ( mlir::SIToFPOp::areCastCompatible(itype, otype) and idtype.isSignedInteger()) { return builder.create(loc, otype, input); } else if ( mlir::UIToFPOp::areCastCompatible(itype, otype) and idtype.isUnsignedInteger()) { return builder.create(loc, otype, input); } else { std::string tmp; llvm::raw_string_ostream os(tmp); os << "cannot convert from " << idtype << " to " << odtype; mgb_throw_raw(InternalError{tmp}); } return nullptr; } } // namespace jit } // namespace mgb #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen