| @@ -30,6 +30,6 @@ MODES = { | |||||
| 'FUSE_ADD_H_SWISH'], | 'FUSE_ADD_H_SWISH'], | ||||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
| (1, 'BOOL'): ['NOT'], | (1, 'BOOL'): ['NOT'], | ||||
| (2, 'BOOL'): ['AND', 'OR', 'XOR'], | |||||
| (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||||
| (3, 'BOOL'): [] | (3, 'BOOL'): [] | ||||
| } | } | ||||
| @@ -45,6 +45,9 @@ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | ||||
| @@ -173,6 +173,9 @@ namespace megdnn { | |||||
| DEF_KERN_ALL(LT, x < y); | DEF_KERN_ALL(LT, x < y); | ||||
| DEF_KERN_ALL(LEQ, x <= y); | DEF_KERN_ALL(LEQ, x <= y); | ||||
| DEF_KERN_ALL(EQ, x == y); | DEF_KERN_ALL(EQ, x == y); | ||||
| DEF_KERN(dt_bool, LT, x < y); | |||||
| DEF_KERN(dt_bool, LEQ, x <= y); | |||||
| DEF_KERN(dt_bool, EQ, x == y); | |||||
| DEF_KERN_INT(FLOOR_DIV, x / y); | DEF_KERN_INT(FLOOR_DIV, x / y); | ||||
| DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | ||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -812,6 +812,9 @@ TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !) | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) | ||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) | ||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) | TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) | ||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <) | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=) | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==) | |||||
| TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | ||||
| using Checker = AutoOprChecker<3, 1>; | using Checker = AutoOprChecker<3, 1>; | ||||
| @@ -27,6 +27,13 @@ DEF_TRAIT(OR, x || y) | |||||
| DEF_TRAIT(XOR, x ^ y) | DEF_TRAIT(XOR, x ^ y) | ||||
| #undef _ALLOW_INT | #undef _ALLOW_INT | ||||
| #undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
| #define _ALLOW_INT true | |||||
| #define _ALLOW_FLOAT true | |||||
| DEF_TRAIT(EQ, x == y) | |||||
| DEF_TRAIT(LEQ, x <= y) | |||||
| DEF_TRAIT(LT, x < y) | |||||
| #undef _ALLOW_BOOL | #undef _ALLOW_BOOL | ||||
| #define _ALLOW_BOOL false | #define _ALLOW_BOOL false | ||||
| @@ -44,10 +51,6 @@ DEF_TRAIT(SUB, x - y) | |||||
| DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) | DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0) | ||||
| DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) | DEF_TRAIT(TANH_GRAD, (1 - x * x) * y) | ||||
| DEF_TRAIT(EQ, x == y) | |||||
| DEF_TRAIT(LEQ, x <= y) | |||||
| DEF_TRAIT(LT, x < y) | |||||
| DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0)) | DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0)) | ||||
| #undef _ALLOW_INT | #undef _ALLOW_INT | ||||