| @@ -52,6 +52,7 @@ namespace megdnn { | |||||
| MEGDNN_INC_FLOAT16(cb(Float16)) \ | MEGDNN_INC_FLOAT16(cb(Float16)) \ | ||||
| MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | ||||
| cb(UintB4) \ | cb(UintB4) \ | ||||
| cb(Bool) \ | |||||
| /*! | /*! | ||||
| * \brief iterate through each full byte dtype | * \brief iterate through each full byte dtype | ||||
| @@ -65,6 +66,7 @@ namespace megdnn { | |||||
| cb(Byte) \ | cb(Byte) \ | ||||
| MEGDNN_INC_FLOAT16(cb(Float16)) \ | MEGDNN_INC_FLOAT16(cb(Float16)) \ | ||||
| MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | ||||
| cb(Bool) \ | |||||
| /*! | /*! | ||||
| * \brief iterate through each fractional byte dtype | * \brief iterate through each fractional byte dtype | ||||
| @@ -122,7 +124,7 @@ namespace megdnn { | |||||
| */ | */ | ||||
| #define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ | #define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ | |||||
| //! In order to avoid an unnecessary increase in binary size, we just | //! In order to avoid an unnecessary increase in binary size, we just | ||||
| //! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add | //! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add | ||||
| @@ -348,6 +350,7 @@ typedef int32_t dt_int32; | |||||
| typedef int16_t dt_int16; | typedef int16_t dt_int16; | ||||
| typedef int8_t dt_int8; | typedef int8_t dt_int8; | ||||
| typedef uint8_t dt_uint8; | typedef uint8_t dt_uint8; | ||||
| typedef bool dt_bool; | |||||
| MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | ||||
| MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | ||||
| @@ -375,7 +378,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| BFloat16 = 11, | BFloat16 = 11, | ||||
| #endif | #endif | ||||
| Bool = 12, | |||||
| #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | ||||
| #define D(_name) _name, | #define D(_name) _name, | ||||
| MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) | MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) | ||||
| @@ -392,7 +395,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||||
| #if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
| //! dtype numeric category fo | //! dtype numeric category fo | ||||
| enum class DTypeCategory: int { | enum class DTypeCategory: int { | ||||
| OTHER, FLOAT, INT, LOWBIT, QUANTIZED | |||||
| OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL | |||||
| }; | }; | ||||
| //! dtype signedness | //! dtype signedness | ||||
| enum class DTypeSignedness: int { | enum class DTypeSignedness: int { | ||||
| @@ -401,7 +404,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||||
| #else | #else | ||||
| struct DTypeCategory { | struct DTypeCategory { | ||||
| enum Ev { | enum Ev { | ||||
| OTHER, FLOAT, INT, LOWBIT, QUANTIZED | |||||
| OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL | |||||
| }; | }; | ||||
| int ev; | int ev; | ||||
| }; | }; | ||||
| @@ -707,6 +710,7 @@ MEGDNN_DEF_DT(Int32, dt_int32, INT, SIGNED, INT32_MIN, INT32_MAX); | |||||
| MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); | MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); | ||||
| MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); | MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); | ||||
| MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | ||||
| MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true); | |||||
| MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | ||||
| std::numeric_limits<dt_float16>::lowest(), | std::numeric_limits<dt_float16>::lowest(), | ||||
| std::numeric_limits<dt_float16>::max())); | std::numeric_limits<dt_float16>::max())); | ||||
| @@ -39,11 +39,12 @@ class ElemwiseForward: public OperatorBase { | |||||
| bool commutable; //!< whether arity == 2 and inputs commutable | bool commutable; //!< whether arity == 2 and inputs commutable | ||||
| bool allow_int; //!< whether int inputs allowed | bool allow_int; //!< whether int inputs allowed | ||||
| bool allow_float; //!< whether float inputs allowed | bool allow_float; //!< whether float inputs allowed | ||||
| bool allow_bool; //!< whether bool inputs allowed | |||||
| const char* name; //!< name of the mode | const char* name; //!< name of the mode | ||||
| ModeTrait(): | ModeTrait(): | ||||
| arity(0), commutable(0), allow_int(0), allow_float(0), | |||||
| arity(0), commutable(0), allow_int(0), allow_float(0), allow_bool(0), | |||||
| name(NULL) | name(NULL) | ||||
| {} | {} | ||||
| @@ -5,6 +5,7 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), | |||||
| 'dt_uint8': ('Uint8', 'INT'), | 'dt_uint8': ('Uint8', 'INT'), | ||||
| 'dt_int8': ('Int8', 'INT'), | 'dt_int8': ('Int8', 'INT'), | ||||
| 'dt_int16': ('Int16', 'INT'), | 'dt_int16': ('Int16', 'INT'), | ||||
| 'dt_bool': ('Bool', 'BOOL'), | |||||
| 'dt_float32': ('Float32', 'FLOAT'), | 'dt_float32': ('Float32', 'FLOAT'), | ||||
| 'dt_float16': ('Float16', 'FLOAT'), | 'dt_float16': ('Float16', 'FLOAT'), | ||||
| 'dt_bfloat16': ('BFloat16', 'FLOAT') | 'dt_bfloat16': ('BFloat16', 'FLOAT') | ||||
| @@ -28,4 +29,7 @@ MODES = { | |||||
| 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | ||||
| 'FUSE_ADD_H_SWISH'], | 'FUSE_ADD_H_SWISH'], | ||||
| (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | (3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
| (1, 'BOOL'): ['NOT'], | |||||
| (2, 'BOOL'): ['AND', 'OR', 'XOR'], | |||||
| (3, 'BOOL'): [] | |||||
| } | } | ||||
| @@ -314,7 +314,12 @@ pdef('Elemwise').add_enum( | |||||
| Doc('ERFCINV', 'unary: inverse function of erfc(x)'), | Doc('ERFCINV', 'unary: inverse function of erfc(x)'), | ||||
| Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'), | Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'), | ||||
| Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), | Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), | ||||
| Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)') | |||||
| Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)'), | |||||
| Doc('NOT', 'unary: !x'), | |||||
| Doc('AND', 'binary: x && y'), | |||||
| Doc('OR', 'binary: x || y'), | |||||
| Doc('XOR', 'binary: x ^ y') | |||||
| ) | ) | ||||
| pdef('ElemwiseMultiType').add_enum( | pdef('ElemwiseMultiType').add_enum( | ||||
| @@ -68,6 +68,7 @@ namespace cond_take { | |||||
| #define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype) | #define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype) | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f) | MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f) | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i) | MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i) | ||||
| inst_eq_i(::megdnn::dtype::Bool) | |||||
| #undef inst_eq_f | #undef inst_eq_f | ||||
| #undef inst_eq_i | #undef inst_eq_i | ||||
| @@ -9,6 +9,9 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| // generated by gen_elemwise_each_mode.py | // generated by gen_elemwise_each_mode.py | ||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) \ | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | ||||
| @@ -38,6 +41,11 @@ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ | ||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ | |||||
| MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ | |||||
| #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | ||||
| MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ | MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ | ||||
| @@ -139,6 +139,7 @@ namespace megdnn { | |||||
| DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); | ||||
| // int only | // int only | ||||
| DEF_KERN(dt_bool, NOT, x ^ 1); | |||||
| #undef KERN_SIG | #undef KERN_SIG | ||||
| @@ -156,6 +157,9 @@ namespace megdnn { | |||||
| DEF_KERN_ALL(MAX, x > y ? x : y); | DEF_KERN_ALL(MAX, x > y ? x : y); | ||||
| DEF_KERN_ALL(MIN, x < y ? x : y); | DEF_KERN_ALL(MIN, x < y ? x : y); | ||||
| DEF_KERN_ALL(MUL, x* y); | DEF_KERN_ALL(MUL, x* y); | ||||
| DEF_KERN(dt_bool, AND, x && y); | |||||
| DEF_KERN(dt_bool, OR, x || y); | |||||
| DEF_KERN(dt_bool, XOR, x ^ y); | |||||
| DEF_KERN_INT(RMULH, round_mulh_saturate(x, y)); | DEF_KERN_INT(RMULH, round_mulh_saturate(x, y)); | ||||
| DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y); | DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y); | ||||
| DEF_KERN_ALL(SUB, x - y); | DEF_KERN_ALL(SUB, x - y); | ||||
| @@ -72,6 +72,15 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | ||||
| #undef cb | #undef cb | ||||
| #define cb(_m) \ | |||||
| MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ | |||||
| get(Mode::_m).allow_bool = true; \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); | |||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); | |||||
| #undef cb | |||||
| #define cb(_m) \ | #define cb(_m) \ | ||||
| MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ | MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ | ||||
| auto&& t = get(Mode::_m); \ | auto&& t = get(Mode::_m); \ | ||||
| @@ -82,10 +91,12 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| #define _a 1 | #define _a 1 | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); | |||||
| #undef _a | #undef _a | ||||
| #define _a 2 | #define _a 2 | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb); | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb); | |||||
| #undef _a | #undef _a | ||||
| #define _a 3 | #define _a 3 | ||||
| MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb); | ||||
| @@ -98,6 +109,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| auto&& t = get(Mode::_m); \ | auto&& t = get(Mode::_m); \ | ||||
| t.allow_int = true; \ | t.allow_int = true; \ | ||||
| t.allow_float = true; \ | t.allow_float = true; \ | ||||
| t.allow_bool = true; \ | |||||
| t.arity = _arity; \ | t.arity = _arity; \ | ||||
| t.name = megdnn_mangle(#_m); \ | t.name = megdnn_mangle(#_m); \ | ||||
| } \ | } \ | ||||
| @@ -129,7 +141,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||||
| #if MEGDNN_ELEMWISE_MODE_ENABLE_ALL | #if MEGDNN_ELEMWISE_MODE_ENABLE_ALL | ||||
| for (auto&& i : traits) { | for (auto&& i : traits) { | ||||
| megdnn_assert(i.arity && (i.allow_int || i.allow_float) && | |||||
| megdnn_assert(i.arity && (i.allow_int || i.allow_float || i.allow_bool) && | |||||
| (!i.commutable || i.arity == 2)); | (!i.commutable || i.arity == 2)); | ||||
| } | } | ||||
| #else | #else | ||||
| @@ -282,6 +294,10 @@ void ElemwiseForward::check_dtype(DType dtype) { | |||||
| megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", | megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", | ||||
| trait.name); | trait.name); | ||||
| break; | break; | ||||
| case DTypeCategory::BOOL: | |||||
| megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n", | |||||
| trait.name); | |||||
| break; | |||||
| default: | default: | ||||
| megdnn_throw("bad dtype"); | megdnn_throw("bad dtype"); | ||||
| } | } | ||||
| @@ -15,6 +15,15 @@ | |||||
| template<int arity> | template<int arity> | ||||
| void ElemwiseForwardImpl::on_arity_dispatched() { | void ElemwiseForwardImpl::on_arity_dispatched() { | ||||
| auto src = make_elemwise_op_param<arity>(); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) | |||||
| on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) | |||||
| megdnn_throw("bad dtype"); | |||||
| } | |||||
| template<int arity> | |||||
| void ElemwiseForwardImpl::on_arity_dispatched_no_bool() { | |||||
| auto src = make_elemwise_op_param<arity>(); | auto src = make_elemwise_op_param<arity>(); | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) | MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) | MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) | ||||
| @@ -45,6 +54,14 @@ IMPL_MODE_DISPATCHER(2, DTypeCategory::FLOAT); | |||||
| IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT); | IMPL_MODE_DISPATCHER(3, DTypeCategory::FLOAT); | ||||
| #undef FOREACH | #undef FOREACH | ||||
| #define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL | |||||
| IMPL_MODE_DISPATCHER(1, DTypeCategory::BOOL); | |||||
| #undef FOREACH | |||||
| #define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL | |||||
| IMPL_MODE_DISPATCHER(2, DTypeCategory::BOOL); | |||||
| #undef FOREACH | |||||
| void ElemwiseForwardImpl::exec( | void ElemwiseForwardImpl::exec( | ||||
| const TensorNDArray &src, | const TensorNDArray &src, | ||||
| _megdnn_tensor_out dst) { | _megdnn_tensor_out dst) { | ||||
| @@ -97,8 +114,8 @@ void ElemwiseForwardImpl::exec( | |||||
| #define D(_n) case _n: return on_arity_dispatched<_n>() | #define D(_n) case _n: return on_arity_dispatched<_n>() | ||||
| D(1); | D(1); | ||||
| D(2); | D(2); | ||||
| D(3); | |||||
| #undef D | #undef D | ||||
| case 3: return on_arity_dispatched_no_bool<3>(); | |||||
| default: | default: | ||||
| megdnn_throw("bad size of input tensors"); | megdnn_throw("bad size of input tensors"); | ||||
| } | } | ||||
| @@ -13,6 +13,9 @@ | |||||
| template<int arity> | template<int arity> | ||||
| void on_arity_dispatched(); | void on_arity_dispatched(); | ||||
| template<int arity> | |||||
| void on_arity_dispatched_no_bool(); | |||||
| template<int arity, DTypeCategory dtype_cat, typename ctype> | template<int arity, DTypeCategory dtype_cat, typename ctype> | ||||
| struct ModeDispatcher; | struct ModeDispatcher; | ||||
| @@ -19,10 +19,12 @@ void TypeCvt::check_exec(const TensorLayout &src, const TensorLayout &dst) { | |||||
| megdnn_assert_eq_shape(src, dst); | megdnn_assert_eq_shape(src, dst); | ||||
| auto cat = src.dtype.category(); | auto cat = src.dtype.category(); | ||||
| megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || | megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || | ||||
| cat == DTypeCategory::QUANTIZED); | |||||
| cat == DTypeCategory::QUANTIZED || | |||||
| cat == DTypeCategory::BOOL); | |||||
| cat = dst.dtype.category(); | cat = dst.dtype.category(); | ||||
| megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || | megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || | ||||
| cat == DTypeCategory::QUANTIZED); | |||||
| cat == DTypeCategory::QUANTIZED || | |||||
| cat == DTypeCategory::BOOL); | |||||
| } | } | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/cond_take/kimpl/dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_cond_take_kern_impls.py | |||||
| #include "../kern.inl" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace cond_take { | |||||
| inst_genidx(::megdnn::dtype::Bool) | |||||
| #undef inst_genidx | |||||
| inst_copy(::megdnn::dtype::Bool) | |||||
| #undef inst_copy | |||||
| #undef inst_copy_ | |||||
| } // cond_take | |||||
| } // cuda | |||||
| } // megdnn | |||||
| @@ -25,8 +25,9 @@ namespace cuda { | |||||
| 1, KernImpl, | 1, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| !std::is_same<typename KernImpl::ctype, dt_int8>::value && | !std::is_same<typename KernImpl::ctype, dt_int8>::value && | ||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| !std::is_same<typename KernImpl::ctype, dt_uint8>::value && | |||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| ctype* dst; | ctype* dst; | ||||
| @@ -41,8 +42,9 @@ namespace cuda { | |||||
| 2, KernImpl, | 2, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| !std::is_same<typename KernImpl::ctype, dt_int8>::value && | !std::is_same<typename KernImpl::ctype, dt_int8>::value && | ||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| !std::is_same<typename KernImpl::ctype, dt_uint8>::value && | |||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| ctype* dst; | ctype* dst; | ||||
| @@ -57,8 +59,9 @@ namespace cuda { | |||||
| 3, KernImpl, | 3, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| !std::is_same<typename KernImpl::ctype, dt_int8>::value && | !std::is_same<typename KernImpl::ctype, dt_int8>::value && | ||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| !std::is_same<typename KernImpl::ctype, dt_uint8>::value && | |||||
| !std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| ctype* dst; | ctype* dst; | ||||
| @@ -74,8 +77,9 @@ namespace cuda { | |||||
| 1, KernImpl, | 1, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<typename KernImpl::ctype, dt_int8>::value || | std::is_same<typename KernImpl::ctype, dt_int8>::value || | ||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| std::is_same<typename KernImpl::ctype, dt_uint8>::value || | |||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | ||||
| typedef typename VectTypeTrait::vect_type vect_type; | typedef typename VectTypeTrait::vect_type vect_type; | ||||
| @@ -99,8 +103,9 @@ namespace cuda { | |||||
| 2, KernImpl, | 2, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<typename KernImpl::ctype, dt_int8>::value || | std::is_same<typename KernImpl::ctype, dt_int8>::value || | ||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| std::is_same<typename KernImpl::ctype, dt_uint8>::value || | |||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | ||||
| typedef typename VectTypeTrait::vect_type vect_type; | typedef typename VectTypeTrait::vect_type vect_type; | ||||
| @@ -126,8 +131,9 @@ namespace cuda { | |||||
| 3, KernImpl, | 3, KernImpl, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<typename KernImpl::ctype, dt_int8>::value || | std::is_same<typename KernImpl::ctype, dt_int8>::value || | ||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_uint8>::value>::type> { | |||||
| std::is_same<typename KernImpl::ctype, dt_uint8>::value || | |||||
| std::is_same<typename KernImpl::ctype, | |||||
| dt_bool>::value>::type> { | |||||
| typedef typename KernImpl::ctype ctype; | typedef typename KernImpl::ctype ctype; | ||||
| using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | using VectTypeTrait = elemwise_intl::VectTypeTrait<ctype>; | ||||
| typedef typename VectTypeTrait::vect_type vect_type; | typedef typename VectTypeTrait::vect_type vect_type; | ||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -169,6 +169,9 @@ INST_FOR_CTYPE | |||||
| #define ct dt_qint32 | #define ct dt_qint32 | ||||
| INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
| #undef ct | #undef ct | ||||
| #define ct dt_bool | |||||
| INST_FOR_CTYPE | |||||
| #undef ct | |||||
| #undef INST_FOR_CTYPE | #undef INST_FOR_CTYPE | ||||
| #undef INST | #undef INST | ||||
| @@ -216,6 +219,9 @@ INST_FOR_CTYPE | |||||
| #define ct dt_qint32 | #define ct dt_qint32 | ||||
| INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
| #undef ct | #undef ct | ||||
| #define ct dt_bool | |||||
| INST_FOR_CTYPE | |||||
| #undef ct | |||||
| #undef ndim_cb | #undef ndim_cb | ||||
| @@ -225,6 +231,7 @@ INST_FOR_CTYPE | |||||
| #define INST(dt_ibyte) template class ParamVectVisitor<4, dt_ibyte, BCAST_1010> | #define INST(dt_ibyte) template class ParamVectVisitor<4, dt_ibyte, BCAST_1010> | ||||
| INST(dt_int8); | INST(dt_int8); | ||||
| INST(dt_uint8); | INST(dt_uint8); | ||||
| INST(dt_bool); | |||||
| INST(dt_qint8); | INST(dt_qint8); | ||||
| INST(dt_quint8); | INST(dt_quint8); | ||||
| #undef dt_ibyte | #undef dt_ibyte | ||||
| @@ -102,6 +102,7 @@ INST(dt_float16, half4); | |||||
| INST(dt_bfloat16, bhalf4); | INST(dt_bfloat16, bhalf4); | ||||
| INST(dt_int32, int4); | INST(dt_int32, int4); | ||||
| INST(dt_int16, short4); | INST(dt_int16, short4); | ||||
| INST(dt_bool, uchar4); | |||||
| #undef as_raw | #undef as_raw | ||||
| #define as_raw(x) x.as_int8() | #define as_raw(x) x.as_int8() | ||||
| INST(dt_qint8, char4); | INST(dt_qint8, char4); | ||||
| @@ -454,6 +455,7 @@ INST_DT_IBYTE(dt_int8); | |||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | |||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| #undef DEVICE_WRAPPER | #undef DEVICE_WRAPPER | ||||
| #undef INST_PARAM_VECT_VISITOR | #undef INST_PARAM_VECT_VISITOR | ||||
| @@ -913,6 +915,7 @@ INST_DT_IBYTE(dt_int8); | |||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | |||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| //! implement general case by UserOpInvokerToSameNdim | //! implement general case by UserOpInvokerToSameNdim | ||||
| @@ -1259,6 +1262,7 @@ INST_DT_IBYTE(dt_int8); | |||||
| INST_DT_IBYTE(dt_uint8); | INST_DT_IBYTE(dt_uint8); | ||||
| INST_DT_IBYTE(dt_qint8); | INST_DT_IBYTE(dt_qint8); | ||||
| INST_DT_IBYTE(dt_quint8); | INST_DT_IBYTE(dt_quint8); | ||||
| INST_DT_IBYTE(dt_bool); | |||||
| #undef INST_DT_IBYTE | #undef INST_DT_IBYTE | ||||
| #endif | #endif | ||||
| @@ -62,7 +62,8 @@ template <typename ctype_dest, typename ctype_src> | |||||
| struct TypeCvtOp<ctype_dest, ctype_src, | struct TypeCvtOp<ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_int8>::value || | std::is_same<ctype_src, dt_int8>::value || | ||||
| std::is_same<ctype_src, dt_uint8>::value>::type> { | |||||
| std::is_same<ctype_src, dt_uint8>::value || | |||||
| std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | ||||
| using dst_vect_type = typename VectTypeTrait<ctype_dest>::vect_type; | using dst_vect_type = typename VectTypeTrait<ctype_dest>::vect_type; | ||||
| @@ -85,7 +86,8 @@ struct TypeCvtOpToQuantized< | |||||
| ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_int8>::value || | std::is_same<ctype_src, dt_int8>::value || | ||||
| std::is_same<ctype_src, dt_uint8>::value>::type> { | |||||
| std::is_same<ctype_src, dt_uint8>::value || | |||||
| std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_dest> param; | CudaDTypeParam<ctype_dest> param; | ||||
| using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | ||||
| @@ -109,7 +111,8 @@ struct TypeCvtOpFromQuantized< | |||||
| ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_qint8>::value || | std::is_same<ctype_src, dt_qint8>::value || | ||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
| std::is_same<ctype_src, dt_quint8>::value || | |||||
| std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_src> param; | CudaDTypeParam<ctype_src> param; | ||||
| using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | ||||
| @@ -137,7 +140,8 @@ struct TypeCvtOpBetweenQuantized< | |||||
| ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
| typename std::enable_if< | typename std::enable_if< | ||||
| std::is_same<ctype_src, dt_qint8>::value || | std::is_same<ctype_src, dt_qint8>::value || | ||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
| std::is_same<ctype_src, dt_quint8>::value || | |||||
| std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
| ctype_dest* dest; | ctype_dest* dest; | ||||
| CudaDTypeParam<ctype_src> src_param; | CudaDTypeParam<ctype_src> src_param; | ||||
| CudaDTypeParam<ctype_dest> dst_param; | CudaDTypeParam<ctype_dest> dst_param; | ||||
| @@ -243,6 +247,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
| cb(dtype_src, dt_float32) \ | cb(dtype_src, dt_float32) \ | ||||
| cb(dtype_src, dt_float16) \ | cb(dtype_src, dt_float16) \ | ||||
| cb(dtype_src, dt_bfloat16) \ | cb(dtype_src, dt_bfloat16) \ | ||||
| cb(dtype_src, dt_bool) \ | |||||
| #define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | #define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | ||||
| cb(dtype_src, dt_quint8) \ | cb(dtype_src, dt_quint8) \ | ||||
| @@ -265,6 +270,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
| cb(dt_float32) \ | cb(dt_float32) \ | ||||
| cb(dt_float16) \ | cb(dt_float16) \ | ||||
| cb(dt_bfloat16) \ | cb(dt_bfloat16) \ | ||||
| cb(dt_bool) \ | |||||
| #define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \ | #define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \ | ||||
| cb(dt_quint8) \ | cb(dt_quint8) \ | ||||
| @@ -138,7 +138,8 @@ void do_cvt_s8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| dctype* __restrict dptr = dst.ptr<dctype>(); | dctype* __restrict dptr = dst.ptr<dctype>(); | ||||
| float scale = src.layout.dtype.param<dtype::QuantizedS8>().scale; | float scale = src.layout.dtype.param<dtype::QuantizedS8>().scale; | ||||
| for (size_t i = 0; i < n; ++i) { | for (size_t i = 0; i < n; ++i) { | ||||
| dptr[i] = static_cast<dctype>(sptr[i] * scale); | |||||
| auto val = sptr[i] * scale; | |||||
| dptr[i] = static_cast<dctype>(val); | |||||
| } | } | ||||
| } | } | ||||
| @@ -150,7 +151,8 @@ void do_cvt_s32_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| dctype* __restrict dptr = dst.ptr<dctype>(); | dctype* __restrict dptr = dst.ptr<dctype>(); | ||||
| float scale = src.layout.dtype.param<dtype::QuantizedS32>().scale; | float scale = src.layout.dtype.param<dtype::QuantizedS32>().scale; | ||||
| for (size_t i = 0; i < n; ++i) { | for (size_t i = 0; i < n; ++i) { | ||||
| dptr[i] = static_cast<dctype>(sptr[i] * scale); | |||||
| auto val = sptr[i] * scale; | |||||
| dptr[i] = static_cast<dctype>(val); | |||||
| } | } | ||||
| } | } | ||||
| @@ -163,7 +165,8 @@ void do_cvt_asymm8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| float scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale; | float scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale; | ||||
| uint8_t zp = src.layout.dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zp = src.layout.dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
| for (size_t i = 0; i < n; ++i) { | for (size_t i = 0; i < n; ++i) { | ||||
| dptr[i] = static_cast<dctype>((sptr[i] - zp) * scale); | |||||
| auto val = (sptr[i] - zp) * scale; | |||||
| dptr[i] = static_cast<dctype>(val); | |||||
| } | } | ||||
| } | } | ||||
| @@ -310,6 +313,7 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| break; \ | break; \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| case DTypeEnum::QuantizedS8: | case DTypeEnum::QuantizedS8: | ||||
| MIDOUT_BEGIN(megdnn_fb_typecvt_src_dtype, | MIDOUT_BEGIN(megdnn_fb_typecvt_src_dtype, | ||||
| midout_iv(DTypeEnum::QuantizedS8)) { | midout_iv(DTypeEnum::QuantizedS8)) { | ||||
| @@ -467,6 +471,7 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| case DTypeEnum::QuantizedS8: | case DTypeEnum::QuantizedS8: | ||||
| MIDOUT_BEGIN(megdnn_fb_typecvt_dst_dtype, | MIDOUT_BEGIN(megdnn_fb_typecvt_dst_dtype, | ||||
| midout_iv(DTypeEnum::QuantizedS8)) { | midout_iv(DTypeEnum::QuantizedS8)) { | ||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,15 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| // generated by gen_elemwise_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_CTYPE dt_bool | |||||
| #include "../kern_impl.inl" | |||||
| @@ -82,6 +82,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| #undef cb | #undef cb | ||||
| default: | default: | ||||
| megdnn_throw("bad dtype"); | megdnn_throw("bad dtype"); | ||||
| @@ -103,6 +104,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| #undef cb | #undef cb | ||||
| default: | default: | ||||
| megdnn_throw("bad dtype"); | megdnn_throw("bad dtype"); | ||||
| @@ -942,6 +942,8 @@ TEST(TEST_ELEMWISE, MODE_TRAIT) { | |||||
| ASSERT_TRUE(T::from_mode(M::RMULH).commutable); | ASSERT_TRUE(T::from_mode(M::RMULH).commutable); | ||||
| ASSERT_FALSE(T::from_mode(M::RMULH).allow_float); | ASSERT_FALSE(T::from_mode(M::RMULH).allow_float); | ||||
| ASSERT_TRUE(T::from_mode(M::XOR).allow_bool); | |||||
| } | } | ||||
| } // namespace elemwise | } // namespace elemwise | ||||
| @@ -916,6 +916,7 @@ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) { | |||||
| case DTypeEnum::QuantizedS4: | case DTypeEnum::QuantizedS4: | ||||
| case DTypeEnum::Byte: | case DTypeEnum::Byte: | ||||
| case DTypeEnum::QuantizedS16: | case DTypeEnum::QuantizedS16: | ||||
| case DTypeEnum::Bool: | |||||
| break; | break; | ||||
| #define cb(low_bit, size) \ | #define cb(low_bit, size) \ | ||||
| case DTypeEnum::low_bit##size: \ | case DTypeEnum::low_bit##size: \ | ||||
| @@ -27,6 +27,7 @@ using ::megdnn::dt_int32; | |||||
| using ::megdnn::dt_quint8; | using ::megdnn::dt_quint8; | ||||
| using ::megdnn::dt_qint8; | using ::megdnn::dt_qint8; | ||||
| using ::megdnn::dt_qint32; | using ::megdnn::dt_qint32; | ||||
| using ::megdnn::dt_bool; | |||||
| using ::megdnn::DType; | using ::megdnn::DType; | ||||
| using ::megdnn::DTypeEnum; | using ::megdnn::DTypeEnum; | ||||
| using ::megdnn::DTypeTrait; | using ::megdnn::DTypeTrait; | ||||
| @@ -145,9 +145,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||||
| 0.f}) / | 0.f}) / | ||||
| 6.f), | 6.f), | ||||
| }; | }; | ||||
| mgb_assert(map.size() + 8 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
| mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||||
| // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | ||||
| // ERFINV, ERFCINV | |||||
| // ERFINV, ERFCINV, NOT, AND, OR, XOR | |||||
| return map; | return map; | ||||
| #undef ADD_OPR | #undef ADD_OPR | ||||
| } | } | ||||
| @@ -193,6 +193,14 @@ Halide::Expr dispatch_elemwise_mode( | |||||
| return Halide::round(inp(0)); | return Halide::round(inp(0)); | ||||
| case Mode::RMULH: | case Mode::RMULH: | ||||
| return (inp(0) * inp(1)) >> Halide::popcount(inp(0)); | return (inp(0) * inp(1)) >> Halide::popcount(inp(0)); | ||||
| case Mode::NOT: | |||||
| return cv(1) - cv(inp(0) != cv(0)); | |||||
| case Mode::AND: | |||||
| return cv(inp(0) != cv(0)) * cv(inp(1) != cv(0)); | |||||
| case Mode::OR: | |||||
| return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) > cv(0)); | |||||
| case Mode::XOR: | |||||
| return cv(cv(inp(0) != cv(0)) + cv(inp(1) != cv(0)) == cv(1)); | |||||
| default: | default: | ||||
| mgb_throw(InternalError, "unsupported Elemwise mode(%d)", | mgb_throw(InternalError, "unsupported Elemwise mode(%d)", | ||||
| static_cast<int>(mode)); | static_cast<int>(mode)); | ||||
| @@ -631,6 +631,8 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
| RET(EL2(H_SWISH_GRAD, i0, og)); | RET(EL2(H_SWISH_GRAD, i0, og)); | ||||
| case Mode::FUSE_ADD_H_SWISH: | case Mode::FUSE_ADD_H_SWISH: | ||||
| RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); | RET(EL2(H_SWISH_GRAD, (i0 + i1), og)); | ||||
| case Mode::NOT: | |||||
| return nullptr; | |||||
| // binary | // binary | ||||
| case Mode::ABS_GRAD: | case Mode::ABS_GRAD: | ||||
| @@ -693,6 +695,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
| return nullptr; | return nullptr; | ||||
| case Mode::EQ: | case Mode::EQ: | ||||
| RET_INVALID(); | RET_INVALID(); | ||||
| case Mode::OR: | |||||
| case Mode::XOR: | |||||
| case Mode::AND: | |||||
| return nullptr; | |||||
| // ternary | // ternary | ||||
| case Mode::COND_LEQ_MOV: | case Mode::COND_LEQ_MOV: | ||||
| @@ -408,6 +408,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { | |||||
| break; | break; | ||||
| case DTypeEnum::UintB4: | case DTypeEnum::UintB4: | ||||
| break; | break; | ||||
| case DTypeEnum::Bool: | |||||
| break; | |||||
| #define cb(x) case DTypeEnum::x: break; | #define cb(x) case DTypeEnum::x: break; | ||||
| MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) | MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) | ||||
| @@ -247,6 +247,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr, | |||||
| break; | break; | ||||
| case DTypeEnum::UintB4: | case DTypeEnum::UintB4: | ||||
| break; | break; | ||||
| case DTypeEnum::Bool: | |||||
| break; | |||||
| #define cb(_dt) \ | #define cb(_dt) \ | ||||
| case DTypeEnum::_dt: \ | case DTypeEnum::_dt: \ | ||||
| break; | break; | ||||
| @@ -32,6 +32,7 @@ namespace opr { | |||||
| EL1(exp, EXP) | EL1(exp, EXP) | ||||
| EL1(log, LOG) | EL1(log, LOG) | ||||
| EL1(abs, ABS) | EL1(abs, ABS) | ||||
| EL1(not_, NOT) | |||||
| #undef EL1 | #undef EL1 | ||||
| @@ -53,6 +54,9 @@ namespace opr { | |||||
| EL2(min, MIN) | EL2(min, MIN) | ||||
| EL2(switch_gt0, SWITCH_GT0) | EL2(switch_gt0, SWITCH_GT0) | ||||
| EL2(eq, EQ) | EL2(eq, EQ) | ||||
| EL2(and_, AND) | |||||
| EL2(or_, OR) | |||||
| EL2(xor_, XOR) | |||||
| #undef EL2 | #undef EL2 | ||||
| @@ -206,6 +206,7 @@ namespace { | |||||
| static constexpr Mode MODE = Mode::_mode; \ | static constexpr Mode MODE = Mode::_mode; \ | ||||
| static constexpr bool ALLOW_INT = _ALLOW_INT; \ | static constexpr bool ALLOW_INT = _ALLOW_INT; \ | ||||
| static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \ | static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \ | ||||
| static constexpr bool ALLOW_BOOL = _ALLOW_BOOL; \ | |||||
| static constexpr const char* NAME = #_mode; \ | static constexpr const char* NAME = #_mode; \ | ||||
| template<typename ctype> \ | template<typename ctype> \ | ||||
| static inline ctype apply( \ | static inline ctype apply( \ | ||||
| @@ -588,6 +589,14 @@ namespace { | |||||
| struct enable_for_dtype_impl<dtype::Int32, void> { | struct enable_for_dtype_impl<dtype::Int32, void> { | ||||
| static constexpr bool value = false; | static constexpr bool value = false; | ||||
| }; | }; | ||||
| template<class Trait> | |||||
| struct enable_for_dtype_impl<dtype::Bool, Trait> { | |||||
| static constexpr bool value = Trait::ALLOW_BOOL; | |||||
| }; | |||||
| template<> | |||||
| struct enable_for_dtype_impl<dtype::Bool, void> { | |||||
| static constexpr bool value = false; | |||||
| }; | |||||
| } | } | ||||
| //! whether to enable test for specific dtype and Trait | //! whether to enable test for specific dtype and Trait | ||||
| @@ -749,8 +758,60 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) { | |||||
| TEST(TestOprBasicArithElemwise, CheckAllModeTested) { | TEST(TestOprBasicArithElemwise, CheckAllModeTested) { | ||||
| size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; | size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER; | ||||
| ASSERT_EQ(nr_member, tested_mode.size()); | |||||
| ASSERT_EQ(nr_member, tested_mode.size() + 4); | |||||
| // Not using TestRunner: NOT, AND, OR, XOR | |||||
| } | } | ||||
| #define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \ | |||||
| TEST(TestOprBasicArithElemwise, _mode) { \ | |||||
| HostTensorGenerator<dtype::Bool> gen; \ | |||||
| auto host_x = gen({2, 1}); \ | |||||
| auto ptr = host_x->ptr<dt_bool>(); \ | |||||
| for (size_t i = 0; i < 2; ++i) { \ | |||||
| ptr[i] = (i & 1); \ | |||||
| } \ | |||||
| auto graph = ComputingGraph::make(); \ | |||||
| using Mode = opr::Elemwise::Mode; \ | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), \ | |||||
| y = opr::Elemwise::make({x}, Mode::_mode); \ | |||||
| HostTensorND host_y; \ | |||||
| auto func = graph->compile({make_callback_copy(y, host_y)}); \ | |||||
| func->execute(); \ | |||||
| ASSERT_EQ(TensorShape({2, 1}), host_y.shape()); \ | |||||
| auto ptry = host_y.ptr<dt_bool>(); \ | |||||
| for (int i = 0;i < 2;i ++) { \ | |||||
| ASSERT_EQ(_op ptr[i], ptry[i]); \ | |||||
| } \ | |||||
| } \ | |||||
| TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !) | |||||
| #define TEST_OPR_BASIC_ARITH_BINARY_BOOL(_mode, _op) \ | |||||
| TEST(TestOprBasicArithElemwise, _mode) { \ | |||||
| HostTensorGenerator<dtype::Bool> gen; \ | |||||
| auto host_x1 = gen({2, 2}), host_x2 = gen({2, 2}); \ | |||||
| auto ptr1 = host_x1->ptr<dt_bool>(), ptr2 = host_x2->ptr<dt_bool>(); \ | |||||
| for (size_t i = 0; i < 4; ++i) { \ | |||||
| ptr1[i] = (i < 2); \ | |||||
| ptr2[i] = (i & 1); \ | |||||
| } \ | |||||
| auto graph = ComputingGraph::make(); \ | |||||
| using Mode = opr::Elemwise::Mode; \ | |||||
| auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1), \ | |||||
| x2 = opr::Host2DeviceCopy::make(*graph, host_x2), \ | |||||
| y = opr::Elemwise::make({x1, x2}, Mode::_mode); \ | |||||
| HostTensorND host_y; \ | |||||
| auto func = graph->compile({make_callback_copy(y, host_y)}); \ | |||||
| func->execute(); \ | |||||
| ASSERT_EQ(TensorShape({2, 2}), host_y.shape()); \ | |||||
| auto ptry = host_y.ptr<dt_bool>(); \ | |||||
| for (int i = 0;i < 4;i ++) { \ | |||||
| ASSERT_EQ(ptr1[i] _op ptr2[i], ptry[i]); \ | |||||
| } \ | |||||
| } \ | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&) | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||) | |||||
| TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^) | |||||
| TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) { | ||||
| using Checker = AutoOprChecker<3, 1>; | using Checker = AutoOprChecker<3, 1>; | ||||
| @@ -19,6 +19,17 @@ | |||||
| ctype x = inp[0][idx]; \ | ctype x = inp[0][idx]; \ | ||||
| ctype y = inp[1][idx] | ctype y = inp[1][idx] | ||||
| #define _ALLOW_BOOL true | |||||
| #define _ALLOW_FLOAT false | |||||
| #define _ALLOW_INT false | |||||
| DEF_TRAIT(AND, x && y) | |||||
| DEF_TRAIT(OR, x || y) | |||||
| DEF_TRAIT(XOR, x ^ y) | |||||
| #undef _ALLOW_INT | |||||
| #undef _ALLOW_FLOAT | |||||
| #undef _ALLOW_BOOL | |||||
| #define _ALLOW_BOOL false | |||||
| #define _ALLOW_FLOAT true | #define _ALLOW_FLOAT true | ||||
| #define _ALLOW_INT true | #define _ALLOW_INT true | ||||
| DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) | DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) | ||||
| @@ -60,6 +71,7 @@ DEF_TRAIT(SHR, do_shr(x, y)) | |||||
| DEF_TRAIT(RMULH, do_round_mulh_saturate(x, y)) | DEF_TRAIT(RMULH, do_round_mulh_saturate(x, y)) | ||||
| #undef _ALLOW_INT | #undef _ALLOW_INT | ||||
| #undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
| #undef _ALLOW_BOOL | |||||
| #undef _CUR_ARITY | #undef _CUR_ARITY | ||||
| #undef _EXPAND_PARAMS | #undef _EXPAND_PARAMS | ||||
| @@ -20,6 +20,7 @@ | |||||
| ctype y = inp[1][idx]; \ | ctype y = inp[1][idx]; \ | ||||
| ctype z = inp[2][idx] | ctype z = inp[2][idx] | ||||
| #define _ALLOW_BOOL false | |||||
| #define _ALLOW_FLOAT true | #define _ALLOW_FLOAT true | ||||
| #define _ALLOW_INT true | #define _ALLOW_INT true | ||||
| DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | ||||
| @@ -46,5 +47,6 @@ DEF_TRAIT(FUSE_MUL_ADD4, i0 * i1 + i2 * i3) | |||||
| #undef _CUR_ARITY | #undef _CUR_ARITY | ||||
| #undef _EXPAND_PARAMS | #undef _EXPAND_PARAMS | ||||
| #undef _ALLOW_BOOL | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -18,6 +18,15 @@ | |||||
| #define _EXPAND_PARAMS \ | #define _EXPAND_PARAMS \ | ||||
| ctype x = inp[0][idx] | ctype x = inp[0][idx] | ||||
| #define _ALLOW_BOOL true | |||||
| #define _ALLOW_FLOAT false | |||||
| #define _ALLOW_INT false | |||||
| DEF_TRAIT(NOT, !x) | |||||
| #undef _ALLOW_INT | |||||
| #undef _ALLOW_FLOAT | |||||
| #undef _ALLOW_BOOL | |||||
| #define _ALLOW_BOOL false | |||||
| #define _ALLOW_FLOAT true | #define _ALLOW_FLOAT true | ||||
| @@ -51,6 +60,8 @@ DEF_TRAIT(H_SWISH, do_h_swish(x)) | |||||
| #undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
| #undef _ALLOW_BOOL | |||||
| #undef _CUR_ARITY | #undef _CUR_ARITY | ||||
| #undef _EXPAND_PARAMS | #undef _EXPAND_PARAMS | ||||
| @@ -21,6 +21,7 @@ enum DTypeEnum : byte { | |||||
| QuantizedS4, | QuantizedS4, | ||||
| QuantizedS16, | QuantizedS16, | ||||
| BFloat16, | BFloat16, | ||||
| Bool, | |||||
| } | } | ||||
| table LinearQuantizationParam { | table LinearQuantizationParam { | ||||
| @@ -140,6 +140,21 @@ namespace mgb { | |||||
| dtype::Int32, RandomDistribution::UNIFORM>; | dtype::Int32, RandomDistribution::UNIFORM>; | ||||
| template class HostTensorGenerator< | template class HostTensorGenerator< | ||||
| dtype::Int32, RandomDistribution::CONSTANT>; | dtype::Int32, RandomDistribution::CONSTANT>; | ||||
| std::shared_ptr<HostTensorND> | |||||
| HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM>:: | |||||
| operator()(const TensorShape& shape, CompNode cn) { | |||||
| if (!cn.valid()) | |||||
| cn = CompNode::load("xpu0"); | |||||
| auto dtype = dtype::Bool(); | |||||
| std::shared_ptr<HostTensorND> ret = | |||||
| std::make_shared<HostTensorND>(cn, shape, dtype); | |||||
| auto ptr = ret->ptr<dt_bool>(); | |||||
| for (size_t i = 0, it = shape.total_nr_elems(); i < it; ++i) { | |||||
| ptr[i] = (i % 2 == 1); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| std::shared_ptr<HostTensorND> | std::shared_ptr<HostTensorND> | ||||
| HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM>:: | HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM>:: | ||||
| operator()(const TensorShape& shape, CompNode cn) { | operator()(const TensorShape& shape, CompNode cn) { | ||||
| @@ -202,6 +202,10 @@ struct RandomDistributionDTypeDefault<dtype::Int32> { | |||||
| static constexpr auto dist = RandomDistribution::UNIFORM; | static constexpr auto dist = RandomDistribution::UNIFORM; | ||||
| }; | }; | ||||
| template<> | template<> | ||||
| struct RandomDistributionDTypeDefault<dtype::Bool> { | |||||
| static constexpr auto dist = RandomDistribution::UNIFORM; | |||||
| }; | |||||
| template<> | |||||
| struct RandomDistributionDTypeDefault<dtype::QuantizedS8> { | struct RandomDistributionDTypeDefault<dtype::QuantizedS8> { | ||||
| static constexpr auto dist = RandomDistribution::UNIFORM; | static constexpr auto dist = RandomDistribution::UNIFORM; | ||||
| }; | }; | ||||
| @@ -251,6 +255,10 @@ struct UniformRNGDefaultRange<dtype::Uint8> { | |||||
| static constexpr dt_uint8 LO = 0, HI = 255; | static constexpr dt_uint8 LO = 0, HI = 255; | ||||
| }; | }; | ||||
| template<> | template<> | ||||
| struct UniformRNGDefaultRange<dtype::Bool> { | |||||
| static constexpr dt_bool LO = false, HI = true; | |||||
| }; | |||||
| template<> | |||||
| struct UniformRNGDefaultRange<dtype::Int16> { | struct UniformRNGDefaultRange<dtype::Int16> { | ||||
| static constexpr dt_int16 LO = -32767, HI = 32767; | static constexpr dt_int16 LO = -32767, HI = 32767; | ||||
| }; | }; | ||||
| @@ -341,6 +349,20 @@ class HostTensorGenerator<dtype, RandomDistribution::CONSTANT> final: | |||||
| private: | private: | ||||
| ctype m_default_val; | ctype m_default_val; | ||||
| }; | }; | ||||
| template <> | |||||
| class HostTensorGenerator<dtype::Bool, RandomDistribution::UNIFORM> final | |||||
| : public HostTensorGeneratorBase { | |||||
| public: | |||||
| using ctype = typename DTypeTrait<dtype::Bool>::ctype; | |||||
| HostTensorGenerator(uint64_t seed = next_rand_seed()) | |||||
| : HostTensorGeneratorBase{seed} {} | |||||
| std::shared_ptr<HostTensorND> operator()(const TensorShape& shape, | |||||
| CompNode cn = {}) override; | |||||
| using HostTensorGeneratorBase::operator(); | |||||
| }; | |||||
| template <> | template <> | ||||
| class HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM> final | class HostTensorGenerator<dtype::QuantizedS8, RandomDistribution::UNIFORM> final | ||||