| @@ -6,6 +6,7 @@ dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | |||||
| dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary | dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary | ||||
| dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | ||||
| dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | ||||
| dnn/src/cuda/elemwise_multi_type/kimpl/* binary | |||||
| tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | ||||
| imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | ||||
| ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ||||
| @@ -382,6 +382,9 @@ struct TensorLayout : public TensorShape { | |||||
| //! get lowest and highest offset reachable from this layout | //! get lowest and highest offset reachable from this layout | ||||
| Span span() const; | Span span() const; | ||||
| //! total number of access bytes | |||||
| size_t access_bytes() const; | |||||
| }; | }; | ||||
| /** | /** | ||||
| @@ -308,6 +308,8 @@ class dt_qulowbit { | |||||
| return _; | return _; | ||||
| } | } | ||||
| MEGDNN_DEVICE uint8_t as_storage() const { return _; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} | MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} | ||||
| #ifdef MEGDNN_CC_HOST | #ifdef MEGDNN_CC_HOST | ||||
| explicit operator uint8_t() { return _; } | explicit operator uint8_t() { return _; } | ||||
| @@ -332,6 +334,8 @@ class dt_qlowbit { | |||||
| return _; | return _; | ||||
| } | } | ||||
| MEGDNN_DEVICE int8_t as_storage() const { return _; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} | MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} | ||||
| #ifdef MEGDNN_CC_HOST | #ifdef MEGDNN_CC_HOST | ||||
| explicit operator int8_t() { return _; } | explicit operator int8_t() { return _; } | ||||
| @@ -1,6 +1,10 @@ | |||||
| # As cuda currently do not support quint8, so we just ignore it. | # As cuda currently do not support quint8, so we just ignore it. | ||||
| SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | ||||
| SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32')] | |||||
| SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||||
| ('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||||
| SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||||
| SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||||
| MODES = { | MODES = { | ||||
| 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
| @@ -16,6 +20,15 @@ MODES = { | |||||
| 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
| } | } | ||||
| QINT4_MODES = { | |||||
| 1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||||
| 'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||||
| 2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
| 'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
| 'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||||
| 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
| } | |||||
| QINT32_MODES = { | QINT32_MODES = { | ||||
| 1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | 1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | ||||
| 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | ||||
| @@ -212,7 +212,7 @@ TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | |||||
| TensorLayout::TensorLayout(const TensorShape& shape, | TensorLayout::TensorLayout(const TensorShape& shape, | ||||
| const std::vector<ptrdiff_t>& stride, DType dtype) | const std::vector<ptrdiff_t>& stride, DType dtype) | ||||
| : TensorLayout(shape, stride, dtype, DefaultTensorFormat::make()) {} | |||||
| : TensorLayout(shape, stride, dtype, Format(dtype)) {} | |||||
| TensorLayout::TensorLayout(const TensorShape& shape, | TensorLayout::TensorLayout(const TensorShape& shape, | ||||
| const std::vector<ptrdiff_t>& stride, DType dtype, | const std::vector<ptrdiff_t>& stride, DType dtype, | ||||
| @@ -412,6 +412,27 @@ TensorLayout::Span TensorLayout::span() const { | |||||
| return format.impl()->span_spec(*this); | return format.impl()->span_spec(*this); | ||||
| } | } | ||||
| size_t TensorLayout::access_bytes() const { | |||||
| megdnn_assert(dtype.valid()); | |||||
| auto contig = collapse_contiguous(); | |||||
| size_t ret = 0; | |||||
| if (dtype.is_low_bit()) { | |||||
| ret = 1; | |||||
| int align_size_in_elements = 8 / dtype.low_bit(); | |||||
| for (size_t i = 0; i < contig.ndim; ++i) { | |||||
| if (contig.stride[i] == 1) { | |||||
| ret *= round_up((int)contig.shape[i], align_size_in_elements); | |||||
| } else { | |||||
| ret *= contig.shape[i]; | |||||
| } | |||||
| } | |||||
| ret /= align_size_in_elements; | |||||
| } else { | |||||
| ret = dtype.size(total_nr_elems()); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | ||||
| megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, | megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, | ||||
| "broadcast involves empty tensor"); | "broadcast involves empty tensor"); | ||||
| @@ -236,33 +236,66 @@ INST(dt_qint8); | |||||
| INST(dt_quint8); | INST(dt_quint8); | ||||
| #undef dt_ibyte | #undef dt_ibyte | ||||
| template <int ndim> | |||||
| void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | |||||
| const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | |||||
| m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | |||||
| for (size_t i = 0; i < rv.layout.ndim; ++i) { | |||||
| m_stride[i] = rv.layout.stride[i]; | |||||
| m_shape[i] = rv.layout.shape[i]; | |||||
| if (i + 1 < rv.layout.ndim) { | |||||
| m_shape_highdim[i] = rv.layout.shape[i + 1]; | |||||
| if (rv.layout.stride[i + 1] == 1) | |||||
| m_align_shape_highdim[i] = | |||||
| (uint32_t)round_up((int)rv.layout.shape[i + 1], 2); | |||||
| else | |||||
| m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | |||||
| } | |||||
| } | |||||
| for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | |||||
| m_shape_highdim[i] = 1; | |||||
| m_align_shape_highdim[i] = 1; | |||||
| } | |||||
| for (size_t i = rv.layout.ndim; i < ndim; ++i) { | |||||
| m_stride[i] = 0; | |||||
| m_shape[i] = 1; | |||||
| } | |||||
| m_is_physical_contiguous = rv.layout.is_physical_contiguous(); | |||||
| } | |||||
| #define ndim_cb(_ndim) \ | |||||
| template class ParamElemVisitor4bitBase<_ndim, BCAST_OTHER>; | |||||
| MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) | |||||
| #undef ndim_cb | |||||
| } // namespace elemwise_intl | } // namespace elemwise_intl | ||||
| void elemwise_intl::get_launch_spec(const void* kern, size_t size, | void elemwise_intl::get_launch_spec(const void* kern, size_t size, | ||||
| int* grid_size, int* block_size) { | int* grid_size, int* block_size) { | ||||
| safe_size_in_kern(size); | |||||
| auto config = query_launch_config_for_kernel(kern); | |||||
| *block_size = config.block_size; | |||||
| int a = size / (config.block_size * 2), | |||||
| b = (size - 1) / (config.block_size * 3) + 1; | |||||
| if (current_device_prop().major <= 3) { | |||||
| // for Kepler, less blocks (more work per thread) is faster | |||||
| *grid_size = b; | |||||
| } else { | |||||
| *grid_size = std::max(a, b); | |||||
| safe_size_in_kern(size); | |||||
| auto config = query_launch_config_for_kernel(kern); | |||||
| *block_size = config.block_size; | |||||
| int a = size / (config.block_size * 2), | |||||
| b = (size - 1) / (config.block_size * 3) + 1; | |||||
| if (current_device_prop().major <= 3) { | |||||
| // for Kepler, less blocks (more work per thread) is faster | |||||
| *grid_size = b; | |||||
| } else { | |||||
| *grid_size = std::max(a, b); | |||||
| } | |||||
| if (!*grid_size) { | |||||
| *block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024); | |||||
| *grid_size = std::max<int>(size / *block_size, 1); | |||||
| } | |||||
| // because we unroll 3 times in the kernel | |||||
| megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >= | |||||
| size); | |||||
| } | } | ||||
| if (!*grid_size) { | |||||
| *block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024); | |||||
| *grid_size = std::max<int>(size / *block_size, 1); | |||||
| } | |||||
| // because we unroll 3 times in the kernel | |||||
| megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >= size); | |||||
| } | |||||
| void elemwise_intl::on_bad_ndim(int ndim) { | |||||
| megdnn_throw(ssprintf("invalid ndim: %d", ndim)); | |||||
| MEGDNN_MARK_USED_VAR(ndim); | |||||
| } | |||||
| void elemwise_intl::on_bad_ndim(int ndim) { | |||||
| megdnn_throw(ssprintf("invalid ndim: %d", ndim)); | |||||
| MEGDNN_MARK_USED_VAR(ndim); | |||||
| } | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -115,6 +115,34 @@ INST(dt_qint32, int4); | |||||
| #undef as_raw | #undef as_raw | ||||
| #undef INST | #undef INST | ||||
| struct int4bx2 { | |||||
| int8_t x; | |||||
| }; | |||||
| struct uint4bx2 { | |||||
| uint8_t x; | |||||
| }; | |||||
| #define INST(_ctype, _Storage, _vect_type) \ | |||||
| template <> \ | |||||
| class VectTypeTrait<_ctype> { \ | |||||
| public: \ | |||||
| using Storage = _Storage; \ | |||||
| static const Storage kMask = 0xf; \ | |||||
| static const Storage kBits = 4; \ | |||||
| using vect_type = _vect_type; \ | |||||
| static const size_t packed_size = 2; \ | |||||
| static __device__ __forceinline__ vect_type make_vector(Storage x, \ | |||||
| Storage y) { \ | |||||
| vect_type t; \ | |||||
| t.x = (x & kMask) | (y << kBits); \ | |||||
| return t; \ | |||||
| } \ | |||||
| } | |||||
| INST(dt_qint4, int8_t, int4bx2); | |||||
| INST(dt_quint4, uint8_t, uint4bx2); | |||||
| #undef INST | |||||
| /*! | /*! | ||||
| * \brief visitor to access an elemeent in a tensor at given logic index | * \brief visitor to access an elemeent in a tensor at given logic index | ||||
| * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) | * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) | ||||
| @@ -217,6 +245,7 @@ template <int ndim, typename ctype> | |||||
| class ParamElemVisitor<ndim, ctype, BCAST_OTHER> | class ParamElemVisitor<ndim, ctype, BCAST_OTHER> | ||||
| : public ParamVisitorBase<ndim, ctype, BCAST_OTHER> { | : public ParamVisitorBase<ndim, ctype, BCAST_OTHER> { | ||||
| public: | public: | ||||
| using CType = ctype; | |||||
| PARAM_ELEM_VISITOR_COMMON_HOST | PARAM_ELEM_VISITOR_COMMON_HOST | ||||
| void host_init(const TensorND& rv, int grid_size, int block_size) { | void host_init(const TensorND& rv, int grid_size, int block_size) { | ||||
| @@ -500,6 +529,177 @@ public: | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| template <int ndim, BcastType brd_type> | |||||
| class ParamElemVisitor4bitBase; | |||||
| template <int ndim> | |||||
| class ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
| using Storage = int8_t; | |||||
| protected: | |||||
| Storage* __restrict m_ptr; | |||||
| int m_stride[ndim]; | |||||
| int m_shape[ndim]; | |||||
| bool m_is_physical_contiguous; | |||||
| //! m_shape_highdim[i] = original_shape[i + 1] | |||||
| #ifdef _MSC_VER | |||||
| Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; | |||||
| Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1]; | |||||
| #else | |||||
| Uint32Fastdiv m_shape_highdim[ndim]; | |||||
| Uint32Fastdiv m_align_shape_highdim[ndim]; | |||||
| #endif | |||||
| public: | |||||
| static const Storage kMask = 0xf; | |||||
| static const Storage kBits = 4; | |||||
| static const int NDIM = ndim; | |||||
| void host_init(const TensorND& rv, int grid_size, int block_size); | |||||
| #if MEGDNN_CC_CUDA | |||||
| devfunc void thread_init(uint32_t) {} | |||||
| devfunc void next() {} | |||||
| devfunc void get_shape_from_access(uint32_t access_idx, | |||||
| int (&shape_idx)[ndim]) { | |||||
| #pragma unroll | |||||
| for (int i = ndim - 1; i >= 1; --i) { | |||||
| Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1]; | |||||
| uint32_t access_idx_div = access_idx / align_shp; | |||||
| shape_idx[i] = access_idx - access_idx_div * align_shp.divisor(); | |||||
| access_idx = access_idx_div; | |||||
| } | |||||
| shape_idx[0] = access_idx; | |||||
| } | |||||
| devfunc int offset(uint32_t idx) { | |||||
| int offset = 0; | |||||
| #pragma unroll | |||||
| for (int i = ndim - 1; i >= 1; --i) { | |||||
| Uint32Fastdiv& shp = m_shape_highdim[i - 1]; | |||||
| uint32_t idx_div = idx / shp; | |||||
| offset += (idx - idx_div * shp.divisor()) * m_stride[i]; | |||||
| idx = idx_div; | |||||
| } | |||||
| offset += idx * m_stride[0]; | |||||
| return offset; | |||||
| } | |||||
| devfunc int idx(uint32_t access_idx) { | |||||
| int idx = 0; | |||||
| if (m_is_physical_contiguous) { | |||||
| idx = access_idx; | |||||
| } else { | |||||
| int shape_idx[ndim]; | |||||
| bool valid = true; | |||||
| get_shape_from_access(access_idx, shape_idx); | |||||
| #pragma unroll | |||||
| for (int i = 0; i < ndim; ++i) { | |||||
| valid &= (shape_idx[i] < m_shape[i]); | |||||
| } | |||||
| #pragma unroll | |||||
| for (int i = 0; i < ndim - 1; ++i) { | |||||
| idx = (idx + shape_idx[i]) * m_shape[i + 1]; | |||||
| } | |||||
| idx = valid ? idx + shape_idx[ndim - 1] : -1; | |||||
| } | |||||
| return idx; | |||||
| } | |||||
| devfunc Storage* ptr() { return m_ptr; } | |||||
| #endif | |||||
| }; | |||||
| template <int ndim> | |||||
| class ParamElemVisitor<ndim, dt_qint4, BCAST_OTHER> | |||||
| : public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
| using CType = dt_qint4; | |||||
| using Storage = int8_t; | |||||
| public: | |||||
| static const int packed_size = 1; | |||||
| using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; | |||||
| void host_init(const TensorND& rv, int grid_size, int block_size) { | |||||
| Super::host_init(rv, grid_size, block_size); | |||||
| } | |||||
| #if MEGDNN_CC_CUDA | |||||
| // cannot be l-value, only support read | |||||
| devfunc dt_qint4 at(uint32_t idx) { | |||||
| int offset_ = Super::offset(idx); | |||||
| int vec_idx = offset_ >> 1; | |||||
| int lane_idx = offset_ & 0x1; | |||||
| Storage item = Storage(unpack_integer_4bits<true>( | |||||
| *(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4)); | |||||
| dt_qint4 result(item); | |||||
| return result; | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <int ndim> | |||||
| class ParamElemVisitor<ndim, dt_quint4, BCAST_OTHER> | |||||
| : public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
| using CType = dt_quint4; | |||||
| using Storage = uint8_t; | |||||
| using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; | |||||
| public: | |||||
| static const int packed_size = 1; | |||||
| void host_init(const TensorND& rv, int grid_size, int block_size) { | |||||
| Super::host_init(rv, grid_size, block_size); | |||||
| } | |||||
| #if MEGDNN_CC_CUDA | |||||
| // cannot be l-value, only support read | |||||
| devfunc dt_quint4 at(uint32_t idx) { | |||||
| int offset_ = Super::offset(idx); | |||||
| int vec_idx = offset_ >> 1; | |||||
| int lane_idx = offset_ & 0x1; | |||||
| Storage item = Storage(unpack_integer_4bits<false>( | |||||
| *(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4)); | |||||
| dt_quint4 result(item); | |||||
| return result; | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| #if MEGDNN_CC_CUDA | |||||
| #define DEVICE_WRAPPER(x) x | |||||
| #else | |||||
| #define DEVICE_WRAPPER(x) | |||||
| #endif | |||||
| #define INST_DT_IBYTE(ctype) \ | |||||
| template <int ndim> \ | |||||
| class ParamVectVisitor<ndim, ctype, BCAST_OTHER> \ | |||||
| : public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { \ | |||||
| public: \ | |||||
| using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; \ | |||||
| void host_init(const TensorND& rv, int grid_size, int block_size) { \ | |||||
| Super::host_init(rv, grid_size, block_size); \ | |||||
| } \ | |||||
| using rwtype = typename VectTypeTrait<ctype>::vect_type; \ | |||||
| static const int packed_size = VectTypeTrait<ctype>::packed_size; \ | |||||
| DEVICE_WRAPPER(devfunc rwtype& at(uint32_t access_idx) { \ | |||||
| return *(rwtype*)(&Super::m_ptr[access_idx]); \ | |||||
| }) \ | |||||
| }; | |||||
| INST_DT_IBYTE(dt_qint4); | |||||
| INST_DT_IBYTE(dt_quint4); | |||||
| #undef DEVICE_WRAPPER | |||||
| #undef INST_DT_IBYTE | |||||
| /* f}}} */ | /* f}}} */ | ||||
| #if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
| @@ -507,7 +707,8 @@ public: | |||||
| /* f{{{ user operator callers */ | /* f{{{ user operator callers */ | ||||
| /* | /* | ||||
| * OpCaller is used to invoke user operator with loaded element arguments. | |||||
| * OpCaller is used to invoke user operator with loaded element | |||||
| * arguments. | |||||
| * | * | ||||
| * device interface: | * device interface: | ||||
| * void thread_init(uint32_t idx); | * void thread_init(uint32_t idx); | ||||
| @@ -518,8 +719,8 @@ public: | |||||
| */ | */ | ||||
| /*! | /*! | ||||
| * \brief call user op directly without visiting any params (i.e. arity == | |||||
| * 0) | |||||
| * \brief call user op directly without visiting any params (i.e. arity | |||||
| * == 0) | |||||
| */ | */ | ||||
| template <class Op> | template <class Op> | ||||
| struct OpCallerNull { | struct OpCallerNull { | ||||
| @@ -1151,6 +1352,20 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| #define INST_DT_TYPE(ctype) \ | |||||
| template <class Op> \ | |||||
| class UserOpInvoker<Op, ctype, 2> \ | |||||
| : public UserOpInvokerToSameNdim<Op, ctype, 2> { \ | |||||
| public: \ | |||||
| UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, \ | |||||
| const Op& op) \ | |||||
| : UserOpInvokerToSameNdim<Op, ctype, 2>(param, stream, op) {} \ | |||||
| } | |||||
| INST_DT_TYPE(dt_qint4); | |||||
| INST_DT_TYPE(dt_quint4); | |||||
| #undef INST_DT_TYPE | |||||
| #define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \ | #define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \ | ||||
| _stride) \ | _stride) \ | ||||
| DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ | DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ | ||||
| @@ -1404,7 +1619,6 @@ void run_elemwise(const ElemwiseOpParamN<arity>& param, cudaStream_t stream, | |||||
| #define INST_RUN_ELEMWISE(Op, ctype, arity) \ | #define INST_RUN_ELEMWISE(Op, ctype, arity) \ | ||||
| template void run_elemwise<Op, ctype, arity>( \ | template void run_elemwise<Op, ctype, arity>( \ | ||||
| const ElemwiseOpParamN<arity>&, cudaStream_t, const Op&) | const ElemwiseOpParamN<arity>&, cudaStream_t, const Op&) | ||||
| #endif | #endif | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -0,0 +1,256 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise_helper_q4.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/cuda/elemwise_helper.cuh" | |||||
| /* | |||||
| * please note that all arithmetics on GPU are 32-bit for best performance; this | |||||
| * limits max possible size | |||||
| */ | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| template <typename ctype> | |||||
| struct IsNotTypeQ4 { | |||||
| static constexpr bool value = !(std::is_same<ctype, dt_qint4>::value || | |||||
| std::is_same<ctype, dt_quint4>::value); | |||||
| }; | |||||
| template <typename ctype> | |||||
| struct IsTypeQ4 { | |||||
| static constexpr bool value = (std::is_same<ctype, dt_qint4>::value || | |||||
| std::is_same<ctype, dt_quint4>::value); | |||||
| }; | |||||
| //! internals for element-wise | |||||
| namespace elemwise_intl { | |||||
| #define devfunc __device__ __forceinline__ | |||||
| #if MEGDNN_CC_CUDA | |||||
| /*! | |||||
| * \brief call an operator whose each param are promted to the same ndim and | |||||
| * brdcast_mask | |||||
| * \tparam PVis ParamElemVisitor class | |||||
| */ | |||||
| template <class Op, int arity, class PVisSrc, class PVisDst, bool BetweenQ4> | |||||
| struct OpCallerToQ4; | |||||
| //! specialization for arity == 1 | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, false> { | |||||
| Op op; | |||||
| PVisSrc par_src[1]; | |||||
| PVisDst par_dst[1]; | |||||
| using src_ctype = typename PVisSrc::CType; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
| int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
| src_ctype src0 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
| src_ctype src1 = (idx1 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
| op(access_idx, src0, src1); | |||||
| } | |||||
| }; | |||||
| //! specialization for arity == 2 | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, false> { | |||||
| Op op; | |||||
| PVisSrc par_src[2]; | |||||
| PVisDst par_dst[1]; | |||||
| using src_ctype = typename PVisSrc::CType; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
| int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
| src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
| src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0; | |||||
| src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
| src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0; | |||||
| op(access_idx, src00, src10, src01, src11); | |||||
| } | |||||
| }; | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, false> { | |||||
| Op op; | |||||
| PVisSrc par_src[3]; | |||||
| PVisDst par_dst[1]; | |||||
| using src_ctype = typename PVisSrc::CType; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
| int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
| src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
| src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0; | |||||
| src_ctype src20 = (idx0 >= 0) ? par_src[2].at(idx0) : (src_ctype)0; | |||||
| src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
| src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0; | |||||
| src_ctype src21 = (idx0 >= 0) ? par_src[2].at(idx1) : (src_ctype)0; | |||||
| op(access_idx, src00, src10, src20, src01, src11, src21); | |||||
| } | |||||
| }; | |||||
| //! specialization for arity == 1 | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, true> { | |||||
| Op op; | |||||
| PVisSrc par_src[1]; | |||||
| PVisDst par_dst[1]; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| op(access_idx, par_src[0].at(access_idx)); | |||||
| } | |||||
| }; | |||||
| //! specialization for arity == 2 | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, true> { | |||||
| Op op; | |||||
| PVisSrc par_src[2]; | |||||
| PVisDst par_dst[1]; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx)); | |||||
| } | |||||
| }; | |||||
| template <class Op, class PVisSrc, class PVisDst> | |||||
| struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, true> { | |||||
| Op op; | |||||
| PVisSrc par_src[3]; | |||||
| PVisDst par_dst[1]; | |||||
| devfunc void on(uint32_t access_idx) { | |||||
| op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx), | |||||
| par_src[2].at(access_idx)); | |||||
| } | |||||
| }; | |||||
| /* f}}} */ | |||||
| template <class OpCaller> | |||||
| __global__ void cuda_kern_q4(OpCaller op_caller, uint32_t size) { | |||||
| uint32_t access_idx = blockIdx.x * blockDim.x + threadIdx.x, | |||||
| delta = blockDim.x * gridDim.x; | |||||
| if (access_idx < size) { | |||||
| op_caller.on(access_idx); | |||||
| access_idx += delta; | |||||
| if (access_idx < size) { | |||||
| op_caller.on(access_idx); | |||||
| access_idx += delta; | |||||
| if (access_idx < size) { | |||||
| op_caller.on(access_idx); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| /* f{{{ UserOpInvoker specializations */ | |||||
| //! run op by promoting all params to same ndim | |||||
| template <class Op, typename src_ctype, typename dst_ctype, int arity, | |||||
| bool BetweenQ4> | |||||
| class UserOpInvokerQ4 { | |||||
| const ElemwiseOpParamN<arity>& m_src_param; | |||||
| const ElemwiseOpParamN<1>& m_dst_param; | |||||
| cudaStream_t m_stream; | |||||
| const Op& m_op; | |||||
| void dispatch0() { | |||||
| switch (m_dst_param.max_ndim) { | |||||
| #define cb(ndim) \ | |||||
| case ndim: \ | |||||
| return dispatch1<ndim>(); | |||||
| MEGDNN_FOREACH_TENSOR_NDIM(cb) | |||||
| #undef cb | |||||
| } | |||||
| on_bad_ndim(m_dst_param.max_ndim); | |||||
| } | |||||
| template <int ndim> | |||||
| void dispatch1() { | |||||
| using PVisSrc = typename std::conditional< | |||||
| BetweenQ4, ParamVectVisitor<ndim, src_ctype, BCAST_OTHER>, | |||||
| ParamElemVisitor<ndim, src_ctype, BCAST_OTHER>>::type; | |||||
| typedef OpCallerToQ4<Op, arity, PVisSrc, | |||||
| ParamVectVisitor<ndim, dst_ctype, BCAST_OTHER>, | |||||
| BetweenQ4> | |||||
| Caller; | |||||
| size_t size = m_dst_param[0].layout.access_bytes(); | |||||
| int grid_size, block_size; | |||||
| void (*fptr)(Caller, uint32_t) = cuda_kern_q4<Caller>; | |||||
| get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size, | |||||
| &block_size); | |||||
| Caller caller; | |||||
| caller.op = m_op; | |||||
| for (int i = 0; i < arity; ++i) | |||||
| caller.par_src[i].host_init(m_src_param[i], grid_size, block_size); | |||||
| caller.par_dst[0].host_init(m_dst_param[0], grid_size, block_size); | |||||
| (*fptr)<<<grid_size, block_size, 0, m_stream>>>(caller, size); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| public: | |||||
| UserOpInvokerQ4(const ElemwiseOpParamN<arity>& src_param, | |||||
| const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
| const Op& op) | |||||
| : m_src_param(src_param), | |||||
| m_dst_param(dst_param), | |||||
| m_stream(stream), | |||||
| m_op(op) { | |||||
| dispatch0(); | |||||
| } | |||||
| }; | |||||
| #endif | |||||
| /* f}}} */ | |||||
| #undef devfunc | |||||
| } // namespace elemwise_intl | |||||
| template <class Op, typename src_ctype, typename dst_ctype, int arity> | |||||
| void run_elemwise(const ElemwiseOpParamN<arity>& src_param, | |||||
| const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
| const Op& op = Op()); | |||||
| #if MEGDNN_CC_CUDA | |||||
| template <class Op, typename src_ctype, typename dst_ctype, int arity> | |||||
| void run_elemwise(const ElemwiseOpParamN<arity>& src_param, | |||||
| const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
| const Op& op) { | |||||
| src_param.assert_initialized(); | |||||
| dst_param.assert_initialized(); | |||||
| // TODO: Maybe 2bit? | |||||
| megdnn_assert(dst_param[0].layout.dtype.is_low_bit()); | |||||
| megdnn_assert(dst_param[0].layout.is_contiguous()); | |||||
| elemwise_intl::UserOpInvokerQ4<Op, src_ctype, dst_ctype, arity, | |||||
| IsTypeQ4<src_ctype>::value>( | |||||
| src_param, dst_param, stream, op); | |||||
| } | |||||
| #define INST_RUN_ELEMWISE_LOWBIT(Op, src_ctype, dst_ctype, arity) \ | |||||
| template void run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
| const ElemwiseOpParamN<arity>&, const ElemwiseOpParamN<1>&, \ | |||||
| cudaStream_t, const Op&) | |||||
| #endif | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #ifndef KERN_IMPL_MODE | |||||
| #error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined" | |||||
| #endif | |||||
| #include "src/cuda/elemwise_multi_type/kern_ops.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| #define cb(_m) \ | |||||
| typedef ElemwiseKern<megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, \ | |||||
| float> \ | |||||
| KernImpl; \ | |||||
| typedef kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
| KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ | |||||
| Op; \ | |||||
| INST_RUN_ELEMWISE_LOWBIT(Op, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, \ | |||||
| KERN_IMPL_ARITY); | |||||
| KERN_IMPL_MODE(cb) | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -6,11 +6,13 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/cuda/elemwise_helper.cuh" | #include "src/cuda/elemwise_helper.cuh" | ||||
| #include "src/cuda/elemwise_helper_q4.cuh" | |||||
| #include "src/cuda/elemwise_multi_type/kern.cuh" | #include "src/cuda/elemwise_multi_type/kern.cuh" | ||||
| #include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
| @@ -127,10 +129,10 @@ struct QuantizedMultiTypeOp; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
| struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
| 1, ctype_src, ctype_dst, KernImpl, | 1, ctype_src, ctype_dst, KernImpl, | ||||
| typename std::enable_if< | |||||
| std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
| ctype_dst* dst; | ctype_dst* dst; | ||||
| CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
| CudaDTypeParam<ctype_src> param_a; | CudaDTypeParam<ctype_src> param_a; | ||||
| @@ -173,10 +175,10 @@ struct QuantizedMultiTypeOp< | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
| struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
| 2, ctype_src, ctype_dst, KernImpl, | 2, ctype_src, ctype_dst, KernImpl, | ||||
| typename std::enable_if< | |||||
| std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
| ctype_dst* dst; | ctype_dst* dst; | ||||
| CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
| CudaDTypeParam<ctype_src> param_a, param_b; | CudaDTypeParam<ctype_src> param_a, param_b; | ||||
| @@ -224,10 +226,10 @@ struct QuantizedMultiTypeOp< | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
| struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
| 3, ctype_src, ctype_dst, KernImpl, | 3, ctype_src, ctype_dst, KernImpl, | ||||
| typename std::enable_if< | |||||
| std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
| ctype_dst* dst; | ctype_dst* dst; | ||||
| CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
| CudaDTypeParam<ctype_src> param_a, param_b, param_c; | CudaDTypeParam<ctype_src> param_a, param_b, param_c; | ||||
| @@ -277,6 +279,367 @@ struct QuantizedMultiTypeOp< | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 1, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
| IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
| ctype_dst* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ ctype_dst apply(ctype_src v1) { | |||||
| float fv1 = param_a.dequantize(v1); | |||||
| float rv = KernImpl::apply(fv1); | |||||
| return dst_param.quantize(rv); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) { | |||||
| dst[idx] = dst_param.quantize(KernImpl::apply(param_a.dequantize(a))); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 2, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
| IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
| ctype_dst* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a, param_b; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| param_b = src_params[1]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) { | |||||
| float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
| float rv = KernImpl::apply(fv1, fv2); | |||||
| return dst_param.quantize(rv); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, | |||||
| ctype_src b) { | |||||
| dst[idx] = dst_param.quantize( | |||||
| KernImpl::apply(param_a.dequantize(a), param_b.dequantize(b))); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 1, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using src_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a; | |||||
| static constexpr bool src_signedness = | |||||
| std::is_same<ctype_src, dt_qint4>::value; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
| src_vect_type; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(src_storage v1) { | |||||
| float fv1 = param_a.dequantize(v1); | |||||
| float rv = KernImpl::apply(fv1); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) { | |||||
| dst_storage x = apply( | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 0))); | |||||
| dst_storage y = apply( | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 4))); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 1, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(ctype_src v1) { | |||||
| float fv1 = param_a.dequantize(v1); | |||||
| float rv = KernImpl::apply(fv1); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
| ctype_src a_y) { | |||||
| dst_storage x = apply(a_x), y = apply(a_y); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 2, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using src_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a, param_b; | |||||
| static constexpr bool src_signedness = | |||||
| std::is_same<ctype_src, dt_qint4>::value; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
| src_vect_type; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| param_b = src_params[1]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(src_storage v1, | |||||
| src_storage v2) { | |||||
| float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
| float rv = KernImpl::apply(fv1, fv2); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, | |||||
| src_vect_type b) { | |||||
| src_storage a_x = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 0)); | |||||
| src_storage a_y = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 4)); | |||||
| src_storage b_x = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(b.x, 0)); | |||||
| src_storage b_y = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(b.x, 4)); | |||||
| dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 2, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a, param_b; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| param_b = src_params[1]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2) { | |||||
| float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
| float rv = KernImpl::apply(fv1, fv2); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
| ctype_src b_x, ctype_src a_y, | |||||
| ctype_src b_y) { | |||||
| dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 3, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using src_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a, param_b, param_c; | |||||
| static constexpr bool src_signedness = | |||||
| std::is_same<ctype_src, dt_qint4>::value; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
| src_vect_type; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| param_b = src_params[1]; | |||||
| param_c = src_params[2]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(src_storage v1, src_storage v2, | |||||
| src_storage v3) { | |||||
| float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), | |||||
| fv3 = param_c.dequantize(v3); | |||||
| float rv = KernImpl::apply(fv1, fv2, fv3); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, | |||||
| src_vect_type b, | |||||
| src_vect_type c) { | |||||
| src_storage a_x = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 0)); | |||||
| src_storage a_y = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(a.x, 4)); | |||||
| src_storage b_x = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(b.x, 0)); | |||||
| src_storage b_y = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(b.x, 4)); | |||||
| src_storage c_x = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(c.x, 0)); | |||||
| src_storage c_y = | |||||
| src_storage(unpack_integer_4bits<src_signedness>(c.x, 4)); | |||||
| dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
| struct QuantizedMultiTypeOp< | |||||
| 3, ctype_src, ctype_dst, KernImpl, | |||||
| typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
| std::is_same<ctype_src, dt_qint32>::value || | |||||
| std::is_same<ctype_src, dt_quint8>::value) && | |||||
| IsTypeQ4<ctype_dst>::value>::type> { | |||||
| using dst_storage = | |||||
| typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
| dst_storage* dst; | |||||
| CudaDTypeParam<ctype_dst> dst_param; | |||||
| CudaDTypeParam<ctype_src> param_a, param_b, param_c; | |||||
| typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
| dst_vect_type; | |||||
| #if !MEGDNN_CC_CUDA | |||||
| QuantizedMultiTypeOp( | |||||
| const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
| dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
| : dst{dst}, dst_param{dst_param} { | |||||
| param_a = src_params[0]; | |||||
| param_b = src_params[1]; | |||||
| param_c = src_params[2]; | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_CC_CUDA | |||||
| __device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2, | |||||
| ctype_src v3) { | |||||
| float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), | |||||
| fv3 = param_c.dequantize(v3); | |||||
| float rv = KernImpl::apply(fv1, fv2, fv3); | |||||
| return dst_param.quantize(rv).as_storage(); | |||||
| } | |||||
| __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
| ctype_src b_x, ctype_src c_x, | |||||
| ctype_src a_y, ctype_src b_y, | |||||
| ctype_src c_y) { | |||||
| dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y); | |||||
| *(dst_vect_type*)(&dst[idx]) = | |||||
| elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
| } | |||||
| #endif | |||||
| }; | |||||
| } // namespace kern_ops_quantized | } // namespace kern_ops_quantized | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
| #define KERN_IMPL_ARITY 3 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||||
| #define KERN_IMPL_ARITY 2 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint32 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_qint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint4 | |||||
| #include "../kern_impl_q4.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_qint32 | |||||
| #include "../kern_impl.inl" | |||||
| @@ -0,0 +1,6 @@ | |||||
| // generated by gen_elemwise_multi_type_kern_impls.py | |||||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
| #define KERN_IMPL_ARITY 1 | |||||
| #define KERN_IMPL_STYPE dt_quint4 | |||||
| #define KERN_IMPL_DTYPE dt_quint4 | |||||
| #include "../kern_impl_q4.inl" | |||||