| @@ -17,16 +17,16 @@ MODES = { | |||||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | ||||
| 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| } | } | ||||
| QINT4_MODES = { | QINT4_MODES = { | ||||
| 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | ||||
| 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | ||||
| 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
| 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
| 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
| 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
| 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | ||||
| 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
| 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| } | } | ||||
| QINT32_MODES = { | QINT32_MODES = { | ||||
| @@ -16,7 +16,7 @@ MODES = { | |||||
| (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | (2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | ||||
| 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | 'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | ||||
| 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | 'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | ||||
| (3, 'INT'): ['COND_LEQ_MOV'], | |||||
| (3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'], | |||||
| (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | (1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
| 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | 'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | ||||
| @@ -28,7 +28,7 @@ MODES = { | |||||
| 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | 'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | ||||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
| 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'], | ||||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'], | |||||
| (1, 'BOOL'): ['NOT'], | (1, 'BOOL'): ['NOT'], | ||||
| (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | ||||
| (3, 'BOOL'): [] | (3, 'BOOL'): [] | ||||
| @@ -420,6 +420,7 @@ pdef('Elemwise').add_enum( | |||||
| Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), | Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), | ||||
| Doc('GELU = 58', 'unary: x Phi(x)'), | Doc('GELU = 58', 'unary: x Phi(x)'), | ||||
| Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), | ||||
| Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'), | |||||
| ) | ) | ||||
| pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
| @@ -510,7 +511,8 @@ pdef('ElemwiseMultiType').add_enum( | |||||
| 'and the result is float32.'), | 'and the result is float32.'), | ||||
| Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', | Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', | ||||
| 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | ||||
| '``c`` float32, and the result is float32.') | |||||
| '``c`` float32, and the result is float32.'), | |||||
| Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'), | |||||
| ) | ) | ||||
| pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | ||||
| @@ -92,7 +92,9 @@ | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | ||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| @@ -265,6 +265,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); | |||||
| // int and float | // int and float | ||||
| DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); | ||||
| DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0)); | |||||
| DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); | ||||
| #undef KERN_SIG | #undef KERN_SIG | ||||
| @@ -219,6 +219,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| CB_MODE(Mode::SILU_GRAD); | CB_MODE(Mode::SILU_GRAD); | ||||
| CB_MODE(Mode::GELU); | CB_MODE(Mode::GELU); | ||||
| CB_MODE(Mode::GELU_GRAD); | CB_MODE(Mode::GELU_GRAD); | ||||
| CB_MODE(Mode::COND_LT_MOV); | |||||
| default: | default: | ||||
| megdnn_assert( | megdnn_assert( | ||||
| 0, | 0, | ||||
| @@ -239,6 +239,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | SET(init_quantized_ternary_op, QFUSE_MUL_ADD3); | ||||
| SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | SET(init_quantized_ternary_op, QCOND_LEQ_MOV); | ||||
| SET(init_quantized_ternary_op, QCOND_LT_MOV); | |||||
| #undef SET | #undef SET | ||||
| } | } | ||||
| @@ -95,6 +95,7 @@ void ElemwiseMultiTypeImplHelper::exec( | |||||
| ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); | ||||
| ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); | ||||
| ON_QUANTIZED_MODE(COND_LT_MOV, 3); | |||||
| default: | default: | ||||
| megdnn_throw("invalid mode"); | megdnn_throw("invalid mode"); | ||||
| } | } | ||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_qint8 | |||||
| #define KERN_IMPL_DTYPE dt_qint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -25,6 +25,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||||
| DISPATCH(FUSE_MUL_ADD3); | DISPATCH(FUSE_MUL_ADD3); | ||||
| DISPATCH(COND_LEQ_MOV); | DISPATCH(COND_LEQ_MOV); | ||||
| DISPATCH(COND_LT_MOV); | |||||
| #undef DISPATCH | #undef DISPATCH | ||||
| default: | default: | ||||
| megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,7 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float16 | |||||
| #include "../kern_impl.inl" | |||||
| #endif | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_float32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int16 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_int8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,5 @@ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_CTYPE dt_uint8 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -179,6 +179,35 @@ DEF_TEST(ternary_non_contig) { | |||||
| checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | ||||
| } | } | ||||
| DEF_TEST(ternary_lt) { | |||||
| using Mode = ElemwiseForward::Param::Mode; | |||||
| Checker<ElemwiseForward> checker(handle); | |||||
| checker.set_param(Mode::COND_LT_MOV); | |||||
| checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
| checker.set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
| checker.set_dtype(0, dtype::Float16()) | |||||
| .set_dtype(1, dtype::Float16()) | |||||
| .set_dtype(2, dtype::Float16()) | |||||
| .set_dtype(3, dtype::Float16()) | |||||
| .execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}}); | |||||
| checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}}); | |||||
| checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}}); | |||||
| ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError); | |||||
| ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError); | |||||
| } | |||||
| DEF_TEST(ternary_lt_non_contig) { | |||||
| using Mode = ElemwiseForward::Param::Mode; | |||||
| Checker<ElemwiseForward> checker(handle); | |||||
| checker.set_param(Mode::COND_LT_MOV); | |||||
| TensorLayout ly{{2, 3}, dtype::Float32()}; | |||||
| ly.stride[0] = 4; | |||||
| checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}}); | |||||
| } | |||||
| DEF_TEST(fuse_mul_add3) { | DEF_TEST(fuse_mul_add3) { | ||||
| using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
| Checker<ElemwiseForward> checker(handle); | Checker<ElemwiseForward> checker(handle); | ||||
| @@ -16,6 +16,8 @@ namespace elemwise { | |||||
| cb(binary_non_contig) \ | cb(binary_non_contig) \ | ||||
| cb(ternary) \ | cb(ternary) \ | ||||
| cb(ternary_non_contig) \ | cb(ternary_non_contig) \ | ||||
| cb(ternary_lt) \ | |||||
| cb(ternary_lt_non_contig) \ | |||||
| cb(fuse_mul_add3) \ | cb(fuse_mul_add3) \ | ||||
| cb(fuse_mul_add3_non_contig) \ | cb(fuse_mul_add3_non_contig) \ | ||||
| cb(fuse_mul_add4) \ | cb(fuse_mul_add4) \ | ||||
| @@ -207,7 +207,7 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) { | |||||
| using Mode = ElemwiseMultiType::Param::Mode; | using Mode = ElemwiseMultiType::Param::Mode; | ||||
| Checker<ElemwiseMultiType> checker(handle_cuda()); | Checker<ElemwiseMultiType> checker(handle_cuda()); | ||||
| for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { | |||||
| for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { | |||||
| UniformIntRNG rng_int8{-127, 127}; | UniformIntRNG rng_int8{-127, 127}; | ||||
| UniformIntRNG rng_uint8{0, 225}; | UniformIntRNG rng_uint8{0, 225}; | ||||
| checker.set_param({mode}) | checker.set_param({mode}) | ||||
| @@ -368,7 +368,7 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_QUANTIZED_MODE_TENARY) { | |||||
| CUBenchmarker<ElemwiseMultiType> bencher(handle_cuda()); | CUBenchmarker<ElemwiseMultiType> bencher(handle_cuda()); | ||||
| UniformIntRNG rng{-128, 127}; | UniformIntRNG rng{-128, 127}; | ||||
| for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) { | |||||
| for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) { | |||||
| printf("Benchmark mode: %d\n", (int)mode); | printf("Benchmark mode: %d\n", (int)mode); | ||||
| bencher.set_param({mode}) | bencher.set_param({mode}) | ||||
| .set_rng(0, &rng) | .set_rng(0, &rng) | ||||
| @@ -59,6 +59,7 @@ Elemwise::Mode get_elem_mode(ElemwiseMultiType::Mode mode) { | |||||
| MODE(FAST_TANH_GRAD); | MODE(FAST_TANH_GRAD); | ||||
| MODE(ATAN2); | MODE(ATAN2); | ||||
| MODE(COND_LEQ_MOV); | MODE(COND_LEQ_MOV); | ||||
| MODE(COND_LT_MOV); | |||||
| MODE(H_SWISH_GRAD); | MODE(H_SWISH_GRAD); | ||||
| MODE(FUSE_ADD_H_SWISH); | MODE(FUSE_ADD_H_SWISH); | ||||
| @@ -231,7 +232,9 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||||
| .set_dtype(1, dtype::QuantizedS8(0.2f)) | .set_dtype(1, dtype::QuantizedS8(0.2f)) | ||||
| .set_dtype(2, dtype::QuantizedS8(0.3f)); | .set_dtype(2, dtype::QuantizedS8(0.3f)); | ||||
| for (auto mode : {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV}) { | |||||
| for (auto mode : | |||||
| {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV, | |||||
| Param::Mode::QCOND_LT_MOV}) { | |||||
| Param param{mode}; | Param param{mode}; | ||||
| checker.set_param(param); | checker.set_param(param); | ||||