GitOrigin-RevId: 5883c68804
tags/v1.1.0
| @@ -29,6 +29,8 @@ cb(sub, SubFOp); | |||
| cb(mul, MulFOp); | |||
| cb(div, DivFOp); | |||
| cb(mod, RemFOp); | |||
| cb(bit_and, AndOp); | |||
| cb(bit_or, OrOp); | |||
| #undef cb | |||
| #define cb(name, mode) \ | |||
| @@ -72,6 +74,7 @@ cb(exp, ExpOp); | |||
| cb(exp2, Exp2Op); | |||
| cb(log10, Log10Op); | |||
| cb(log2, Log2Op); | |||
| cb(log, LogOp); | |||
| cb(rsqrt, RsqrtOp); | |||
| cb(sin, SinOp); | |||
| cb(sqrt, SqrtOp); | |||
| @@ -79,7 +82,8 @@ cb(tanh, TanhOp); | |||
| #undef cb | |||
| mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { | |||
| return max(lhs, const_val(0.f)); | |||
| auto zero = const_val(0.f); | |||
| return select(ge(lhs, zero), lhs, sub(zero, lhs)); | |||
| } | |||
| mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | |||
| @@ -87,11 +91,6 @@ mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | |||
| return neg(ceil(neg(lhs))); | |||
| } | |||
| mlir::Value ValueBuilderHelper::log(mlir::Value lhs) { | |||
| // math.log10(math.e) = 0.4342944819032518f | |||
| return div(log10(lhs), const_val(0.4342944819032518f)); | |||
| } | |||
| mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | |||
| mlir::Value false_val) { | |||
| return m_builder.create<mlir::SelectOp>(m_location, cond, true_val, | |||
| @@ -47,6 +47,8 @@ public: | |||
| cb(lt); | |||
| cb(le); | |||
| cb(eq); | |||
| cb(bit_and); | |||
| cb(bit_or); | |||
| #undef cb | |||
| mlir::Value const_val(float val); | |||
| @@ -18,6 +18,7 @@ | |||
| #include "megbrain/jit/mlir/ir/dialect.h" | |||
| #include "./common.h" | |||
| #include "./numerical.h" | |||
| #include <mlir/Dialect/StandardOps/IR/Ops.h> | |||
| #include <mlir/IR/Builders.h> | |||
| @@ -28,6 +29,8 @@ | |||
| cb(ReluOp, RELU) \ | |||
| cb(AbsOp, ABS) \ | |||
| cb(NegOp, NEGATE) \ | |||
| cb(AcosOp, ACOS) \ | |||
| cb(AsinOp, ASIN) \ | |||
| cb(CeilOp, CEIL) \ | |||
| cb(CosOp, COS) \ | |||
| cb(ExpOp, EXP) \ | |||
| @@ -40,7 +43,11 @@ | |||
| cb(FastTanhOp, FAST_TANH) \ | |||
| cb(HswishOp, H_SWISH) \ | |||
| cb(ExpM1Op, EXPM1) \ | |||
| cb(RoundOp, ROUND) | |||
| cb(RoundOp, ROUND) \ | |||
| cb(ErfOp, ERF) \ | |||
| cb(ErfInvOp, ERFINV) \ | |||
| cb(ErfCOp, ERFC) \ | |||
| cb(ErfCInvOp, ERFCINV) | |||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | |||
| cb(AbsGradOp, ABS_GRAD) \ | |||
| @@ -52,6 +59,7 @@ | |||
| cb(SubOp, SUB) \ | |||
| cb(MulOp, MUL) \ | |||
| cb(TrueDivOp, TRUE_DIV) \ | |||
| cb(PowOp, POW) \ | |||
| cb(SigmoidGradOp, SIGMOID_GRAD) \ | |||
| cb(SwishGt0Op, SWITCH_GT0) \ | |||
| cb(TanhGradOp, TANH_GRAD) \ | |||
| @@ -64,7 +72,8 @@ | |||
| cb(FastTanhGradOp, FAST_TANH_GRAD) \ | |||
| cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | |||
| cb(HswishGradOp, H_SWISH_GRAD) \ | |||
| cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) | |||
| cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ | |||
| cb(Atan2Op, ATAN2) | |||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | |||
| cb(CondLeqMovOp, COND_LEQ_MOV) \ | |||
| @@ -197,6 +206,79 @@ struct StandardOp<jit::RoundOp> { | |||
| } | |||
| }; | |||
| //! pi / 2 - arctan2(x, sqrt(1 - x * x)) | |||
| template <> | |||
| struct StandardOp<jit::AcosOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| auto x = operands[0]; | |||
| auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||
| auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||
| auto pi_over_2 = helper.const_val(1.57079637f); | |||
| return helper.sub(pi_over_2, asin); | |||
| } | |||
| }; | |||
| //! arctan2(x, sqrt(1 - x * x)) | |||
| template <> | |||
| struct StandardOp<jit::AsinOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| auto x = operands[0]; | |||
| auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||
| return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||
| } | |||
| }; | |||
| //! gauss error function | |||
| template <> | |||
| struct StandardOp<jit::ErfOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| return erf_approx(helper, operands[0]); | |||
| } | |||
| }; | |||
| //! inverse of gauss error function | |||
| //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||
| template <> | |||
| struct StandardOp<jit::ErfInvOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| auto sqrt2 = helper.const_val(1.4142135623f); | |||
| auto x = helper.mul(helper.const_val(0.5f), | |||
| helper.add(operands[0], helper.const_val(1.f))); | |||
| return helper.div(ndtri_approx(helper, x), sqrt2); | |||
| } | |||
| }; | |||
| //! complementary error function | |||
| template <> | |||
| struct StandardOp<jit::ErfCOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0])); | |||
| } | |||
| }; | |||
| //! inverse of complementary gauss error function | |||
| //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||
| template <> | |||
| struct StandardOp<jit::ErfCInvOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| auto minus_sqrt2 = helper.const_val(-1.4142135623f); | |||
| auto x = helper.mul(helper.const_val(0.5f), operands[0]); | |||
| return helper.div(ndtri_approx(helper, x), minus_sqrt2); | |||
| } | |||
| }; | |||
| /////////////////////////// binary op /////////////////////////// | |||
| //! binary: x > 0 ? y : -y | |||
| @@ -210,6 +292,16 @@ struct StandardOp<jit::AbsGradOp> { | |||
| } | |||
| }; | |||
| //! x^y = exp(y * log(x)) | |||
| template <> | |||
| struct StandardOp<jit::PowOp> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); | |||
| } | |||
| }; | |||
| //! x * (1 - x) * y | |||
| template <> | |||
| struct StandardOp<jit::SigmoidGradOp> { | |||
| @@ -382,6 +474,16 @@ struct StandardOp<jit::FuseAddHswishOp> { | |||
| } | |||
| }; | |||
| //! arctan | |||
| template <> | |||
| struct StandardOp<jit::Atan2Op> { | |||
| mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||
| ValueRange operands) { | |||
| ValueBuilderHelper helper(builder, loc); | |||
| return atan2_approx(helper, operands[0], operands[1]); | |||
| } | |||
| }; | |||
| /////////////////////////// ternary op /////////////////////////// | |||
| //! x <= y ? z : ctype(0) | |||
| template <> | |||
| @@ -0,0 +1,248 @@ | |||
| /** | |||
| * \file src/jit/impl/mlir/ir/numerical.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_JIT && MGB_JIT_MLIR | |||
| #include "numerical.h" | |||
| namespace mgb { | |||
| namespace jit { | |||
| mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, | |||
| std::vector<mlir::Value>& coeff) { | |||
| size_t n = coeff.size(); | |||
| if (n == 0) { | |||
| return helper.const_val(0); | |||
| } | |||
| mlir::Value r = coeff[0]; | |||
| for (size_t i = 1; i < n; i++) { | |||
| r = helper.add(helper.mul(r, x), coeff[i]); | |||
| } | |||
| return r; | |||
| } | |||
| // polynomial approximation of arctangent | |||
| // atan(t) = t + c3 * t^3 + c5 * t^5 + ... + c17 * t^17 | |||
| // original paper: | |||
| // https://arxiv.org/pdf/1508.03211.pdf | |||
| mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, | |||
| mlir::Value x) { | |||
| auto atan_poly = [&](mlir::Value t) { | |||
| std::vector<mlir::Value> coeff = { | |||
| helper.const_val(2.90188402868807315826416015625E-3), | |||
| helper.const_val(-1.62907354533672332763671875E-2), | |||
| helper.const_val(4.3082617223262786865234375E-2), | |||
| helper.const_val(-7.5408883392810821533203125E-2), | |||
| helper.const_val(0.1066047251224517822265625), | |||
| helper.const_val(-0.14209578931331634521484375), | |||
| helper.const_val(0.19993579387664794921875), | |||
| helper.const_val(-0.3333314359188079833984375)}; | |||
| auto t2 = helper.mul(t, t); | |||
| auto p = polynomial(helper, t2, coeff); | |||
| return helper.add(helper.mul(helper.mul(p, t2), t), t); | |||
| }; | |||
| // constants | |||
| auto zero = helper.const_val(0); | |||
| auto pi = helper.const_val(3.141592653589793); | |||
| auto pi_over_2 = helper.const_val(1.570796326794897); | |||
| // transform the angle into interval [0, pi/4] | |||
| auto ax = helper.abs(x); | |||
| auto ay = helper.abs(y); | |||
| auto q = helper.div(helper.min(ax, ay), helper.max(ax, ay)); | |||
| // get approximation for interval [0, pi/4] | |||
| auto r = atan_poly(q); | |||
| // [0, pi/4] => [0, pi/2] | |||
| r = helper.select(helper.le(ax, ay), helper.sub(pi_over_2, r), r); | |||
| // [0, pi/2] => [0, pi] | |||
| r = helper.select(helper.le(x, zero), helper.sub(pi, r), r); | |||
| // [0, pi] => [-pi, pi] | |||
| r = helper.select(helper.le(y, zero), helper.sub(zero, r), r); | |||
| return r; | |||
| } | |||
| // numerical approximation of gauss error function | |||
| // https://en.wikipedia.org/wiki/Error_function#Polynomial | |||
| // original book: | |||
| // Numerical Recipes in Fortran 77: The Art of Scientific Computing | |||
| mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||
| auto zero = helper.const_val(0); | |||
| auto one = helper.const_val(1); | |||
| auto half = helper.const_val(0.5); | |||
| auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); | |||
| std::vector<mlir::Value> coeff = { | |||
| helper.const_val(0.17087277), | |||
| helper.const_val(-0.82215223), | |||
| helper.const_val(1.48851587), | |||
| helper.const_val(-1.13520398), | |||
| helper.const_val(0.27886807), | |||
| helper.const_val(-0.18628806), | |||
| helper.const_val(0.09678418), | |||
| helper.const_val(0.37409196), | |||
| helper.const_val(1.00002368), | |||
| helper.const_val(-1.26551223)}; | |||
| auto p = polynomial(helper, t, coeff); | |||
| auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); | |||
| return helper.select(helper.ge(x, zero), | |||
| helper.sub(one, r), | |||
| helper.sub(r, one)); | |||
| } | |||
| // numerical approximation of the inverse of normal distribution function | |||
| // original algorithm: | |||
| // https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c | |||
| // case 1: 0 < x < exp(-2) | |||
| // z = sqrt(-2 * log(x)) | |||
| // t = 1 / z | |||
| // res = log(z) / z - z + t * P(t) / Q(t) | |||
| // where coefficients of P and Q are different | |||
| // for z < 8 and for z >= 8 | |||
| // | |||
| // case2: exp(-2) <= x <= 1 - exp(-2) | |||
| // w = x - 0.5 | |||
| // res = sqrt(2pi) * (w + w^3 * R(w^2) / S(w^2)) | |||
| // | |||
| // case3: 1 - exp(-2) < x < 1 | |||
| // 0 < 1 - x < exp(-2) | |||
| // ndtri(x) = -ndtri(1 - x) | |||
| // fallback to case 1 | |||
| mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||
| // polynomial P | |||
| auto P = [&](mlir::Value i, mlir::Value cond) { | |||
| std::vector<mlir::Value> coeff0 = { | |||
| helper.const_val(4.05544892305962419923E0), | |||
| helper.const_val(3.15251094599893866154E1), | |||
| helper.const_val(5.71628192246421288162E1), | |||
| helper.const_val(4.40805073893200834700E1), | |||
| helper.const_val(1.46849561928858024014E1), | |||
| helper.const_val(2.18663306850790267539E0), | |||
| helper.const_val(-1.40256079171354495875E-1), | |||
| helper.const_val(-3.50424626827848203418E-2), | |||
| helper.const_val(-8.57456785154685413611E-4)}; | |||
| std::vector<mlir::Value> coeff1 = { | |||
| helper.const_val(3.23774891776946035970E0), | |||
| helper.const_val(6.91522889068984211695E0), | |||
| helper.const_val(3.93881025292474443415E0), | |||
| helper.const_val(1.33303460815807542389E0), | |||
| helper.const_val(2.01485389549179081538E-1), | |||
| helper.const_val(1.23716634817820021358E-2), | |||
| helper.const_val(3.01581553508235416007E-4), | |||
| helper.const_val(2.65806974686737550832E-6), | |||
| helper.const_val(6.23974539184983293730E-9)}; | |||
| return helper.select(cond, | |||
| polynomial(helper, i, coeff0), | |||
| polynomial(helper, i, coeff1)); | |||
| }; | |||
| // polynomial Q | |||
| auto Q = [&](mlir::Value i, mlir::Value cond) { | |||
| std::vector<mlir::Value> coeff0 = { | |||
| helper.const_val(1.f), | |||
| helper.const_val(1.57799883256466749731E1), | |||
| helper.const_val(4.53907635128879210584E1), | |||
| helper.const_val(4.13172038254672030440E1), | |||
| helper.const_val(1.50425385692907503408E1), | |||
| helper.const_val(2.50464946208309415979E0), | |||
| helper.const_val(-1.42182922854787788574E-1), | |||
| helper.const_val(-3.80806407691578277194E-2), | |||
| helper.const_val(-9.33259480895457427372E-4)}; | |||
| std::vector<mlir::Value> coeff1 = { | |||
| helper.const_val(1.f), | |||
| helper.const_val(6.02427039364742014255E0), | |||
| helper.const_val(3.67983563856160859403E0), | |||
| helper.const_val(1.37702099489081330271E0), | |||
| helper.const_val(2.16236993594496635890E-1), | |||
| helper.const_val(1.34204006088543189037E-2), | |||
| helper.const_val(3.28014464682127739104E-4), | |||
| helper.const_val(2.89247864745380683936E-6), | |||
| helper.const_val(6.79019408009981274425E-9)}; | |||
| return helper.select(cond, | |||
| polynomial(helper, i, coeff0), | |||
| polynomial(helper, i, coeff1)); | |||
| }; | |||
| // polynomial R | |||
| auto R = [&](mlir::Value i) { | |||
| std::vector<mlir::Value> coeff = { | |||
| helper.const_val(-5.99633501014107895267E1), | |||
| helper.const_val(9.80010754185999661536E1), | |||
| helper.const_val(-5.66762857469070293439E1), | |||
| helper.const_val(1.39312609387279679503E1), | |||
| helper.const_val(-1.23916583867381258016E0)}; | |||
| return polynomial(helper, i, coeff); | |||
| }; | |||
| // polynomial S | |||
| auto S = [&](mlir::Value i) { | |||
| std::vector<mlir::Value> coeff = { | |||
| helper.const_val(1.f), | |||
| helper.const_val(1.95448858338141759834E0), | |||
| helper.const_val(4.67627912898881538453E0), | |||
| helper.const_val(8.63602421390890590575E1), | |||
| helper.const_val(-2.25462687854119370527E2), | |||
| helper.const_val(2.00260212380060660359E2), | |||
| helper.const_val(-8.20372256168333339912E1), | |||
| helper.const_val(1.59056225126211695515E1), | |||
| helper.const_val(-1.18331621121330003142E0)}; | |||
| return polynomial(helper, i, coeff); | |||
| }; | |||
| // constants | |||
| auto zero = helper.const_val(0); | |||
| auto one = helper.const_val(1); | |||
| auto half = helper.const_val(0.5); | |||
| auto eight = helper.const_val(8); | |||
| auto minus_2 = helper.const_val(-2); | |||
| auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) | |||
| auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) | |||
| // conditions | |||
| auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) | |||
| auto case3 = helper.gt(x, helper.sub(one, exp_minus_2)); // x > 1 - exp(-2) | |||
| auto case13 = helper.bit_or(case1, case3); | |||
| // case1 or case3 | |||
| auto x13 = helper.select(case1, x, helper.sub(one, x)); // x or (1 - x) | |||
| auto z = helper.sqrt(helper.mul(minus_2, helper.log(x13))); | |||
| auto z_lt_8 = helper.lt(z, eight); | |||
| auto t = helper.div(one, z); | |||
| auto res1 = helper.add(helper.sub(helper.div(helper.log(z), z), z), | |||
| helper.div(helper.mul(t, P(t, z_lt_8)), Q(t, z_lt_8))); | |||
| auto res13 = helper.select(case1, res1, helper.sub(zero, res1)); | |||
| // case2 | |||
| auto w = helper.sub(x, half); | |||
| auto w2 = helper.mul(w, w); | |||
| auto w3 = helper.mul(w, w2); | |||
| auto res2 = helper.mul( | |||
| sqrt_2pi, helper.add(w, helper.div(helper.mul(w3, R(w2)), S(w2)))); | |||
| return helper.select(case13, res13, res2); | |||
| } | |||
| } // namespace jit | |||
| } // namespace mgb | |||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * \file src/jit/impl/mlir/ir/numerical.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_JIT && MGB_JIT_MLIR | |||
| #include <vector> | |||
| #include "./common.h" | |||
| namespace mgb { | |||
| namespace jit { | |||
| /*! polynomial of degree N: | |||
| * C_0 + C_1 * x + C_2 * x^2 + ... + C_N * x^N | |||
| * where coeff = [C_N, ..., C_2, C_1, C_0] | |||
| */ | |||
| mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, | |||
| std::vector<mlir::Value>& coeff); | |||
| //! numerical approximation of arctangent | |||
| mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, mlir::Value x); | |||
| //! numerical approximation of gauss error function | |||
| mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x); | |||
| //! numerical approximation of the inverse of normal distribution function | |||
| mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x); | |||
| } // namespace jit | |||
| } // namespace mgb | |||
| #endif // MGB_JIT && MGB_JIT_MLIR | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -68,8 +68,8 @@ class ElemwiseUnaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
| def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; | |||
| def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; | |||
| def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; | |||
| /* ACOS */ | |||
| /* ASIN */ | |||
| def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>; | |||
| def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>; | |||
| def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; | |||
| def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; | |||
| def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; | |||
| @@ -83,10 +83,10 @@ def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; | |||
| def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; | |||
| def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; | |||
| def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | |||
| /* ERF */ | |||
| /* ERFINV */ | |||
| /* ERFC */ | |||
| /* ERFCINV */ | |||
| def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>; | |||
| def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>; | |||
| def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>; | |||
| def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>; | |||
| class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
| ElemwiseOp<mnemonic, traits> { | |||
| @@ -130,14 +130,14 @@ def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; | |||
| def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; | |||
| def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; | |||
| def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; | |||
| /* POW */ | |||
| def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>; | |||
| def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; | |||
| def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; | |||
| def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; | |||
| def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; | |||
| def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; | |||
| def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | |||
| /* ATAN2 */ | |||
| def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>; | |||
| class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||
| ElemwiseOp<mnemonic, traits> { | |||
| @@ -159,22 +159,48 @@ void run_mlir(CompNode cn) { | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
| } | |||
| struct MlirTestOpt { | |||
| float low; | |||
| float high; | |||
| float maxerr; | |||
| }; | |||
| struct MlirTestOpt get_mode_opt(opr::Elemwise::Mode mode) { | |||
| struct MlirTestOpt opt = {0, 1, 1e-6}; | |||
| if (mode == opr::Elemwise::Mode::ABS) { | |||
| opt.low = -10; | |||
| opt.high = 10; | |||
| } else if (mode == opr::Elemwise::Mode::LOG) { | |||
| opt.low = 0.1; | |||
| opt.high = 4; | |||
| } else if (mode == opr::Elemwise::Mode::ERF or | |||
| mode == opr::Elemwise::Mode::ERFC) { | |||
| opt.low = -5; | |||
| opt.high = 5; | |||
| } else if (mode == opr::Elemwise::Mode::ERFINV) { | |||
| opt.low = -0.999; | |||
| opt.high = 0.999; | |||
| opt.maxerr = 1e-4; | |||
| } else if (mode == opr::Elemwise::Mode::ERFCINV) { | |||
| opt.low = 0.001; | |||
| opt.high = 1.999; | |||
| opt.maxerr = 1e-4; | |||
| } | |||
| return opt; | |||
| } | |||
| template <typename tag, int arity> | |||
| void run_mlir_mode(CompNode cn) { | |||
| set_backend(Backend::MLIR); | |||
| auto graph = ComputingGraph::make(); | |||
| float low = 0.f, high = 1.f; | |||
| if (tag::mode == opr::Elemwise::Mode::LOG) { | |||
| low = 0.1; | |||
| high = 4; | |||
| } | |||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(low, | |||
| high); | |||
| auto opt = get_mode_opt(tag::mode); | |||
| HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(opt.low, | |||
| opt.high); | |||
| SmallVector<std::shared_ptr<HostTensorND>> hosts; | |||
| VarNodeArray input_vars; | |||
| for (int i = 0; i < arity; i++) { | |||
| hosts.push_back(gen({23, 42}, cn)); | |||
| hosts.push_back(gen({2323, 4242}, cn)); | |||
| input_vars.push_back( | |||
| opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); | |||
| } | |||
| @@ -198,7 +224,7 @@ void run_mlir_mode(CompNode cn) { | |||
| make_callback_copy(y_jit, host_y_jit)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||
| MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr); | |||
| } | |||
| #endif | |||
| @@ -240,18 +266,25 @@ TEST(TestJITMlirCodeGen, BasicGPU) { | |||
| cb(RELU) \ | |||
| cb(ABS) \ | |||
| cb(NEGATE) \ | |||
| cb(ACOS) \ | |||
| cb(ASIN) \ | |||
| cb(CEIL) \ | |||
| cb(EXP) \ | |||
| cb(FLOOR) \ | |||
| cb(LOG) \ | |||
| cb(LOG1P) \ | |||
| cb(SIN) \ | |||
| cb(COS) \ | |||
| cb(TANH) \ | |||
| cb(FAST_TANH) \ | |||
| cb(H_SWISH) \ | |||
| cb(SIGMOID) \ | |||
| cb(EXPM1) \ | |||
| cb(ROUND) | |||
| cb(ROUND) \ | |||
| cb(ERF) \ | |||
| cb(ERFINV) \ | |||
| cb(ERFC) \ | |||
| cb(ERFCINV) | |||
| // clang-format on | |||
| template <typename tag> | |||
| class TestJITMlirUnaryElemwise : public ::testing::Test {}; | |||
| @@ -268,21 +301,27 @@ FOREACH_UNARY_MODE(def_tag) | |||
| ::testing::Types<FOREACH_UNARY_MODE(t) ABS>; | |||
| #undef t | |||
| TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); | |||
| TYPED_TEST(TestJITMlirUnaryElemwise, run) { | |||
| auto cn = CompNode::load("cpu0"); | |||
| run_mlir_mode<TypeParam, 1>(cn); | |||
| } | |||
| #define SKIP_MODE(_mode) \ | |||
| if (TypeParam::mode == opr::Elemwise::Mode::_mode) { \ | |||
| printf("skip\n"); \ | |||
| return; \ | |||
| } | |||
| TYPED_TEST(TestJITMlirUnaryElemwise, run) { | |||
| auto cn = CompNode::load("cpu0"); | |||
| SKIP_MODE(ROUND); | |||
| run_mlir_mode<TypeParam, 1>(cn); | |||
| } | |||
| TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||
| REQUIRE_GPU(1); | |||
| auto cn = CompNode::load("gpu0"); | |||
| SKIP_MODE(SIN); | |||
| SKIP_MODE(ROUND); | |||
| run_mlir_mode<TypeParam, 1>(cn); | |||
| } | |||
| @@ -298,6 +337,7 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||
| cb(MOD) \ | |||
| cb(SUB) \ | |||
| cb(TRUE_DIV) \ | |||
| cb(POW) \ | |||
| cb(ABS_GRAD) \ | |||
| cb(SIGMOID_GRAD) \ | |||
| cb(SWITCH_GT0) \ | |||
| @@ -311,7 +351,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||
| cb(FAST_TANH_GRAD) \ | |||
| cb(FUSE_ADD_SIGMOID) \ | |||
| cb(H_SWISH_GRAD) \ | |||
| cb(FUSE_ADD_H_SWISH) | |||
| cb(FUSE_ADD_H_SWISH) \ | |||
| cb(ATAN2) | |||
| // clang-format on | |||
| template <typename tag> | |||
| class TestJITMlirBinaryElemwise : public ::testing::Test {}; | |||
| @@ -336,6 +377,9 @@ TYPED_TEST(TestJITMlirBinaryElemwise, run) { | |||
| TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | |||
| REQUIRE_GPU(1); | |||
| auto cn = CompNode::load("gpu0"); | |||
| SKIP_MODE(MOD); | |||
| run_mlir_mode<TypeParam, 2>(cn); | |||
| } | |||
| @@ -373,7 +417,7 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { | |||
| #undef SKIP_MODE | |||
| #endif | |||
| #endif // MGB_JIT_MLIR | |||
| #endif // MGB_JIT | |||