| @@ -2,6 +2,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import megengine.autodiff as ad | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.functional.elemwise as elemwise | import megengine.functional.elemwise as elemwise | ||||
| from megengine import tensor | from megengine import tensor | ||||
| @@ -293,3 +294,25 @@ def test_empty_tensor(is_trace): | |||||
| run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | ||||
| run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | ||||
| run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) | run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) | ||||
| @pytest.mark.parametrize("is_trace", [True, False]) | |||||
| def test_maximum_grad_consistency(is_trace): | |||||
| def f(x): | |||||
| with ad.GradManager() as gm: | |||||
| gm.attach(x) | |||||
| gm.backward(F.maximum(x, x)) | |||||
| dx = x.grad | |||||
| x.grad = None | |||||
| return dx | |||||
| def run(f): | |||||
| x = F.arange(10) | |||||
| for i in range(3): | |||||
| np.testing.assert_equal(f(x).numpy(), np.ones(10)) | |||||
| if is_trace: | |||||
| for symbolic in [False, True]: | |||||
| run(trace(symbolic=symbolic)(f)) | |||||
| else: | |||||
| run(f) | |||||
| @@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||||
| // misc | // misc | ||||
| ENTRY(COND_LEQ_MOV, | ENTRY(COND_LEQ_MOV, | ||||
| ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]), | ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]), | ||||
| ENTRY(COND_LT_MOV, | |||||
| ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]), | |||||
| ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), | ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), | ||||
| ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), | ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), | ||||
| ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), | ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), | ||||
| @@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode( | |||||
| // ternary | // ternary | ||||
| case Mode::COND_LEQ_MOV: | case Mode::COND_LEQ_MOV: | ||||
| return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); | return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); | ||||
| case Mode::COND_LT_MOV: | |||||
| return Halide::select(inp(0) < inp(1), inp(2), cv(0)); | |||||
| case Mode::FUSE_MUL_ADD3: | case Mode::FUSE_MUL_ADD3: | ||||
| return inp(0) * inp(1) + inp(2); | return inp(0) * inp(1) + inp(2); | ||||
| case Mode::FUSE_MUL_ADD4: | case Mode::FUSE_MUL_ADD4: | ||||
| @@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>( | |||||
| helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | ||||
| } | } | ||||
| //! COND_LT_MOV: x < y ? z : ctype(0) | |||||
| template <> | |||||
| mlir::Value lower_mode<Mode::COND_LT_MOV>( | |||||
| mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { | |||||
| ValueBuilderHelper helper(builder, loc); | |||||
| return helper.select( | |||||
| helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | |||||
| } | |||||
| //! FUSE_MUL_ADD3: x * y + z | //! FUSE_MUL_ADD3: x * y + z | ||||
| template <> | template <> | ||||
| mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>( | mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>( | ||||
| @@ -60,6 +60,7 @@ | |||||
| #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | ||||
| cb(CondLeqMovOp, COND_LEQ_MOV) \ | cb(CondLeqMovOp, COND_LEQ_MOV) \ | ||||
| cb(CondLtMovOp, COND_LT_MOV) \ | |||||
| cb(FuseMulAdd3Op, FUSE_MUL_ADD3) | cb(FuseMulAdd3Op, FUSE_MUL_ADD3) | ||||
| // clang-format on | // clang-format on | ||||
| @@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | |||||
| // clang-format off | // clang-format off | ||||
| #define FOREACH_TERNARY_MODE(cb) \ | #define FOREACH_TERNARY_MODE(cb) \ | ||||
| cb(COND_LEQ_MOV) \ | cb(COND_LEQ_MOV) \ | ||||
| cb(COND_LT_MOV) \ | |||||
| cb(FUSE_MUL_ADD3) \ | cb(FUSE_MUL_ADD3) \ | ||||
| // clang-format on | // clang-format on | ||||
| template <typename tag> | template <typename tag> | ||||
| @@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) { | |||||
| CHECK_ELEM2(ATAN2, true, gt0); | CHECK_ELEM2(ATAN2, true, gt0); | ||||
| CHECK_ELEM3(COND_LEQ_MOV, false, none); | CHECK_ELEM3(COND_LEQ_MOV, false, none); | ||||
| CHECK_ELEM3(COND_LT_MOV, false, none); | |||||
| CHECK_ELEM3(FUSE_MUL_ADD3, true, none); | CHECK_ELEM3(FUSE_MUL_ADD3, true, none); | ||||
| CHECK_ELEM4(FUSE_MUL_ADD4, true, none); | CHECK_ELEM4(FUSE_MUL_ADD4, true, none); | ||||
| @@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
| case Mode::FLOOR_DIV: | case Mode::FLOOR_DIV: | ||||
| return nullptr; | return nullptr; | ||||
| case Mode::MAX: | case Mode::MAX: | ||||
| RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og)); | |||||
| if (wrt_idx) { | |||||
| RET(EL3(COND_LT_MOV, i[0], i[1], og)); | |||||
| } else { | |||||
| RET(EL3(COND_LEQ_MOV, i[1], i[0], og)); | |||||
| } | |||||
| case Mode::MIN: | case Mode::MIN: | ||||
| RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og)); | |||||
| if (wrt_idx) { | |||||
| RET(EL3(COND_LT_MOV, i[1], i[0], og)); | |||||
| } else { | |||||
| RET(EL3(COND_LEQ_MOV, i[0], i[1], og)); | |||||
| } | |||||
| case Mode::MOD: | case Mode::MOD: | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| RET(og); | RET(og); | ||||
| @@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
| if (wrt_idx <= 1) | if (wrt_idx <= 1) | ||||
| return nullptr; | return nullptr; | ||||
| RET(EL3(COND_LEQ_MOV, i0, i1, og)); | RET(EL3(COND_LEQ_MOV, i0, i1, og)); | ||||
| case Mode::COND_LT_MOV: | |||||
| if (wrt_idx <= 1) | |||||
| return nullptr; | |||||
| RET(EL3(COND_LT_MOV, i0, i1, og)); | |||||
| // fuse oprs | // fuse oprs | ||||
| case Mode::FUSE_MUL_ADD3: | case Mode::FUSE_MUL_ADD3: | ||||
| if (wrt_idx < 2) { | if (wrt_idx < 2) { | ||||
| @@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {}; | |||||
| /* ======================= ternary config ======================= */ | /* ======================= ternary config ======================= */ | ||||
| template <> | template <> | ||||
| struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; | struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; | ||||
| template <> | |||||
| struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {}; | |||||
| /* ======================= test runner ======================= */ | /* ======================= test runner ======================= */ | ||||
| namespace detail { | namespace detail { | ||||
| @@ -13,6 +13,7 @@ | |||||
| #define _ALLOW_FLOAT true | #define _ALLOW_FLOAT true | ||||
| #define _ALLOW_INT true | #define _ALLOW_INT true | ||||
| DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | ||||
| DEF_TRAIT(COND_LT_MOV, x < y ? z : 0) | |||||
| DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) | DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) | ||||
| #undef _ALLOW_INT | #undef _ALLOW_INT | ||||
| #undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
| @@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) { | |||||
| switch (mode) { | switch (mode) { | ||||
| MAKE_TERNARY(FUSE_MUL_ADD3); | MAKE_TERNARY(FUSE_MUL_ADD3); | ||||
| MAKE_TERNARY(COND_LEQ_MOV); | MAKE_TERNARY(COND_LEQ_MOV); | ||||
| MAKE_TERNARY(COND_LT_MOV); | |||||
| default: | default: | ||||
| mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | ||||
| break; | break; | ||||
| @@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) { | |||||
| switch (mode) { | switch (mode) { | ||||
| MAKE_TERNARY(FUSE_MUL_ADD3); | MAKE_TERNARY(FUSE_MUL_ADD3); | ||||
| MAKE_TERNARY(COND_LEQ_MOV); | MAKE_TERNARY(COND_LEQ_MOV); | ||||
| MAKE_TERNARY(COND_LT_MOV); | |||||
| default: | default: | ||||
| mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | ||||
| break; | break; | ||||