diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index 05ff43cb..f2999c98 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -44,30 +44,29 @@ namespace { break; #define FOR_NONLINEAR_UNARY(_op) \ - megdnn::arm_common::OpCallerUnary<_op, megdnn::arm_common::VEC>::run( \ + megdnn::arm_common::OpCallerUnary<_op, megdnn::VEC>::run( \ static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ bias_type, dst_type, N* OC* OH* OW* pack_oc_size); -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ - OC, OH* OW); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N, OC, OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>::run( \ static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ + OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ + OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_mode) \ @@ -168,36 +167,33 @@ struct PostProcess { #undef FOR_BIAS #undef HANDLE_IDENTITY -#define FOR_NONLINEAR_UNARY(_op) \ +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::arm_common::OpCallerUnary<_op, megdnn::VEC>::run( \ + static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ + bias_type, dst_type, N* OC* OH* OW* pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ megdnn::arm_common:: \ - OpCallerUnary<_op, megdnn::arm_common::VEC>::run( \ + OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ static_cast(conv_dst_ptr), \ - reinterpret_cast(dst_ptr), bias_type, dst_type, \ - N* OC* OH* OW* pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary< \ - _op, megdnn::arm_common::VEC_BCAST101>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N, OC, OH* OW); + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary< \ - _op, megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N, OC, OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ - megdnn::arm_common::OpCallerBinary< \ - _op, megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N, OC, OH* OW, pack_oc_size); +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, \ + dst_type, N, OC, OH* OW, pack_oc_size); #define HANDLE_IDENTITY(_caller, _op) \ case megdnn::NonlineMode::IDENTITY: \ @@ -271,26 +267,25 @@ struct PostProcess { #undef FOR_NONLINEAR #undef FOR_BIAS -#define FOR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ - OC, OH* OW); - -#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101xX>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N, OC, OH* OW, pack_oc_size); - -#define FOR_BINARY(_op) \ - megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ +#define FOR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ + OH* OW); + +#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_BCAST101xX>::run( \ static_cast(conv_dst_ptr), \ reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ + OH* OW, pack_oc_size); + +#define FOR_BINARY(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_bias_mode, OH, OW) \ diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index e144834a..c4800d49 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -89,163 +89,4 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { fallback::ElemwiseImpl::exec(srcs, dst); } -ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { - KernParam kern_param; - kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE; - kern_param.mode = opr->param().mode; - kern_param.handle = opr->handle(); - - auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { - if (is_vector(l)) - return true; - if (l.ndim == 2 && l.stride[1] == 1) - return true; - return false; - }; - - if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { - kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); - bool c_is_scalar; - opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); - auto &src0 = kern_param.ternary_elparam[0], - &src1 = kern_param.ternary_elparam[1], - &src2 = kern_param.ternary_elparam[2]; - BroadcastChannelInfo binfo; - - if (is_vector(src0.layout) && is_vector(src1.layout) && - is_vector(src2.layout)) { - kern_param.broad_cast_type = BcastType::VEC_VEC_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) { - kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR; - return kern_param; - } - - if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) && - src0.layout.eq_layout(src2.layout)) { - kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; - return kern_param; - } - - if (is_vector(src1.layout) && - (is_broadcastedx_channel_like<4>(src0.layout, binfo) || - is_broadcastedx_channel_like<8>(src0.layout, binfo)) && - src0.layout.eq_layout(src2.layout)) { - kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; - return kern_param; - } - - if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && - is_broadcasted_channel_like(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; - return kern_param; - } - - if (is_legal_layout_for_nhwc(src1.layout) && - is_NHWC_broadcasted_channel_like(src0.layout, binfo) && - src0.layout.eq_layout(src2.layout)) { - kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; - return kern_param; - } - - if (is_legal_layout_for_nhwc(src0.layout) && - src2.layout.eq_layout(src0.layout) && - is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && - (is_broadcastedx_channel_like<4>(src1.layout, binfo) || - is_broadcastedx_channel_like<8>(src1.layout, binfo))) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_vector(src2.layout) && - is_broadcasted_scalar(src1.layout)) { - kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && - is_broadcasted_scalar(src2.layout)) { - kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR; - return kern_param; - } - } else if (opr->m_src->size() == 2) { - kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); - auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1]; - BroadcastChannelInfo binfo; - if (is_vector(src0.layout) && is_vector(src1.layout)) { - kern_param.broad_cast_type = BcastType::VEC_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { - kern_param.broad_cast_type = BcastType::VEC_SCALAR; - return kern_param; - } - - if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { - kern_param.broad_cast_type = BcastType::SCALAR_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101; - return kern_param; - } - - if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { - kern_param.broad_cast_type = BcastType::BCAST101_VEC; - return kern_param; - } - - if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; - return kern_param; - } - - if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { - kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; - return kern_param; - } - - if (is_legal_layout_for_nhwc(src1.layout) && - is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { - kern_param.broad_cast_type = BcastType::BCAST111C_VEC; - return kern_param; - } - - if (is_legal_layout_for_nhwc(src0.layout) && - is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { - kern_param.broad_cast_type = BcastType::VEC_BCAST111C; - return kern_param; - } - - if (is_vector(src0.layout) && - (is_broadcastedx_channel_like<4>(src1.layout, binfo) || - is_broadcastedx_channel_like<8>(src1.layout, binfo))) { - kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; - return kern_param; - } - - if (is_vector(src1.layout) && - (is_broadcastedx_channel_like<4>(src0.layout, binfo) || - is_broadcastedx_channel_like<8>(src0.layout, binfo))) { - kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; - return kern_param; - } - } else if (opr->m_src->size() == 1) { - kern_param.broad_cast_type = BcastType::VEC; - kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); - return kern_param; - } - - return kern_param; -} - // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index 8f528a4d..7d0cc9cf 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -18,22 +18,12 @@ namespace megdnn { namespace arm_common { class ElemwiseImpl final : public fallback::ElemwiseImpl { public: + using fallback::ElemwiseImpl::AlgoBase; using fallback::ElemwiseImpl::ElemwiseImpl; void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override; const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; } private: - struct KernParam { - BcastType broad_cast_type; - Mode mode; - const TensorND* m_dst; - Handle* handle; - ElemwiseOpParamN<3> ternary_elparam; - ElemwiseOpParamN<2> binary_elparam; - ElemwiseOpParamN<1> unary_elparam; - }; - KernParam make_kern_param(ElemwiseImpl* opr); - class AlgoBase; class AlgoUnary; class AlgoBinaryVecVec; class AlgoBinaryVecScalar; @@ -54,19 +44,6 @@ private: class AlgoPack; }; -/*! - * - * \brief base class for Elemwise algo - * - */ -class ElemwiseImpl::AlgoBase : public detail::Algorithm { -public: - virtual bool is_available(const KernParam&) const = 0; - virtual void exec(const KernParam&) const = 0; - virtual ~AlgoBase() = default; - uint32_t type() const override { return INVALID_ALGO_TYPE; }; -}; - #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #define DISPATCH_TYPE(_case) \ if (src0.layout.dtype == dtype::Float32{}) { \ diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index 4a3333cb..92ebcc29 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -15,10 +15,13 @@ #include "src/arm_common/elemwise_helper/op_binary.h" #include "src/arm_common/elemwise_helper/op_ternary.h" #include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/fallback/elemwise_helper/op_common.h" namespace megdnn { namespace arm_common { +using BcastType = megdnn::BcastType; + ///////////////////////////////// ParamElemVistor /////////////////////////// template struct ParamElemVisitor; @@ -99,36 +102,6 @@ cb(__fp16, __fp16, float16x8_t, f16); #endif #undef cb -/*! - * \brief broadcast type - * BCAST_x[0]x[1]...: x[i] == !stride[i] - */ -enum BcastType { - VEC, - VEC_VEC, - VEC_BCAST101, - VEC_BCASTX0X, - VEC_BCAST111C, - VEC_BCAST101xX, - VEC_SCALAR, - SCALAR_VEC, - BCAST101_VEC, - BCASTX0X_VEC, - BCAST111C_VEC, - BCAST101xX_VEC, - VEC_VEC_VEC, - VEC_VEC_SCALAR, - BCAST101_VEC_BCAST101, - BCAST111C_VEC_BCAST111C, - BCAST101xX_VEC_BCAST101xX, - VEC_BCAST101_VEC, - VEC_BCAST111C_VEC, - VEC_BCAST101xX_VEC, - VEC_SCALAR_VEC, - VEC_SCALAR_SCALAR, - UNKNOWN_BCAST_TYPE -}; - ///////////////////////////////// OpCaller ///////////////////////////// template struct OpCallerUnary; diff --git a/dnn/src/fallback/elemwise/opr_binary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp similarity index 95% rename from dnn/src/fallback/elemwise/opr_binary_impl.cpp rename to dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp index 9acda94e..e102d400 100644 --- a/dnn/src/fallback/elemwise/opr_binary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp @@ -1,14 +1,7 @@ /** - * \file dnn/src/fallback/elemwise/opr_binary_impl.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * \file dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp */ -#include "./opr_impl.h" +#include "src/fallback/elemwise/opr_impl.h" #include "src/common/elemwise/kern_defs.cuh" #include "src/common/utils.h" diff --git a/dnn/src/fallback/elemwise/opr_unary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp similarity index 87% rename from dnn/src/fallback/elemwise/opr_unary_impl.cpp rename to dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp index af829358..a7b64f59 100644 --- a/dnn/src/fallback/elemwise/opr_unary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp @@ -1,14 +1,7 @@ /** - * \file dnn/src/fallback/elemwise/opr_unary_impl.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * \file dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp */ -#include "./opr_impl.h" +#include "src/fallback/elemwise/opr_impl.h" #include "src/common/elemwise/kern_defs.cuh" #include "src/common/utils.h" diff --git a/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp b/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp new file mode 100644 index 00000000..341b3809 --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp @@ -0,0 +1,535 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp + */ +#include "src/fallback/elemwise/gi_impl/binary/algo.h" +#include "src/fallback/elemwise_op.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_elemwise_binary) + +using namespace megdnn; +using namespace fallback; + +namespace { +static inline bool is_available_common(Elemwise::Mode mode) { + /** + * Fused sigmoid & tanh may be slower than the naive algo, because the + * time used by neon function `exp_ps_f32` is decided by the input. + */ + if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID || + mode == Elemwise::Mode::FUSE_ADD_TANH) { + return false; + } + + return true; +} +} // anonymous namespace + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ + mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ + mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \ + mode == Mode::FUSE_ADD_H_SWISH) \ + return true; + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ + mode == Mode::SUB || mode == Mode::MUL || mode == Mode::FUSE_ADD_RELU) \ + return true; + +bool ElemwiseImpl::AlgoBinaryVecVec::is_available(const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + (BcastType::VEC_VEC != kern_param.broad_cast_type)) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + //! exactly match [x, y] + [x, y] + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecScalar::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) && + (BcastType::SCALAR_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::is_available"_hash); + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) && + (BcastType::BCAST101_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) && + (BcastType::BCASTX0X_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) && + (BcastType::BCAST111C_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::is_available"_hash); + + return false; +} + +bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( + const KernParam& kern_param) const { + if (!is_available_common(kern_param.mode) || + ((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) && + (BcastType::BCAST101xX_VEC != kern_param.broad_cast_type))) + return false; + + auto& elparam = kern_param.binary_elparam; + auto& src0 = elparam[0]; + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::is_available"_hash); + + return false; +} + +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ + DISPATCH_BINARY( \ + FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + + //! exactly match [x, y] + [x, y] +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::exec"_hash); + +#undef DISPATCH_BINARY + + return; +} + +void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + + // Case 2: vector + scalar +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr())[0], \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) { + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_vec_sca"_hash); + } +#undef DISPATCH_BINARY + + // scalar + vector +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr())[0], \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src1.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) { + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_sca_vec"_hash); + } +#undef DISPATCH_BINARY + + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case 3: BcastType::VEC + BCAST_101 + if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type && + is_broadcasted_channel_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_101 + BcastType::VEC + if (BcastType::BCAST101_VEC == kern_param.broad_cast_type && + is_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case: BcastType::VEC + BCAST_X0X + if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type && + is_broadcasted_3dim_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_X0X + BcastType::VEC + if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type && + is_broadcasted_3dim_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // Case extra: BcastType::VEC + BCAST_111C + if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_111C + BcastType::VEC + if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type && + is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + +void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.binary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1]; + auto&& dst = *(kern_param.m_dst); + BroadcastChannelInfo binfo; + + // BcastType::VEC + BCAST_101X + if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) { + megdnn_assert( + is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ + binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_vec_b"_hash); + +#undef DISPATCH_BINARY + } + + // BCAST_101x + BcastType::VEC + if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) { + megdnn_assert( + is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ + binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + + DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_b_vec"_hash); + +#undef DISPATCH_BINARY + } + return; +} + +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/binary/algo.h b/dnn/src/fallback/elemwise/gi_impl/binary/algo.h new file mode 100644 index 00000000..10fb2cda --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/binary/algo.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.h + */ + +#pragma once +#include "src/fallback/elemwise/opr_impl.h" + +namespace megdnn { +namespace fallback { + +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoBinary##case final : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + AlgoAttribute attribute() const override { \ + return AlgoAttribute::REPRODUCIBLE; \ + } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ + }; + +DECL_CB(VecVec); +DECL_CB(VecScalar); +DECL_CB(VecBcast101); +DECL_CB(VecBcastX0X); +DECL_CB(VecBcast111C); +DECL_CB(VecBcast101xX); +#undef DECL_CB +} // namespace fallback +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/gi_mathfun.cpp b/dnn/src/fallback/elemwise/gi_impl/gi_mathfun.cpp new file mode 100644 index 00000000..6c30af27 --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/gi_mathfun.cpp @@ -0,0 +1,383 @@ +/** + * \file dnn/src/fallback/elemwise/gi_mathfun.cpp + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. + * + */ + +/* NEON implementation of sin, cos, exp and log + + Inspired by Intel Approximate Math library, and based on the + corresponding algorithms of the cephes math library +*/ + +/* Copyright (C) 2011 Julien Pommier + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + (this is the zlib license) +*/ + +#include "./gi_mathfun.h" + +namespace megdnn { +namespace fallback { + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +/** + * natural logarithm computed for 4 simultaneous float return NaN for x <= 0 + */ +v4sf GiLogPsFloat32(v4sf x) { + v4sf one = GiBroadcastFloat32(1); + + x = GiMaximumFloat32( + x, GiBroadcastFloat32(0)); /* force flush to zero on denormal values */ + v4su invalid_mask = GiLessThanEqFloat32(x, GiBroadcastFloat32(0)); + + v4si ux = GiReinterpretAsInt32(x); + + v4si emm0 = GiShiftRight23Int32(ux); + + /* keep only the fractional part */ + ux = GiAndInt32(ux, GiBroadcastInt32(c_inv_mant_mask)); + ux = GiOrInt32(ux, GiReinterpretAsInt32(GiBroadcastFloat32(0.5f))); + x = GiReintInt32ToFloat32(ux); + + emm0 = GiSubtractInt32(emm0, GiBroadcastInt32(0x7f)); + v4sf e = GiCastToFloat32(emm0); + + e = GiAddFloat32(e, one); + + /* part2: + * if( x < SQRTHF ) { + * e -= 1; + * x = x + x - 1.0; + * } else { x = x - 1.0; } + */ + v4su mask = GiLessThanFloat32(x, GiBroadcastFloat32(c_cephes_SQRTHF)); + v4sf tmp = GiAndFloat32(x, GiReintUint32ToFloat32(mask)); + x = GiSubtractFloat32(x, one); + e = GiSubtractFloat32(e, GiAndFloat32(one, GiReintUint32ToFloat32(mask))); + x = GiAddFloat32(x, tmp); + + v4sf z = GiMultiplyFloat32(x, x); + + v4sf y = GiBroadcastFloat32(c_cephes_log_p0); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p1), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p2), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p3), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p4), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p5), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p6), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p7), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p8), y, x); + y = GiMultiplyFloat32(y, x); + + y = GiMultiplyFloat32(y, z); + + y = GiMultiplyAddFloat32(y, e, GiBroadcastFloat32(c_cephes_log_q1)); + + y = GiMultiplySubFloat32(y, z, GiBroadcastFloat32(0.5f)); + + x = GiAddFloat32(x, y); + x = GiMultiplyAddFloat32(x, e, GiBroadcastFloat32(c_cephes_log_q2)); + x = GiOrFloat32( + x, GiReintUint32ToFloat32(invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +/* exp() computed for 4 float at once */ +v4sf GiExpPsFloat32(v4sf x) { + v4sf tmp, fx; + + v4sf one = GiBroadcastFloat32(1); + x = GiMinimumFloat32(x, GiBroadcastFloat32(c_exp_hi)); + x = GiMaximumFloat32(x, GiBroadcastFloat32(c_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = GiMultiplyAddFloat32( + GiBroadcastFloat32(0.5f), x, GiBroadcastFloat32(c_cephes_LOG2EF)); + + /* perform a floorf */ + tmp = GiCastToFloat32(GiCastToInt32(fx)); + + /* if greater, subtract 1 */ + v4su mask = GiGreaterThanFloat32(tmp, fx); + v4sf mask_float = GiAndFloat32(GiReintUint32ToFloat32(mask), one); + + fx = GiSubtractFloat32(tmp, mask_float); + + tmp = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C1)); + v4sf z = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C2)); + x = GiSubtractFloat32(x, tmp); + x = GiSubtractFloat32(x, z); + + z = GiMultiplyFloat32(x, x); + + v4sf y = GiBroadcastFloat32(c_cephes_exp_p0); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p1), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p2), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p3), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p4), y, x); + y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p5), y, x); + + y = GiMultiplyAddFloat32(x, y, z); + y = GiAddFloat32(y, one); + + /* build 2^n */ + v4si mm; + mm = GiCastToInt32(fx); + mm = GiAddInt32(mm, GiBroadcastInt32(0x7f)); + mm = GiShiftLeft23Int32(mm); + v4sf pow2n = GiReintInt32ToFloat32(mm); + + y = GiMultiplyFloat32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +/* evaluation of 4 sines & cosines at once. + + The code is the exact rewriting of the cephes sinf function. + Precision is excellent as long as x < 8192 (I did not bother to + take into account the special handling they have for greater values + -- it does not return garbage for arguments over 8192, though, but + the extra precision is missing). + + Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + surprising but correct result. + + Note also that when you compute sin(x), cos(x) is available at + almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of + sincos_ps_f32.. + */ +void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos) { + // any x + v4sf y; + + v4su emm2; + + v4su sign_mask_sin, sign_mask_cos; + sign_mask_sin = GiLessThanFloat32(x, GiBroadcastFloat32(0)); + x = GiAbsFloat32(x); + + /* scale by 4/Pi */ + y = GiMultiplyFloat32(x, GiBroadcastFloat32(c_cephes_FOPI)); + + /* store the integer part of y in mm0 */ + emm2 = GiReinterpretAsUint32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + emm2 = GiAddUint32(emm2, GiBroadcastUint32(1)); + emm2 = GiAddUint32(emm2, GiBroadcastUint32(~1)); + y = GiReintUint32ToFloat32(emm2); + + /* get the polynom selection mask + * there is one polynom for 0 <= x <= Pi/4 + * and another one for Pi/4(kern_param.mode))); \ + } +#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT +void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 1: shape of (src0, src2) and src1 are exactly match +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr())[0], \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecScalar::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 3: shape of src0 and src2 is {1, C, 1, 1} + BroadcastChannelInfo binfo; + is_broadcasted_channel_like(src0.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ + auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ + binfo, dst, run](size_t task_id, size_t) { \ + size_t offset = task_id * nr_channels_per_thread; \ + size_t nr_channels_thread = \ + std::min(nr_channels - offset, nr_channels_per_thread); \ + run(static_cast(src0.raw_ptr()) + offset, \ + static_cast(src1.raw_ptr()) + offset * binfo.z, \ + static_cast(src2.raw_ptr()) + offset, \ + static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ + binfo.y * binfo.z); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ + } \ + MIDOUT_END(); \ + return + + size_t nr_threads = static_cast(kern_param.handle) + ->megcore_dispatcher() + ->nr_threads(); + + size_t nr_channels = binfo.y; + size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 3: shape of src0 and src2 is {1, 1, 1, C} + BroadcastChannelInfo binfo; + is_NHWC_broadcasted_channel_like(src0.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST111C_VEC_BCAST111C>::run; \ + auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ + binfo, dst, run](size_t task_id, size_t) { \ + size_t offset = task_id * nr_channels_per_thread; \ + size_t nr_channels_thread = \ + std::min(nr_channels - offset, nr_channels_per_thread); \ + size_t src1_offset = \ + is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()) + \ + offset * (binfo.z + src1_offset), \ + src1_offset, static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ + src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ + dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ + binfo.y * binfo.z); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ + } \ + MIDOUT_END(); \ + return + + size_t nr_threads = static_cast(kern_param.handle) + ->megcore_dispatcher() + ->nr_threads(); + + size_t nr_channels = binfo.y; + size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + BroadcastChannelInfo binfo; + megdnn_assert( + is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + BroadcastChannelInfo binfo; + megdnn_assert( + is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101xXVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig + BroadcastChannelInfo binfo; + is_broadcasted_channel_like(src1.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101Vec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig + BroadcastChannelInfo binfo; + is_NHWC_broadcasted_channel_like(src1.layout, binfo); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ + static_cast(src1.raw_ptr()), \ + static_cast(src2.raw_ptr()), \ + is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast111CVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} + +void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr())[0], \ + static_cast(src2.raw_ptr()), \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarVec::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( + const KernParam& kern_param) const { + auto& elparam = kern_param.ternary_elparam; + auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; + + // Case 6: (src1 and src2 is scalar) && (src0 is vector) +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr()), \ + static_cast(src1.raw_ptr())[0], \ + static_cast(src2.raw_ptr())[0], \ + static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ + return + + auto&& dst = *(kern_param.m_dst); + DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarScalar::exec"_hash); +#undef DISPATCH_TERNARY + + return; +} +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/ternary/algo.h b/dnn/src/fallback/elemwise/gi_impl/ternary/algo.h new file mode 100644 index 00000000..4d96483f --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/ternary/algo.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.h + */ +#pragma once +#include "src/fallback/elemwise/opr_impl.h" + +namespace megdnn { +namespace fallback { + +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoTernaryFma3##case final : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + AlgoAttribute attribute() const override { \ + return AlgoAttribute::REPRODUCIBLE; \ + } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ + }; + +DECL_CB(VecVecVec); +DECL_CB(VecVecScalar); +DECL_CB(Bcast101VecBcast101); +DECL_CB(Bcast111CVecBcast111C); +DECL_CB(Bcast101xXVecBcast101xX); +DECL_CB(VecBcast101Vec); +DECL_CB(VecBcast111CVec); +DECL_CB(VecBcast101xXVec); +DECL_CB(VecScalarVec); +DECL_CB(VecScalarScalar); +#undef DECL_CB +} // namespace fallback +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp b/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp new file mode 100644 index 00000000..ffbac861 --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp @@ -0,0 +1,125 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp + */ +#include "src/fallback/elemwise/gi_impl/unary/algo.h" +#include "src/fallback/elemwise_op.h" + +#include "src/common/utils.h" +#include "src/naive/handle.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_fallback_elemwise_unary) + +using namespace megdnn; +using namespace fallback; + +bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { + if (BcastType::VEC != kern_param.broad_cast_type) + return false; + + if (kern_param.m_dst->layout.dtype.category() != DTypeCategory::FLOAT && + (kern_param.mode == Mode::EXP || kern_param.mode == Mode::SIGMOID || + kern_param.mode == Mode::TANH || kern_param.mode == Mode::FAST_TANH || + kern_param.mode == Mode::H_SWISH)) { + return false; + } + //! As `NEGATE` mode is so simple, that the code generate by compiler is + //! vectorized optimized, while other mode such as `ABS` has branch, the + //! compiler may not generate code as good as user intrinsic. + if (kern_param.mode == Mode::NEGATE) { + return false; + } + + auto& elparam = kern_param.unary_elparam; + if (!elparam[0].layout.is_contiguous()) + return false; + megdnn_assert(elparam[0].layout.ndim == 1); + auto& src0 = elparam[0]; + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::RELU || mode == Mode::ABS || mode == Mode::SIGMOID || \ + mode == Mode::EXP || mode == Mode::TANH || mode == Mode::FAST_TANH || \ + mode == Mode::H_SWISH) \ + return true; + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + auto mode = kern_param.mode; \ + if (mode == Mode::RELU || mode == Mode::ABS) \ + return true; + + DISPATCH_TYPE_FALLBACK("AlgoUnary::is_available"_hash); + return false; +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT +} + +void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { +#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_fallback_elemwise_unary, midout_iv(_case), \ + midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \ + thin_function run = \ + OpCallerUnary<_op<_type, _type>, BcastType::VEC>::run; \ + auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, run]( \ + size_t task_id, size_t) { \ + size_t offset = task_id * nr_elems_per_thread; \ + size_t nr_elems_thread = \ + std::min(nr_elems - offset, nr_elems_per_thread); \ + run(static_cast(src0.raw_ptr()) + offset, \ + static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \ + src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ + } \ + MIDOUT_END(); \ + return + + auto& elparam = kern_param.unary_elparam; + megdnn_assert(elparam[0].layout.ndim == 1); + auto& src0 = elparam[0]; + auto& dst_tensor = *(kern_param.m_dst); + + size_t nr_threads = static_cast(kern_param.handle) + ->megcore_dispatcher() + ->nr_threads(); + + size_t nr_elems = src0.layout.total_nr_elems(); + size_t nr_elems_per_thread = (nr_elems + nr_threads - 1) / nr_threads; + +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ + DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ + DISPATCH_UNARY(SIGMOID, _case, _type, _type_midout_id, SigmoidOp); \ + DISPATCH_UNARY(EXP, _case, _type, _type_midout_id, ExpOp); \ + DISPATCH_UNARY(TANH, _case, _type, _type_midout_id, TanhOp); \ + DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \ + DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ + DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ + } + + DISPATCH_TYPE_FALLBACK("AlgoUnary::exec"_hash); +#undef DISPATCH_MODE_FLOAT +#undef DISPATCH_MODE_INT +#undef DISPATCH_UNARY +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/gi_impl/unary/algo.h b/dnn/src/fallback/elemwise/gi_impl/unary/algo.h new file mode 100644 index 00000000..78c94ab1 --- /dev/null +++ b/dnn/src/fallback/elemwise/gi_impl/unary/algo.h @@ -0,0 +1,25 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.h + */ +#pragma once +#include "src/fallback/elemwise/opr_impl.h" +namespace megdnn { +namespace fallback { +class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { + mutable std::string m_name; + + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { + if (m_name.empty()) { + m_name = ssprintf("Elemwise::AlgoUnary"); + } + return m_name.c_str(); + } + + bool is_available(const KernParam&) const override; + void exec(const KernParam&) const override; +}; + +} // namespace fallback +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/opr_impl.cpp b/dnn/src/fallback/elemwise/opr_impl.cpp index eb4b2d9c..98e22374 100644 --- a/dnn/src/fallback/elemwise/opr_impl.cpp +++ b/dnn/src/fallback/elemwise/opr_impl.cpp @@ -12,6 +12,9 @@ #include "src/common/elemwise/kern_defs.cuh" #include "src/common/utils.h" +#include "src/fallback//elemwise/gi_impl/unary/algo.h" +#include "src/fallback/elemwise/gi_impl/binary/algo.h" +#include "src/fallback/elemwise/gi_impl/ternary/algo.h" #include "src/naive/handle.h" #include "midout.h" @@ -21,13 +24,22 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) -namespace megdnn { -namespace fallback { +using namespace megdnn; +using namespace fallback; void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { if (!dst.layout.is_contiguous()) { return naive::ElemwiseForwardImpl::exec(srcs, dst); } + if (!exec_gi_intrinsic(srcs, dst)) { + return exec_fallback(srcs, dst); + } +} + +void ElemwiseImpl::exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst) { + if (!dst.layout.is_contiguous()) { + return naive::ElemwiseForwardImpl::exec(srcs, dst); + } m_src = &srcs; m_dst = &dst; @@ -82,7 +94,229 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { naive::ElemwiseForwardImpl::exec(srcs, dst); } -} // namespace fallback -} // namespace megdnn +class ElemwiseImpl::AlgoPack { +#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7) + AlgoUnary algo_unary; + AlgoBinaryVecVec algo_binary_vec_vec; + AlgoBinaryVecScalar algo_binary_vec_sca; + AlgoBinaryVecBcast101 algo_binary_vec_bcast101; + AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X; + AlgoBinaryVecBcast111C algo_binary_vec_bcast110; + AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; + AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; + AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; + AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; + AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110; + AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; + AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; + AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec; + AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; + AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; + AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; +#endif + +public: + AlgoPack() { +#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7) + all_algos.emplace_back(&algo_unary); + all_algos.emplace_back(&algo_binary_vec_vec); + all_algos.emplace_back(&algo_binary_vec_sca); + all_algos.emplace_back(&algo_binary_vec_bcast101); + all_algos.emplace_back(&algo_binary_vec_bcastX0X); + all_algos.emplace_back(&algo_binary_vec_bcast110); + all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); + all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); + all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); + all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110); + all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); + all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); +#endif + } + SmallVector all_algos; +}; + +bool ElemwiseImpl::exec_gi_intrinsic( + const TensorNDArray& srcs, _megdnn_tensor_out dst) { + m_src = &srcs; + m_dst = &dst; + + if (m_dst->layout.dtype == dtype::Float32() || + m_dst->layout.dtype == dtype::Int32() || + m_dst->layout.dtype == dtype::Int16() || m_dst->layout.dtype == dtype::Int8()) { + auto kern_param = make_kern_param(this); + kern_param.m_dst = &dst; + static AlgoPack m_algo_pack; + for (auto& m_algo : m_algo_pack.all_algos) { + if (m_algo->is_available(kern_param)) { + m_algo->exec(kern_param); + return true; + } + } + } + return false; +} + +ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { + KernParam kern_param; + kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE; + kern_param.mode = opr->param().mode; + kern_param.handle = opr->handle(); + auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { + if (is_vector(l)) + return true; + if (l.ndim == 2 && l.stride[1] == 1) + return true; + return false; + }; + + if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { + kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); + bool c_is_scalar; + opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); + auto &src0 = kern_param.ternary_elparam[0], + &src1 = kern_param.ternary_elparam[1], + &src2 = kern_param.ternary_elparam[2]; + BroadcastChannelInfo binfo; + + if (is_vector(src0.layout) && is_vector(src1.layout) && + is_vector(src2.layout)) { + kern_param.broad_cast_type = BcastType::VEC_VEC_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) { + kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; + return kern_param; + } + + if (is_vector(src1.layout) && + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo)) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; + return kern_param; + } + + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && + is_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src1.layout) && + is_NHWC_broadcasted_channel_like(src0.layout, binfo) && + src0.layout.eq_layout(src2.layout)) { + kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src0.layout) && + src2.layout.eq_layout(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_vector(src2.layout) && + is_broadcasted_scalar(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && + is_broadcasted_scalar(src2.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR; + return kern_param; + } + } else if (opr->m_src->size() == 2) { + kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); + auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1]; + BroadcastChannelInfo binfo; + if (is_vector(src0.layout) && is_vector(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { + kern_param.broad_cast_type = BcastType::VEC_SCALAR; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { + kern_param.broad_cast_type = BcastType::SCALAR_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCAST101_VEC; + return kern_param; + } + + if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; + return kern_param; + } + + if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src1.layout) && + is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { + kern_param.broad_cast_type = BcastType::BCAST111C_VEC; + return kern_param; + } + + if (is_legal_layout_for_nhwc(src0.layout) && + is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { + kern_param.broad_cast_type = BcastType::VEC_BCAST111C; + return kern_param; + } + + if (is_vector(src0.layout) && + (is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo))) { + kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; + return kern_param; + } + + if (is_vector(src1.layout) && + (is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo))) { + kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; + return kern_param; + } + } else if (opr->m_src->size() == 1) { + kern_param.broad_cast_type = BcastType::VEC; + kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); + return kern_param; + } + + return kern_param; +} // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise/opr_impl.h b/dnn/src/fallback/elemwise/opr_impl.h index 57aa97bd..c7285ecb 100644 --- a/dnn/src/fallback/elemwise/opr_impl.h +++ b/dnn/src/fallback/elemwise/opr_impl.h @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once +#include "src/fallback/elemwise_op.h" #include "src/naive/elemwise/opr_impl.h" namespace megdnn { @@ -33,13 +34,69 @@ class ElemwiseImpl : public naive::ElemwiseForwardImpl { template void exec_BINARY_FLOAT(); + void exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst); + bool exec_gi_intrinsic(const TensorNDArray& srcs, _megdnn_tensor_out dst); + +private: + class AlgoUnary; + class AlgoBinaryVecVec; + class AlgoBinaryVecScalar; + class AlgoBinaryVecBcast101; + class AlgoBinaryVecBcastX0X; + class AlgoBinaryVecBcast111C; + class AlgoBinaryVecBcast101xX; + class AlgoTernaryFma3VecVecVec; + class AlgoTernaryFma3VecVecScalar; + class AlgoTernaryFma3Bcast101VecBcast101; + class AlgoTernaryFma3Bcast111CVecBcast111C; + class AlgoTernaryFma3Bcast101xXVecBcast101xX; + class AlgoTernaryFma3VecBcast101Vec; + class AlgoTernaryFma3VecBcast111CVec; + class AlgoTernaryFma3VecBcast101xXVec; + class AlgoTernaryFma3VecScalarVec; + class AlgoTernaryFma3VecScalarScalar; + class AlgoPack; + public: + class AlgoBase; + struct KernParam { + BcastType broad_cast_type; + Mode mode; + const TensorND* m_dst; + Handle* handle; + ElemwiseOpParamN<3> ternary_elparam; + ElemwiseOpParamN<2> binary_elparam; + ElemwiseOpParamN<1> unary_elparam; + }; + KernParam make_kern_param(ElemwiseImpl* opr); using naive::ElemwiseForwardImpl::ElemwiseForwardImpl; void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override; + const char* get_algorithm_set_name() const { return "FALLBACK ELEMWISE"; } bool is_thread_safe() const override { return true; } }; +/*! + * \brief base class for Elemwise algo + */ +class ElemwiseImpl::AlgoBase : public detail::Algorithm { +public: + virtual bool is_available(const KernParam&) const = 0; + virtual void exec(const KernParam&) const = 0; + virtual ~AlgoBase() = default; + uint32_t type() const override { return INVALID_ALGO_TYPE; }; +}; + +//! fallback only support float, int32, int8 +#define DISPATCH_TYPE_FALLBACK(_case) \ + if (src0.layout.dtype == dtype::Float32{}) { \ + DISPATCH_MODE_FLOAT(_case, float, 0); \ + } else if (src0.layout.dtype == dtype::Int32{}) { \ + DISPATCH_MODE_INT(_case, int, 2); \ + } else if (src0.layout.dtype == dtype::Int8{}) { \ + DISPATCH_MODE_INT(_case, dt_int8, 4); \ + } + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/elemwise_helper/kimpl/abs.h b/dnn/src/fallback/elemwise_helper/kimpl/abs.h new file mode 100644 index 00000000..8bec33b0 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/abs.h @@ -0,0 +1,78 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/abs.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct AbsOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : (-src); } +}; + +template +struct AbsOp; + +#define OP(_ctype, _gi_type, _func_suffix, _simd_width) \ + template <> \ + struct AbsOp<_ctype> : AbsOpBase<_ctype> { \ + using AbsOpBase::AbsOpBase; \ + using AbsOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _gi_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _gi_type operator()(const _gi_type& src) const { \ + auto vitem0 = GiAbs##_func_suffix(src.val[0]); \ + auto vitem1 = GiAbs##_func_suffix(src.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(dt_float32)) +OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32)) +OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8)) +#undef OP + +template <> +struct AbsOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale; + fsrc = fsrc > 0 ? fsrc : -fsrc; + return QConverter::convert(fsrc); + } +}; + +template <> +struct AbsOp : AbsOpBase { + using AbsOpBase::AbsOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using AbsOpBase::operator(); + void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8_FALLBACK; + } + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); + vitem0 = GiAbsFloat32(vitem0); + vitem1 = GiAbsFloat32(vitem1); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/add.h b/dnn/src/fallback/elemwise_helper/kimpl/add.h new file mode 100644 index 00000000..b0acd10c --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/add.h @@ -0,0 +1,134 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/add.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 + src1; + } +}; + +template +struct AddOp; + +#define OP(_ctype, _gi_type, _gi_type2, _func_suffix, _simd_width) \ + template <> \ + struct AddOp<_ctype> : AddOpBase<_ctype> { \ + using AddOpBase::AddOpBase; \ + using AddOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _gi_type2& src0, const _gi_type2& src1, dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _gi_type2 operator()(const _gi_type2& src0, const _gi_type2& src1) const { \ + auto vitem0 = GiAdd##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = GiAdd##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _gi_type& src0, const _gi_type& src1, dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _gi_type operator()(const _gi_type& src0, const _gi_type& src1) const { \ + return GiAdd##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, + GI_SIMD_LEN_BYTE / sizeof(dt_float32)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8)) +#undef OP + +template <> +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert( + src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1); + } +}; + +template <> +struct AddOp : AddOpBase { + using AddOpBase::AddOpBase; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + using AddOpBase::operator(); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct AddOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert( + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1); + } +}; + +template <> +struct AddOp : AddOpBase { + using AddOpBase::AddOpBase; + using AddOpBase::operator(); + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); + + void operator()( + const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, + dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/exp.h b/dnn/src/fallback/elemwise_helper/kimpl/exp.h new file mode 100644 index 00000000..4a6291dc --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/exp.h @@ -0,0 +1,49 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/exp.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct ExpOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + return exp(tmp); + } +}; + +template +struct ExpOp; + +#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ + template <> \ + struct ExpOp<_ctype> : ExpOpBase<_ctype> { \ + using ExpOpBase::ExpOpBase; \ + using ExpOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + auto vitem0 = GiExpPs##_func_suffix(src.val[0]); \ + auto vitem1 = GiExpPs##_func_suffix(src.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h b/dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h new file mode 100644 index 00000000..9fb0f37e --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h @@ -0,0 +1,70 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +//! tanh = x * (27 + x^2) / (27 + 9 * x^2) +template +struct FastTanhOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float x = src; + return x * (27.f + x * x) / (27.f + 9.f * x * x); + } +}; + +template +struct FastTanhOp; + +#define OP(_ctype, _simd_type, _func_suffix, _fix_func_suffix, _simd_width) \ + template <> \ + struct FastTanhOp<_ctype> : FastTanhOpBase<_ctype> { \ + using FastTanhOpBase::FastTanhOpBase; \ + using FastTanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + auto val_27 = GiBroadcast##_func_suffix(27.f); \ + auto val_9 = GiBroadcast##_func_suffix(9.f); \ + auto valx = src.val[0]; \ + auto valx1 = src.val[1]; \ + auto valxp2 = GiMultiply##_fix_func_suffix(valx, valx); \ + auto valx1p2 = GiMultiply##_fix_func_suffix(valx1, valx1); \ + auto denominator = GiAdd##_fix_func_suffix(valxp2, val_27); \ + auto denominator1 = GiAdd##_fix_func_suffix(valx1p2, val_27); \ + valx = GiMultiply##_fix_func_suffix(valx, denominator); \ + valx1 = GiMultiply##_fix_func_suffix(valx1, denominator1); \ + denominator = GiMultiplyAdd##_fix_func_suffix(val_27, valxp2, val_9); \ + denominator1 = GiMultiplyAdd##_fix_func_suffix(val_27, valx1p2, val_9); \ + auto r_denominator = GiRecpe##_func_suffix(denominator); \ + auto r_denominator1 = GiRecpe##_func_suffix(denominator1); \ + r_denominator = GiMultiply##_fix_func_suffix( \ + GiRecpeS##_func_suffix(denominator, r_denominator), \ + r_denominator); \ + r_denominator1 = GiMultiply##_fix_func_suffix( \ + GiRecpeS##_func_suffix(denominator1, r_denominator1), \ + r_denominator1); \ + valx = GiMultiply##_fix_func_suffix(valx, r_denominator); \ + valx1 = GiMultiply##_fix_func_suffix(valx1, r_denominator1); \ + return {{valx, valx1}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h new file mode 100644 index 00000000..70a7360a --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h @@ -0,0 +1,118 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h" +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct FuseAddHSwishOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmp = src0 + src1; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + return tmp; + } +}; + +template +struct FuseAddHSwishOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \ + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \ + using FuseAddHSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = GiAdd##_func_suffix(val1, val3); \ + val2 = GiAdd##_func_suffix(val2, val4); \ + H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + val1 = GiAdd##_func_suffix(val1, val2); \ + H_SWISH_KERN_N1_FALLBACK(_func_suffix, val1); \ + return val1; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +template <> +struct FuseAddHSwishOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + float tmp = + src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp); + } +}; + +template <> +struct FuseAddHSwishOp : FuseAddHSwishOpBase { + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; + using FuseAddHSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()( + const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, + dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + GI_FLOAT32_t vitem0, vitem1; + + vitem0 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1)); + vitem1 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1)); + H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1); + vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst); + vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h" + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h new file mode 100644 index 00000000..87ebb540 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h @@ -0,0 +1,162 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h + */ +#pragma once + +#include "gi_util_impl_helper.h" +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct FuseAddReluOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + auto tmp = src0 + src1; + return tmp > 0 ? tmp : 0; + } +}; + +template +struct FuseAddReluOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \ + using FuseAddReluOpBase::FuseAddReluOpBase; \ + using FuseAddReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, _func_suffix); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, _func_suffix); \ + return val1; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template +struct FuseAddReluOpCommon; + +template <> +struct FuseAddReluOpCommon { + inline static GI_FLOAT32_t vzero() { return GiBroadcastFloat32(0); } +}; + +template <> +struct FuseAddReluOpCommon { + inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); } +}; + +template <> +struct FuseAddReluOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert(std::max( + src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, 0.f)); + } +}; + +template <> +struct FuseAddReluOp : FuseAddReluOpBase, + FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + + vitem0 = GiMaximumFloat32(vitem0, this->vzero()); + vitem1 = GiMaximumFloat32(vitem1, this->vzero()); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct FuseAddReluOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { + return QConverter::convert(std::max( + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, 0.f)); + } +}; + +template <> +struct FuseAddReluOp : FuseAddReluOpBase, + FuseAddReluOpCommon { + using FuseAddReluOpBase::FuseAddReluOpBase; + using FuseAddReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + void operator()( + const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, + dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiAddFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + + vitem0 = GiMaximumFloat32(vitem0, this->vzero()); + vitem1 = GiMaximumFloat32(vitem1, this->vzero()); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h new file mode 100644 index 00000000..27924bdf --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h @@ -0,0 +1,60 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct FuseAddSigmoidOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmpf = src0 + src1; + tmpf = exp(-tmpf); + tmpf = 1.f / (1.f + tmpf); + return tmpf; + } +}; + +template +struct FuseAddSigmoidOp; + +#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \ + using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \ + using FuseAddSigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = GiAdd##_func_suffix(val1, val3); \ + val2 = GiAdd##_func_suffix(val2, val4); \ + val1 = GiSigmoidPs##_func_suffix(val1); \ + val2 = GiSigmoidPs##_func_suffix(val2); \ + return {{val1, val2}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h new file mode 100644 index 00000000..3efe2b09 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h @@ -0,0 +1,77 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct FuseAddTanhOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float tmpf = exp(src0 + (src1)); + float tmpf2 = 1 / tmpf; + return (tmpf - tmpf2) / (tmpf + tmpf2); + } +}; + +template +struct FuseAddTanhOp; + +#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \ + using FuseAddTanhOpBase::FuseAddTanhOpBase; \ + using FuseAddTanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = GiAdd##_func_suffix(val1, val3); \ + val2 = GiAdd##_func_suffix(val2, val4); \ + auto exp1 = GiExpPs##_func_suffix(val1); \ + auto exp2 = GiExpPs##_func_suffix(val2); \ + auto rexp1 = GiRecpe##_func_suffix(exp1); \ + auto rexp2 = GiRecpe##_func_suffix(exp2); \ + rexp1 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \ + rexp2 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \ + val1 = GiSubtract##_func_suffix(exp1, rexp1); \ + val2 = GiSubtract##_func_suffix(exp2, rexp2); \ + exp1 = GiAdd##_func_suffix(exp1, rexp1); \ + exp2 = GiAdd##_func_suffix(exp2, rexp2); \ + rexp1 = GiRecpe##_func_suffix(exp1); \ + rexp2 = GiRecpe##_func_suffix(exp2); \ + rexp1 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \ + rexp2 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \ + val1 = GiMultiply##_func_suffix(val1, rexp1); \ + val2 = GiMultiply##_func_suffix(val2, rexp2); \ + return {{val1, val2}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h b/dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h new file mode 100644 index 00000000..348d10ab --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h @@ -0,0 +1,60 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct FuseMulAdd3OpBase : TernaryOpBase { + using TernaryOpBase::TernaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, const src_ctype src2, + dst_ctype* dst) const { + *dst = operator()(src0, src1, src2); + } + + dst_ctype operator()( + const src_ctype& src0, const src_ctype& src1, const src_ctype& src2) const { + return (src0 * src1) + src2; + } +}; + +template +struct FuseMulAdd3Op; + +#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \ + using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \ + using FuseMulAdd3OpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + const _simd_type& src2, dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1, src2); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + const _simd_type& src2) const { \ + auto vitem0 = GiMultiplyAdd##_func_suffix( \ + src2.val[0], src0.val[0], src1.val[0]); \ + auto vitem1 = GiMultiplyAdd##_func_suffix( \ + src2.val[1], src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/gi_util_impl_helper.h b/dnn/src/fallback/elemwise_helper/kimpl/gi_util_impl_helper.h new file mode 100644 index 00000000..53a5a38b --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/gi_util_impl_helper.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/fallback/elemwise/gi_impl/gi_util_impl_helper.h + */ + +#pragma once + +/*! + * \brief compute fuse_add_relu on two simd packs + * + * Compute + * + * val1 = fuse_add_relu(val1, val3) + * val2 = fuse_add_relu(val2, val4) + * + * This algorithm handles int overflow. + */ +#define FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, func_suffix) \ + do { \ + val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val3)); \ + val2 = GiMaximum##func_suffix(val2, GiNeg##func_suffix(val4)); \ + val1 = GiAdd##func_suffix(val1, val3); \ + val2 = GiAdd##func_suffix(val2, val4); \ + } while (0) + +#define FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, func_suffix) \ + do { \ + val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val2)); \ + val1 = GiAdd##func_suffix(val1, val2); \ + } while (0) + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/hswish.h b/dnn/src/fallback/elemwise_helper/kimpl/hswish.h new file mode 100644 index 00000000..2dfcfc2b --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/hswish.h @@ -0,0 +1,108 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/hswish.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h" +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct HSwishOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + return (tmp); + } +}; + +//! h_swish(x) = x * clip(x + 3, 0, 6) / 6 +template +struct HSwishOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \ + using HSwishOpBase::HSwishOpBase; \ + using HSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type2 operator()(const _simd_type2& src) const { \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + auto val_zero = GiBroadcast##_func_suffix(0.f); \ + auto val_six = GiBroadcast##_func_suffix(6.f); \ + auto val_three = GiBroadcast##_func_suffix(3.f); \ + auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ + auto clip1 = GiMaximum##_func_suffix( \ + GiMinimum##_func_suffix( \ + GiAdd##_func_suffix(src, val_three), val_six), \ + val_zero); \ + return GiMultiply##_func_suffix( \ + GiMultiply##_func_suffix(src, clip1), val_rec_six); \ + } \ + }; + +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +template <> +struct HSwishOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint32& src) const { + float tmp = src.as_int32() * this->scale_src; + tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; + tmp *= this->scale_dst; + return QConverter::convert(tmp); + } +}; + +template <> +struct HSwishOp : HSwishOpBase { + using HSwishOpBase::HSwishOpBase; + using HSwishOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc)); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src); + + H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1); + vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst); + vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst); + + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h" +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h b/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h new file mode 100644 index 00000000..eb709b48 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h @@ -0,0 +1,7 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h + */ + +#undef H_SWISH_KERN_FALLBACK + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h b/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h new file mode 100644 index 00000000..d39368da --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h + */ + +#define H_SWISH_KERN_FALLBACK(_func_suffix, _val1, _val2) \ + do { \ + auto val_zero = GiBroadcast##_func_suffix(0.f); \ + auto val_six = GiBroadcast##_func_suffix(6.f); \ + auto val_three = GiBroadcast##_func_suffix(3.f); \ + auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ + auto clip1 = GiMaximum##_func_suffix( \ + GiMinimum##_func_suffix( \ + GiAdd##_func_suffix(_val1, val_three), val_six), \ + val_zero); \ + auto clip2 = GiMaximum##_func_suffix( \ + GiMinimum##_func_suffix( \ + GiAdd##_func_suffix(_val2, val_three), val_six), \ + val_zero); \ + _val1 = GiMultiply##_func_suffix( \ + GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \ + _val2 = GiMultiply##_func_suffix( \ + GiMultiply##_func_suffix(_val2, clip2), val_rec_six); \ + } while (0); + +#define H_SWISH_KERN_N1_FALLBACK(_func_suffix, _val1) \ + do { \ + auto val_zero = GiBroadcast##_func_suffix(0.f); \ + auto val_six = GiBroadcast##_func_suffix(6.f); \ + auto val_three = GiBroadcast##_func_suffix(3.f); \ + auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ + auto clip1 = GiMaximum##_func_suffix( \ + GiMinimum##_func_suffix( \ + GiAdd##_func_suffix(_val1, val_three), val_six), \ + val_zero); \ + _val1 = GiMultiply##_func_suffix( \ + GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \ + } while (0); + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/max.h b/dnn/src/fallback/elemwise_helper/kimpl/max.h new file mode 100644 index 00000000..31b38641 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/max.h @@ -0,0 +1,102 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/max.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { +template +struct MaxOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 > src1 ? src0 : src1; + } +}; + +template +struct MaxOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct MaxOp<_ctype> : MaxOpBase<_ctype> { \ + using MaxOpBase::MaxOpBase; \ + using MaxOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto vitem0 = GiMaximum##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = GiMaximum##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + return GiMaximum##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct MaxOpBase : BinaryOpBase { + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + using BinaryOpBase::BinaryOpBase; + + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + float fsrc0 = src0.as_int8() * this->scale0; + float fsrc1 = src1.as_int8() * this->scale1; + return QConverter::convert(fsrc0 > fsrc1 ? fsrc0 : fsrc1); + } +}; + +template <> +struct MaxOp : MaxOpBase { + using MaxOpBase::MaxOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MaxOpBase::operator(); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiMaximumFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiMaximumFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/min.h b/dnn/src/fallback/elemwise_helper/kimpl/min.h new file mode 100644 index 00000000..598fce33 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/min.h @@ -0,0 +1,99 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/min.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct MinOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 < src1 ? src0 : src1; + } +}; + +template +struct MinOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct MinOp<_ctype> : MinOpBase<_ctype> { \ + using MinOpBase::MinOpBase; \ + using MinOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto vitem0 = GiMinimum##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = GiMinimum##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + return GiMinimum##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct MinOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + float fsrc0 = src0.as_int8() * this->scale0; + float fsrc1 = src1.as_int8() * this->scale1; + return QConverter::convert(fsrc0 < fsrc1 ? fsrc0 : fsrc1); + } +}; + +template <> +struct MinOp : MinOpBase { + using MinOpBase::MinOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MinOpBase::operator(); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiMinimumFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiMinimumFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/mul.h b/dnn/src/fallback/elemwise_helper/kimpl/mul.h new file mode 100644 index 00000000..24da646a --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/mul.h @@ -0,0 +1,99 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/mul.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct MulOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 * src1; + } +}; + +template +struct MulOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct MulOp<_ctype> : MulOpBase<_ctype> { \ + using MulOpBase::MulOpBase; \ + using MulOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto vitem0 = GiMultiply##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = GiMultiply##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + return GiMultiply##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct MulOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert( + src0.as_int8() * scale_src0 * src1.as_int8() * scale1); + } +}; + +template <> +struct MulOp : MulOpBase { + using MulOpBase::MulOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using MulOpBase::operator(); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiMultiplyFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiMultiplyFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/none.h b/dnn/src/fallback/elemwise_helper/kimpl/none.h new file mode 100644 index 00000000..8ece32a3 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/none.h @@ -0,0 +1,77 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/none.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + dst_ctype operator()(const src_ctype& src) const { return src; } +}; + +template +struct NoneOp; +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ + NoneOp(){}; \ + NoneOp(float, float){}; \ + using NoneOpBase::NoneOpBase; \ + using NoneOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + _simd_type2 operator()(const _simd_type2& src) const { return src; } \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + GiStore##_func_suffix(dst, src.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \ + } \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + GiStore##_func_suffix(dst, src); \ + } \ + _simd_type operator()(const _simd_type& src) const { return src; } \ + }; + +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; } +}; + +template <> +struct NoneOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *(reinterpret_cast(dst)) = src; + } +}; + +#pragma GCC diagnostic ignored "-Waddress-of-packed-member" + +template <> +struct NoneOp : NoneOpBase { + using NoneOpBase::NoneOpBase; + using NoneOpBase::operator(); + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); + + void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { + GiStoreInt32(reinterpret_cast(dst), vsrc.val[0]); + GiStoreInt32(reinterpret_cast(dst + 16), vsrc.val[1]); + } + void operator()(const GI_INT32_t& src, dt_qint8* dst) const { + GiStoreInt32(reinterpret_cast(dst), src); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/op_base.h b/dnn/src/fallback/elemwise_helper/kimpl/op_base.h new file mode 100644 index 00000000..5affdd1a --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/op_base.h @@ -0,0 +1,450 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/op_base.h + */ +#pragma once + +#include +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "src/common/utils.h" +#include "src/fallback/elemwise/gi_impl/gi_mathfun.h" +#include "src/fallback/quantized_converter.h" + +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" + +namespace megdnn { +namespace fallback { + +////////////////////////// unary ////////////////////////// +template +struct OpBase { + using src_ctype = _src_ctype; + using dst_ctype = _dst_ctype; + OpBase() = default; +}; + +template +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + UnaryOpBase() = default; + UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {} +}; + +#define OPERATOR_UNARY_QINT8_FALLBACK \ + GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst), operator()( \ + {{GiMoveLowLongInt16(vsrct0), \ + GiMoveHighLongInt16(vsrct0)}})); \ + GI_INT16_t vsrct1 = GiMoveHighLongInt8(vsrc.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 8), \ + operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ + GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 16), \ + operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ + GI_INT16_t vsrct3 = GiMoveHighLongInt8(vsrc.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 24), \ + operator()({{GiMoveLowLongInt16(vsrct3), GiMoveHighLongInt16(vsrct3)}})) + +//! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul) +//! scale = src.scale / dst.scale +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + float scale_src, scale_dst; + GI_FLOAT32_t vscale_src, vscale_dst; + float scale; + GI_FLOAT32_t vscale; + + void init(float src_scale, float dst_scale) { + scale_src = src_scale; + vscale_src = GiBroadcastFloat32(scale_src); + scale_dst = 1.f / dst_scale; + vscale_dst = GiBroadcastFloat32(scale_dst); + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src_scale, dst_scale); + } + UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } +}; + +template <> +struct UnaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_qint8; + float scale; + GI_FLOAT32_t vscale; + float scale_src, scale_dst; + GI_FLOAT32_t vscale_src, vscale_dst; + + void init(float src_scale, float dst_scale) { + scale_src = src_scale; + vscale_src = GiBroadcastFloat32(src_scale); + scale_dst = 1 / dst_scale; + vscale_dst = GiBroadcastFloat32(scale_dst); + scale = src_scale / dst_scale; + vscale = GiBroadcastFloat32(scale); + } + + UnaryOpBase(DType src_dtype, DType dst_dtype) { + float src_scale = src_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src_scale, dst_scale); + } + + UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } +}; + +////////////////////////// binary ////////////////////////// +template +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + BinaryOpBase() = default; + BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*dst_dtype*/) {} +}; + +/* ================= binary op for quantized types ================== */ + +#define OPERATOR_BINARY_QINT8_FALLBACK \ + GI_INT16_t vsrct0_0 = GiMoveLowLongInt8(vsrc0.val[0]); \ + GI_INT16_t vsrct1_0 = GiMoveLowLongInt8(vsrc1.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0_0), GiMoveHighLongInt16(vsrct0_0)}}, \ + {{GiMoveLowLongInt16(vsrct1_0), GiMoveHighLongInt16(vsrct1_0)}})); \ + GI_INT16_t vsrct0_1 = GiMoveHighLongInt8(vsrc0.val[0]); \ + GI_INT16_t vsrct1_1 = GiMoveHighLongInt8(vsrc1.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 8), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0_1), GiMoveHighLongInt16(vsrct0_1)}}, \ + {{GiMoveLowLongInt16(vsrct1_1), GiMoveHighLongInt16(vsrct1_1)}})); \ + GI_INT16_t vsrct0_2 = GiMoveLowLongInt8(vsrc0.val[1]); \ + GI_INT16_t vsrct1_2 = GiMoveLowLongInt8(vsrc1.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 16), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0_2), GiMoveHighLongInt16(vsrct0_2)}}, \ + {{GiMoveLowLongInt16(vsrct1_2), GiMoveHighLongInt16(vsrct1_2)}})); \ + GI_INT16_t vsrct0_3 = GiMoveHighLongInt8(vsrc0.val[1]); \ + GI_INT16_t vsrct1_3 = GiMoveHighLongInt8(vsrc1.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 24), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0_3), GiMoveHighLongInt16(vsrct0_3)}}, \ + {{GiMoveLowLongInt16(vsrct1_3), GiMoveHighLongInt16(vsrct1_3)}})) + +//! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f / +//! dst.scale scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + float scale_src0, scale_src1, scale_dst; + GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst; + float scale0, scale1; + GI_FLOAT32_t vscale0, vscale1; + + void init(float src0_scale, float src1_scale, float dst_scale) { + scale_src0 = src0_scale; + vscale_src0 = GiBroadcastFloat32(scale_src0); + scale_src1 = src1_scale; + vscale_src1 = GiBroadcastFloat32(scale_src1); + scale_dst = 1.f / dst_scale; + vscale_dst = GiBroadcastFloat32(scale_dst); + scale0 = src0_scale / dst_scale; + vscale0 = GiBroadcastFloat32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = GiBroadcastFloat32(scale1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, dst_scale); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { + init(src0_scale, src1_scale, dst_scale); + } +}; + +template <> +struct BinaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint32; + using dst_ctype = dt_qint8; + float scale0, scale1; + GI_FLOAT32_t vscale0, vscale1; + float scale_src0, scale_src1, scale_dst; + GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst; + + void init(float src0_scale, float src1_scale, float dst_scale) { + scale_src0 = src0_scale; + vscale_src0 = GiBroadcastFloat32(src0_scale); + scale_src1 = src1_scale; + vscale_src1 = GiBroadcastFloat32(src1_scale); + scale_dst = 1 / dst_scale; + vscale_dst = GiBroadcastFloat32(scale_dst); + scale0 = src0_scale / dst_scale; + vscale0 = GiBroadcastFloat32(scale0); + scale1 = src1_scale / dst_scale; + vscale1 = GiBroadcastFloat32(scale1); + } + + BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, dst_scale); + } + + BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { + init(src0_scale, src1_scale, dst_scale); + } +}; + +////////////////////////// ternary ////////////////////////// +template +struct TernaryOpBase : OpBase { + using OpBase::OpBase; + TernaryOpBase() = default; + TernaryOpBase( + DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*src2_dtype*/, + DType /*dst_dtype*/) {} +}; + +#define OPERATOR_TERNARY_QINT8_FALLBACK \ + GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc0.val[0]); \ + GI_INT16_t vsrct1 = GiMoveLowLongInt8(vsrc1.val[0]); \ + GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc2.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ + {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ + {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ + vsrct0 = GiMoveHighLongInt8(vsrc0.val[0]); \ + vsrct1 = GiMoveHighLongInt8(vsrc1.val[0]); \ + vsrct2 = GiMoveHighLongInt8(vsrc2.val[0]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 8), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ + {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ + {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ + vsrct0 = GiMoveLowLongInt8(vsrc0.val[1]); \ + vsrct1 = GiMoveLowLongInt8(vsrc1.val[1]); \ + vsrct2 = GiMoveLowLongInt8(vsrc2.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 16), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ + {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ + {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ + vsrct0 = GiMoveHighLongInt8(vsrc0.val[1]); \ + vsrct1 = GiMoveHighLongInt8(vsrc1.val[1]); \ + vsrct2 = GiMoveHighLongInt8(vsrc2.val[1]); \ + GiStoreLowInt8( \ + reinterpret_cast(dst + 24), \ + operator()( \ + {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ + {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ + {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})) + +/*========================= ternaty op for quanzited ====================*/ +template <> +struct TernaryOpBase : OpBase { + using OpBase::OpBase; + using src_ctype = dt_qint8; + using dst_ctype = dt_qint8; + float scale_src0, scale_src1, scale_src2, scale_dst; + GI_FLOAT32_t vscale_src0, vscale_src1, vscale_src2, vscale_dst; + float scale0, scale1, scale2; + GI_FLOAT32_t vscale0, vscale1, vscale2; + void init(float src0_scale, float src1_scale, float src2_scale, float dst_scale) { + scale_src0 = src0_scale; + scale_src1 = src1_scale; + scale_src2 = src2_scale; + scale_dst = 1.f / dst_scale; + vscale_src0 = GiBroadcastFloat32(scale_src0); + vscale_src1 = GiBroadcastFloat32(scale_src1); + vscale_src2 = GiBroadcastFloat32(scale_src2); + vscale_dst = GiBroadcastFloat32(scale_dst); + scale0 = src0_scale / dst_scale; + scale1 = src1_scale / dst_scale; + scale2 = src2_scale / dst_scale; + vscale0 = GiBroadcastFloat32(scale0); + vscale1 = GiBroadcastFloat32(scale1); + vscale2 = GiBroadcastFloat32(scale2); + } + TernaryOpBase( + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) { + float src0_scale = src0_dtype.param().scale; + float src1_scale = src1_dtype.param().scale; + float src2_scale = src2_dtype.param().scale; + float dst_scale = dst_dtype.param().scale; + init(src0_scale, src1_scale, src2_scale, dst_scale); + } + TernaryOpBase( + float src0_scale, float src1_scale, float src2_scale, float dst_scale) { + init(src0_scale, src1_scale, src2_scale, dst_scale); + } +}; + +////////////////////////// fixup ////////////////////////// +struct FixupBase { + GI_INT32_t vmultiplier, vshift; + FixupBase(float scale) { + //! ignore Fixup if scale >= 0.5, using typecvt instead of shift & + //! multiplier, as it may introduce errors. + if (scale >= 0.5) + return; + + int shift = static_cast(::ceilf(::log2f(0.5 / scale))); + scale *= ::powf(2, shift); + //! Using double can get full precision here, but it can be ignored. + vmultiplier = GiBroadcastInt32( + std::round(static_cast(scale) * ((2LL) << 30))); + vshift = GiBroadcastInt32(-shift); + } +}; + +//////////////////////// quantization common //////////////////// +template +struct UnaryQuantizationOp; + +template +struct UnaryQuantizationOp : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale_src; + fsrc = op(fsrc); + fsrc = fsrc * this->scale_dst; + return QConverter::convert(fsrc); + } + + void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src); + auto val = this->op({{vitem0, vitem1}}); + val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); + val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +template +struct BinaryQuantizationOp; + +template +struct BinaryQuantizationOp : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + float fsrc0 = src0.as_int8() * this->scale_src0; + float fsrc1 = src1.as_int8() * this->scale_src1; + float fdst = op(fsrc0, fsrc1); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst); + } + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0); + auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0); + auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1); + auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1); + auto val = op({{val0, val1}}, {{val2, val3}}); + val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); + val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +template +struct TernaryQuantizationOp; + +template +struct TernaryQuantizationOp + : TernaryOpBase { + using TernaryOpBase::TernaryOpBase; + constexpr static size_t SIMD_WIDTH = 16; + Op op; + + void operator()( + const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2, + dt_qint8* dst) const { + *dst = operator()(src0, src1, src2); + } + + dt_qint8 operator()( + const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2) const { + float fsrc0 = src0.as_int8() * this->scale_src0; + float fsrc1 = src1.as_int8() * this->scale_src1; + float fsrc2 = src2.as_int8() * this->scale_src2; + float fdst = op(fsrc0, fsrc1, fsrc2); + fdst = fdst * this->scale_dst; + return QConverter::convert(fdst); + } + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, + const GI_INT8_V2_t& vsrc2, dt_qint8* dst) const { + OPERATOR_TERNARY_QINT8_FALLBACK; + } + + GI_INT8_t operator()( + const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, + const GI_INT32_V2_t& vsrc2) const { + auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0); + auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0); + auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1); + auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1); + auto val4 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[0]), this->vscale_src2); + auto val5 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[1]), this->vscale_src2); + auto val = op({{val0, val1}}, {{val2, val3}}, {{val4, val5}}); + val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); + val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); + return QConverter::convert(val); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/pow.h b/dnn/src/fallback/elemwise_helper/kimpl/pow.h new file mode 100644 index 00000000..3402db4f --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/pow.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/pow.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +/////////////////////// POW float only //////////////////////////// +template +struct PowOp : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + constexpr static size_t SIMD_WIDTH = 1; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return powf(src0, src1); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/relu.h b/dnn/src/fallback/elemwise_helper/kimpl/relu.h new file mode 100644 index 00000000..7b8365a1 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/relu.h @@ -0,0 +1,188 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/relu.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; } +}; + +template +struct ReluOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ + using ReluOpBase::ReluOpBase; \ + using ReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()(const _simd_type2& src) const { \ + auto vzero = GiBroadcast##_func_suffix(0); \ + auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \ + auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + auto vzero = GiBroadcast##_func_suffix(0); \ + return GiMaximum##_func_suffix(src, vzero); \ + } \ + }; + +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint8& src, dt_qint8* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint8& src) const { + float fsrc = src.as_int8() * this->scale; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc); + } +}; + +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using ReluOpBase::operator(); + + void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { + OPERATOR_UNARY_QINT8_FALLBACK; + } + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vzero = GiBroadcastFloat32(0.f); + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); + vitem0 = GiMaximumFloat32(vitem0, vzero); + vitem1 = GiMaximumFloat32(vitem1, vzero); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +template <> +struct ReluOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const dt_qint32& src, dt_qint8* dst) const { + *dst = operator()(src); + } + + dt_qint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + fsrc = std::max(fsrc, 0.f); + return QConverter::convert(fsrc); + } +}; + +//! if old armv7, special define relu with fixup +#if defined(__ARM_ARCH) && __ARM_ARCH < 8 +template <> +struct ReluOp : ReluOpBase, FixupBase { + using ReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = 4; + + ReluOp(DType src_dtype, DType dst_dtype) + : ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {} + + ReluOp(float src_scale, float dst_scale) + : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} + + void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { + vst1_s8(reinterpret_cast(dst), operator()(vsrc)); + } + + int8x8_t operator()(const int32x4x2_t& vsrc) const { + int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); + int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); + vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); + vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); + return vqmovn_s16(vcombine_s16( + vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + } + int8x8_t operator()(const float32x4_t& vsrc) const { + int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); + vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); + vitem0 = vrshlq_s32(vitem0, vshift); + int16x4_t vitem = vqmovn_s32(vitem0); + return vqmovn_s16(vcombine_s16(vitem, vitem)); + } + void operator()(const int32x4_t& src, dt_qint8* dst) const { + auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + auto result = QConverter::convert(vitem0); + vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)result, 0); + } + void operator()(const float32x4_t& src, dt_qint8* dst) const { + auto vitem0 = vmulq_f32(src, this->vscale); + vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); + auto result = QConverter::convert(vitem0); + vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)result, 0); + } +}; + +#else +template <> +struct ReluOp : ReluOpBase { + using ReluOpBase::ReluOpBase; + using ReluOpBase::operator(); + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); + + void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const GI_INT32_t& src, dt_qint8* dst) const { + GiStoreLane0Int32( + reinterpret_cast(dst), (GI_INT32_t)(operator()(src))); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); + vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); + vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero()); + + return QConverter::convert({{vitem0, vitem1}}); + } + GI_INT8_t operator()(const GI_INT32_t& src) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); + vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); + return QConverter::convert(vitem0); + } + GI_INT8_t operator()(const GI_FLOAT32_t& src) const { + auto vitem0 = GiMultiplyFloat32(src, this->vscale); + vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); + return QConverter::convert(vitem0); + } +}; + +#endif + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h b/dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h new file mode 100644 index 00000000..81930dae --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct SigmoidOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmpf = src; + tmpf = exp(-tmpf); + tmpf = 1.f / (1.f + tmpf); + return tmpf; + } +}; + +template +struct SigmoidOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ + using SigmoidOpBase::SigmoidOpBase; \ + using SigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + void operator()(const _simd_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type2 operator()(const _simd_type2& src) const { \ + return {{operator()(src.val[0]), operator()(src.val[1])}}; \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + return GiSigmoidPs##_func_suffix(src); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/sub.h b/dnn/src/fallback/elemwise_helper/kimpl/sub.h new file mode 100644 index 00000000..89ba3546 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/sub.h @@ -0,0 +1,97 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/sub.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct SubOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 - src1; + } +}; + +template +struct SubOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct SubOp<_ctype> : SubOpBase<_ctype> { \ + using SubOpBase::SubOpBase; \ + using SubOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto vitem0 = GiSubtract##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = GiSubtract##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + return GiSubtract##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) +OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) +#undef OP + +template <> +struct SubOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { + *dst = operator()(src0, src1); + } + dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { + return QConverter::convert( + src0.as_int8() * scale0 - src1.as_int8() * scale1); + } +}; + +template <> +struct SubOp : SubOpBase { + using SubOpBase::SubOpBase; + constexpr static size_t SIMD_WIDTH = 16; + using SubOpBase::operator(); + + void operator()( + const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { + OPERATOR_BINARY_QINT8_FALLBACK; + } + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { + auto vitem0 = GiSubtractFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); + auto vitem1 = GiSubtractFloat32( + GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), + GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); + return QConverter::convert({{vitem0, vitem1}}); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/tanh.h b/dnn/src/fallback/elemwise_helper/kimpl/tanh.h new file mode 100644 index 00000000..bfbb7091 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/tanh.h @@ -0,0 +1,81 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/tanh.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct TanhOpBase : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dst_ctype operator()(const src_ctype& src) const { + float tmp = src; + return tanh(tmp); + } +}; + +template +struct TanhOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ + using TanhOpBase::TanhOpBase; \ + using TanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _simd_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()(const _simd_type2& src) const { \ + auto one_val = GiBroadcast##_func_suffix(1.f); \ + auto two_val = GiBroadcast##_func_suffix(2.f); \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + val1 = GiMultiply##_func_suffix(two_val, val1); \ + val2 = GiMultiply##_func_suffix(two_val, val2); \ + val1 = GiExpPs##_func_suffix(val1); \ + val2 = GiExpPs##_func_suffix(val2); \ + val1 = GiAdd##_func_suffix(one_val, val1); \ + val2 = GiAdd##_func_suffix(one_val, val2); \ + auto rval1 = GiRecpe##_func_suffix(val1); \ + auto rval2 = GiRecpe##_func_suffix(val2); \ + rval1 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(val1, rval1), rval1); \ + rval2 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(val2, rval2), rval2); \ + val1 = GiMultiply##_func_suffix(two_val, rval1); \ + val2 = GiMultiply##_func_suffix(two_val, rval2); \ + val1 = GiSubtract##_func_suffix(one_val, val1); \ + val2 = GiSubtract##_func_suffix(one_val, val2); \ + return {{val1, val2}}; \ + } \ + _simd_type operator()(const _simd_type& src) const { \ + auto one_val = GiBroadcast##_func_suffix(1.f); \ + auto two_val = GiBroadcast##_func_suffix(2.f); \ + auto val1 = src; \ + val1 = GiMultiply##_func_suffix(two_val, val1); \ + val1 = GiExpPs##_func_suffix(val1); \ + val1 = GiAdd##_func_suffix(one_val, val1); \ + auto rval1 = GiRecpe##_func_suffix(val1); \ + rval1 = GiMultiply##_func_suffix( \ + GiRecpeS##_func_suffix(val1, rval1), rval1); \ + val1 = GiMultiply##_func_suffix(two_val, rval1); \ + val1 = GiSubtract##_func_suffix(one_val, val1); \ + return val1; \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/true_div.h b/dnn/src/fallback/elemwise_helper/kimpl/true_div.h new file mode 100644 index 00000000..71817845 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/true_div.h @@ -0,0 +1,68 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/true_div.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +//! use a couple Newton-Raphson steps to refine the estimate. +//! A / B => 1. rB = vrecpeq_f32(B) 2. rB= vmulq_f32(vrecpsq_f32(B, rB), rB) +//! 3. A * rB +template +struct TrueDivOpBase : BinaryOpBase { + using BinaryOpBase::BinaryOpBase; + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { + *dst = operator()(src0, src1); + } + dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { + return src0 / src1; + } +}; + +template +struct TrueDivOp; + +#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ + template <> \ + struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \ + using TrueDivOpBase::TrueDivOpBase; \ + using TrueDivOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _simd_type2& src0, const _simd_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem.val[0]); \ + GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _simd_type2 operator()( \ + const _simd_type2& src0, const _simd_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = GiDivide##_func_suffix(val1, val3); \ + val2 = GiDivide##_func_suffix(val2, val4); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _simd_type& src0, const _simd_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + GiStore##_func_suffix(dst, vitem); \ + } \ + _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ + return GiDivide##_func_suffix(src0, src1); \ + } \ + }; +OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) +#undef OP + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/kimpl/typecvt.h b/dnn/src/fallback/elemwise_helper/kimpl/typecvt.h new file mode 100644 index 00000000..32b9edf8 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/kimpl/typecvt.h @@ -0,0 +1,53 @@ +/** + * \file dnn/src/fallback/elemwise_helper/kimpl/typecvt.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/op_base.h" + +namespace megdnn { +namespace fallback { + +template +struct TypeCvtOp; + +template <> +struct TypeCvtOp : UnaryOpBase { + using UnaryOpBase::UnaryOpBase; + constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); + + void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { + GiStoreLowInt8(reinterpret_cast(dst), operator()(vsrc)); + } + void operator()(const GI_INT32_t& vsrc, dt_qint8* dst) const { + GiStoreLane0Int32( + reinterpret_cast(dst), (GI_INT32_t)(operator()(vsrc))); + } + void operator()(const src_ctype& src, dst_ctype* dst) const { + *dst = operator()(src); + } + dt_qint8 operator()(const dt_qint32& src) const { + float fsrc = src.as_int32() * this->scale; + return QConverter::convert(fsrc); + } + + GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); + auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); + + return QConverter::convert({{vitem0, vitem1}}); + } + GI_INT8_t operator()(const GI_INT32_t& src) const { + auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); + return QConverter::convert(vitem0); + } + GI_INT8_t operator()(const GI_FLOAT32_t& src) const { + auto vitem0 = GiMultiplyFloat32(src, this->vscale); + return QConverter::convert(vitem0); + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/op_binary.h b/dnn/src/fallback/elemwise_helper/op_binary.h new file mode 100644 index 00000000..6c171052 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/op_binary.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/fallback/elemwise_helper/op_binary.h + */ + +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/add.h" +#include "src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h" +#include "src/fallback/elemwise_helper/kimpl/fuse_add_relu.h" +#include "src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h" +#include "src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h" +#include "src/fallback/elemwise_helper/kimpl/max.h" +#include "src/fallback/elemwise_helper/kimpl/min.h" +#include "src/fallback/elemwise_helper/kimpl/mul.h" +#include "src/fallback/elemwise_helper/kimpl/pow.h" +#include "src/fallback/elemwise_helper/kimpl/sub.h" +#include "src/fallback/elemwise_helper/kimpl/true_div.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace fallback { +#define cb(op) \ + template <> \ + struct op \ + : BinaryQuantizationOp> { \ + using BinaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::BinaryQuantizationOp; \ + }; + +cb(TrueDivOp); +cb(FuseAddSigmoidOp); +cb(FuseAddTanhOp); +cb(FuseAddHSwishOp); + +#undef cb +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/op_common.h b/dnn/src/fallback/elemwise_helper/op_common.h new file mode 100644 index 00000000..6b24ad24 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/op_common.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/fallback/elemwise_helper/op_common.h + */ +#pragma once + +namespace megdnn { +/*! + * \brief broadcast type + * BCAST_x[0]x[1]...: x[i] == !stride[i] + */ +enum BcastType { + VEC, + VEC_VEC, + VEC_BCAST101, + VEC_BCASTX0X, + VEC_BCAST111C, + VEC_BCAST101xX, + VEC_SCALAR, + SCALAR_VEC, + BCAST101_VEC, + BCASTX0X_VEC, + BCAST111C_VEC, + BCAST101xX_VEC, + VEC_VEC_VEC, + VEC_VEC_SCALAR, + BCAST101_VEC_BCAST101, + BCAST111C_VEC_BCAST111C, + BCAST101xX_VEC_BCAST101xX, + VEC_BCAST101_VEC, + VEC_BCAST111C_VEC, + VEC_BCAST101xX_VEC, + VEC_SCALAR_VEC, + VEC_SCALAR_SCALAR, + UNKNOWN_BCAST_TYPE +}; + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/op_ternary.h b/dnn/src/fallback/elemwise_helper/op_ternary.h new file mode 100644 index 00000000..9d73ee56 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/op_ternary.h @@ -0,0 +1,24 @@ +/** + * \file dnn/src/fallback/elemwise_helper/op_ternary.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace fallback { +#define cb(op) \ + template <> \ + struct op \ + : TernaryQuantizationOp> { \ + using TernaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::TernaryQuantizationOp; \ + }; + +cb(FuseMulAdd3Op); +#undef cb +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_helper/op_unary.h b/dnn/src/fallback/elemwise_helper/op_unary.h new file mode 100644 index 00000000..3c17ec88 --- /dev/null +++ b/dnn/src/fallback/elemwise_helper/op_unary.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/fallback/elemwise_helper/op_unary.h + */ +#pragma once + +#include "src/fallback/elemwise_helper/kimpl/abs.h" +#include "src/fallback/elemwise_helper/kimpl/exp.h" +#include "src/fallback/elemwise_helper/kimpl/fast_tanh.h" +#include "src/fallback/elemwise_helper/kimpl/hswish.h" +#include "src/fallback/elemwise_helper/kimpl/none.h" +#include "src/fallback/elemwise_helper/kimpl/relu.h" +#include "src/fallback/elemwise_helper/kimpl/sigmoid.h" +#include "src/fallback/elemwise_helper/kimpl/tanh.h" +#include "src/fallback/elemwise_helper/kimpl/typecvt.h" + +//////////////////// quantization ////////////////////////////// +namespace megdnn { +namespace fallback { +#define cb(op) \ + template <> \ + struct op \ + : UnaryQuantizationOp> { \ + using UnaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::UnaryQuantizationOp; \ + }; + +cb(SigmoidOp); +cb(ExpOp); +cb(TanhOp); +cb(FastTanhOp); +cb(HSwishOp); +#undef cb +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/elemwise_op.h b/dnn/src/fallback/elemwise_op.h new file mode 100644 index 00000000..2bd961aa --- /dev/null +++ b/dnn/src/fallback/elemwise_op.h @@ -0,0 +1,1432 @@ +/** + * \file dnn/src/fallback/elemwise_op.h + */ + +#pragma once + +#include "src/fallback/elemwise_helper/op_binary.h" +#include "src/fallback/elemwise_helper/op_common.h" +#include "src/fallback/elemwise_helper/op_ternary.h" +#include "src/fallback/elemwise_helper/op_unary.h" + +#include "src/fallback/general_intrinsic/gi_float.h" +#include "src/fallback/general_intrinsic/gi_int.h" + +namespace megdnn { +namespace fallback { + +///////////////////////////////// ParamElemVistor /////////////////////////// +template +struct ParamElemVisitor; + +//! visitor single elemwise, and dup to vector +template +struct ParamElemVisitorDup; + +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(reinterpret_cast(src)); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiBroadcast##_fun_suffix( \ + *reinterpret_cast(src)); \ + } \ + } +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_qint8, int8_t, GI_INT8_t, Int8); + +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +cb(dt_int8, int8_t, GI_INT8_t, Int8); +#undef cb + +template +struct ParamElemVisitorBcast101x4; +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiReinter##rel_suffix##To##_fun_suffix(GiBroadcast##rel_suffix( \ + *reinterpret_cast(src))); \ + } \ + } + +cb(dt_qint8, int32_t, GI_INT8_t, Int8, Int32); +cb(dt_int8, int32_t, GI_INT8_t, Int8, Int32); +#undef cb +#define cb(_ctype, _inner_ctype, _simd_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _simd_type operator()(const _ctype* src) const { \ + return GiLoad##_fun_suffix(reinterpret_cast(src)); \ + } \ + } + +cb(dt_qint32, int32_t, GI_INT32_t, Int32); +cb(dt_float32, float, GI_FLOAT32_t, Float32); +cb(dt_int32, int32_t, GI_INT32_t, Int32); +#undef cb + +///////////////////////////////// OpCaller ///////////////////////////// +template +struct OpCallerUnary; + +template +struct OpCallerUnary { + static void run( + const typename Op::src_ctype* src, typename Op::dst_ctype* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { + Op op(src_dtype, dst_dtype); + ParamElemVisitor vis; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis(src), vis(src + Op::SIMD_WIDTH)}}, dst); + src += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src, dst); + src++; + dst++; + } + } +}; + +template +struct OpCallerBinary; + +///////////////////////// Pow //////////////////////////////// +template +struct OpCallerBinary, VEC_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_SCALAR> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, VEC_BCASTX0X> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST111C> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src1_ptr = src1; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, BCAST111C_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + const typename Op::src_ctype* src0_ptr = src0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary, SCALAR_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary, BCAST101_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerBinary, BCASTX0X_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src0_ptr = src0_ptr_base; +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, dst); + src0++; + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_simd = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_simd, src1_simd}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + dst++; + } + src1_ptr++; + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + size_t i = 0; + auto src1_ptr = src1_ptr_base; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0); + auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1_ptr); + auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, dst); + src0++; + src1_ptr++; + dst++; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src1_ptr = src1; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0); + auto src0_simd1 = vis(src0 + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1_ptr); + auto src1_simd1 = vis(src1_ptr + Op::SIMD_WIDTH); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0, *src1_ptr, dst); + dst++; + src0++; + src1_ptr++; + rest--; + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + for (size_t c = 0; c < channel; c++) { + size_t rest = channel_stride; + const typename Op::src_ctype* src0_ptr = src0; + while (rest >= Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + dst += Op::SIMD_WIDTH * 2; + rest -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + while (rest > 0) { + op(*src0_ptr, *src1, dst); + dst++; + src0_ptr++; + src1++; + rest--; + } + } + } + } +}; + +template +struct OpCallerBinary, BCAST101xX_VEC> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xDVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto channel_block_vec = vis0(src0_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec, channel_block_vec}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + OpCallerBinaryBcast101xDVec::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } else { + OpCallerBinaryBcast101xXVec::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } + } +}; + +template +struct OpCallerBinary, VEC_BCAST101xX> { + using Op = PowOp; + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t i = 0; i < channel_stride; i++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0), *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xD { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src0_offset <= channel_stride; + img_index += 2 * src0_offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinaryVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + OpCallerBinaryVecBcast101xD::run( + src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } else { + OpCallerBinaryVecBcast101xX::run( + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + auto vis1_simd = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, dst); + src0++; + dst++; + } + } +}; + +//! this only for nonswap op, like SUB and DIV +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + auto vis0_simd = vis0(&src0); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0_simd, vis0_simd}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(src0, *src1, dst); + src1++; + dst++; + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitorDup vis0; + ParamElemVisitor vis1; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + for (size_t c = 0; c < channel; c++) { + auto vis0_simd = vis0(src0_ptr); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0_simd, vis0_simd}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src1++; + dst++; + } + src0_ptr++; + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t b = 0; b < batch; b++) { + auto src0_ptr_base = src0 + b * channel_stride; + for (size_t c = 0; c < channel; c++) { + auto src0_ptr = src0_ptr_base; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, dst); + src0_ptr++; + src1++; + dst++; + } + } + } + } +}; + +template +struct OpCallerTernary; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, *src2, dst); + src0++; + src1++; + src2++; + dst++; + } + } +}; + +//! src0: vector, src1: vector, src2: scalar +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitorDup vis2; + auto vis2_simd = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_simd, vis2_simd}}, + dst); + src0 += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, *src1, src2, dst); + src0++; + src1++; + dst++; + } + } +}; + +//! src0: 1C11, src1: vector, src2: 1C11 +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride, + size_t batch_offset) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis1; + ParamElemVisitorDup vis0; + ParamElemVisitorDup vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + auto b_offset = batch_offset; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src0_simd = vis0(src0_ptr); + auto src2_simd = vis2(src2_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{src0_simd, src0_simd}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{src2_simd, src2_simd}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src1++; + dst++; + b_offset--; + } + src0_ptr++; + src2_ptr++; + } + src1 += b_offset; + dst += b_offset; + } + } +}; + +//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + size_t src1_offset, const typename Op::src_ctype* src2, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType src2_dtype, DType dst_dtype, size_t batch_size, size_t channel_size, + size_t channel_stride, size_t batch_offset) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis; + for (size_t batch = 0; batch < batch_size; batch++) { + auto b_offset = batch_offset; + for (size_t channel = 0; channel < channel_size; channel++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + auto src0_simd0 = vis(src0_ptr); + auto src0_simd1 = vis(src0_ptr + Op::SIMD_WIDTH); + auto src1_simd0 = vis(src1); + auto src1_simd1 = vis(src1 + Op::SIMD_WIDTH); + auto src2_simd0 = vis(src2_ptr); + auto src2_simd1 = vis(src2_ptr + Op::SIMD_WIDTH); + op({{src0_simd0, src0_simd1}}, {{src1_simd0, src1_simd1}}, + {{src2_simd0, src2_simd1}}, dst); + src0_ptr += Op::SIMD_WIDTH * 2; + src1 += Op::SIMD_WIDTH * 2; + src2_ptr += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + b_offset -= Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0_ptr, *src1, *src2_ptr, dst); + src0_ptr++; + src1++; + src2_ptr++; + dst++; + b_offset--; + } + src1 += src1_offset; + } + src1 += b_offset; + dst += b_offset; + } + } +}; + +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerTernaryBcast101xDVecBcast101xD { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src0_ptr = src0; + auto src2_ptr = src2; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src0_block_ptr = src0_ptr + cb * channel_block_dim; + auto src2_block_ptr = src2_ptr + cb * channel_block_dim; + auto channel_block_vec0 = vis0(src0_block_ptr); + auto channel_block_vec2 = vis2(src2_block_ptr); + size_t img_index = 0; + auto src1_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src1_offset <= channel_stride; + img_index += 2 * src1_offset) { + op({{channel_block_vec0, channel_block_vec0}}, + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + {{channel_block_vec2, channel_block_vec2}}, dst); + src1 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*(src0_block_ptr + c_iter), *src1, + *(src2_block_ptr + c_iter), dst); + src1++; + dst++; + } + } + } + } + } +}; + +//! src0: CHW44, src1: vector, src2: CHW44 +template +struct OpCallerTernaryBcast101xXVecBcast101xX { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitorBcast101x4 vis0; + ParamElemVisitor vis1; + ParamElemVisitorBcast101x4 vis2; + OpCallerTernaryBcast101xDVecBcast101xD::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + +//! src1: 1C11, src0 and src2 are contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + auto src1_ptr = src1; + for (size_t channel = 0; channel < channel_size; channel++) { + size_t i = 0; + auto src1_simd = vis1(src1_ptr); + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{src1_simd, src1_simd}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src2++; + dst++; + } + src1_ptr++; + } + } + } +}; + +//! src1: 111C, src0 and src2 may not be contig +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, size_t src0_offset, + const typename Op::src_ctype* src1, const typename Op::src_ctype* src2, + size_t src2_offset, typename Op::dst_ctype* dst, DType src0_dtype, + DType src1_dtype, DType src2_dtype, DType dst_dtype, size_t batch_size, + size_t channel_size, size_t channel_stride) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + ParamElemVisitor vis2; + for (size_t batch = 0; batch < batch_size; batch++) { + for (size_t channel = 0; channel < channel_size; channel++) { + auto src1_ptr = src1; + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= channel_stride; + i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{vis1(src1_ptr), vis1(src1_ptr + Op::SIMD_WIDTH)}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src1_ptr += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < channel_stride; i++) { + op(*src0, *src1_ptr, *src2, dst); + src0++; + src1_ptr++; + src2++; + dst++; + } + src0 += src0_offset; + src2 += src2_offset; + } + } + } +}; + +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + +//! src1: CHW44, src0 and src2 are contig +template +struct OpCallerTernaryVecBcast101xDVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * offset <= channel_stride; + img_index += 2 * offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + // TODO:all elemwise_multi_type op imp one simd mode + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), *src2, dst); + src0++; + src2++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerTernaryVecBcast101xXVec { + template + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { + ParamElemVisitor vis0; + ParamElemVisitorBcast101x4 vis1; + ParamElemVisitor vis2; + OpCallerTernaryVecBcast101xDVec::run( + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); + } +}; + +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); + + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + if (channel_block_dim == 4) { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } else { + OpCallerTernaryVecBcast101xXVec::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); + } + } +}; + +//! src1: scalar, src0 and src2 has the same shape +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitor vis2; + auto vis1_simd = vis1(&src1); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); + src0 += Op::SIMD_WIDTH * 2; + src2 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, *src2, dst); + src0++; + src2++; + dst++; + } + } +}; + +//! src1, src2: scalar, src0 is vector +template +struct OpCallerTernary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { + Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorDup vis1; + ParamElemVisitorDup vis2; + auto vis1_simd = vis1(&src1); + auto vis2_simd = vis2(&src2); + size_t i = 0; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_simd, vis1_simd}}, + {{vis2_simd, vis2_simd}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } +#if MEGDNN_FIX_AARCH32_BUG +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) +#endif + for (; i < nr_elems; i++) { + op(*src0, src1, src2, dst); + src0++; + dst++; + } + } +}; + +} // namespace fallback +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/general_intrinsic/gi_common.h b/dnn/src/fallback/general_intrinsic/gi_common.h index 41dcd1d2..8c3ff8cb 100644 --- a/dnn/src/fallback/general_intrinsic/gi_common.h +++ b/dnn/src/fallback/general_intrinsic/gi_common.h @@ -19,7 +19,7 @@ #include #else #if defined(__arm__) || defined(__aarch64__) -#include +#include "src/arm_common/simd_macro/marm_neon.h" #endif #if defined(__x86_64__) || defined(__i386__) #include diff --git a/dnn/src/fallback/quantized_converter.h b/dnn/src/fallback/quantized_converter.h index b842a862..bc0b4f7b 100644 --- a/dnn/src/fallback/quantized_converter.h +++ b/dnn/src/fallback/quantized_converter.h @@ -21,13 +21,13 @@ namespace megdnn { namespace fallback { struct QConverterBase { - inline static GI_INT32 vzero() { return GiBroadcastInt32(0); } + inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); } - inline static GI_FLOAT32 vfzero() { return GiBroadcastFloat32(0.f); } + inline static GI_FLOAT32_t vfzero() { return GiBroadcastFloat32(0.f); } - inline static GI_FLOAT32 vfhalf() { return GiBroadcastFloat32(0.5f); } + inline static GI_FLOAT32_t vfhalf() { return GiBroadcastFloat32(0.5f); } - inline static GI_FLOAT32 vfneg_half() { return GiBroadcastFloat32(-0.5f); } + inline static GI_FLOAT32_t vfneg_half() { return GiBroadcastFloat32(-0.5f); } }; struct QConverter { @@ -56,23 +56,23 @@ inline dt_qint32 QConverter::convert(const float& src) { } template <> -inline GI_FLOAT32_V2 QConverter::convert(const GI_INT16& vsrc) { - GI_INT32 vhi = GiMoveHighLongInt16(vsrc); - GI_INT32 vlo = GiMoveLowLongInt16(vsrc); +inline GI_FLOAT32_V2_t QConverter::convert(const GI_INT16_t& vsrc) { + GI_INT32_t vhi = GiMoveHighLongInt16(vsrc); + GI_INT32_t vlo = GiMoveLowLongInt16(vsrc); return {{GiCastToFloat32(vlo), GiCastToFloat32(vhi)}}; } template <> -inline GI_INT8 QConverter::convert(const GI_FLOAT32_V2& vsrc) { +inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) { return GiCvtFromFloat32V2ToInt8(vsrc); } template <> -inline GI_INT8 QConverter::convert(const GI_FLOAT32& src) { +inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) { return GiCvtFromFloat32ToInt8(src); } template <> -inline GI_INT32 QConverter::round(const GI_FLOAT32& vsrc) { +inline GI_INT32_t QConverter::round(const GI_FLOAT32_t& vsrc) { return GiRoundAsInt32(vsrc); } } // namespace fallback diff --git a/dnn/test/fallback/elemwise.cpp b/dnn/test/fallback/elemwise.cpp index 0f5e3c37..ef9487ea 100644 --- a/dnn/test/fallback/elemwise.cpp +++ b/dnn/test/fallback/elemwise.cpp @@ -38,6 +38,309 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { checker.set_rng(2, &rng); checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); } + + +TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + checker.set_param(Mode::FUSE_MUL_ADD3); + + auto run = [&] { + //! nchw44 + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + //! nchw44 + checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + + //! nchw88 + checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + + //! nchw88 + checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + + checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); + checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); + checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); + checker.execs({{1, 7}, {1, 7}, {1, 7}, {}}); + checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); + checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); + checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); + checker.execs({{3, 4, 5}, {1}, {1}, {}}); + checker.execs({{1}, {3, 4, 5}, {1}, {}}); + }; + + // case int + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + checker.set_dtype(2, dtype::Int8()); + run(); + + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Int16()); + checker.set_dtype(2, dtype::Int16()); + run(); + + checker.set_dtype(0, dtype::Int32()); + checker.set_dtype(1, dtype::Int32()); + checker.set_dtype(2, dtype::Int32()); + run(); + + // case float + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + run(); +} + +TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + auto run = [&]() { + // VEC_BCAST101x not PowOp + checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + // BCAST101x_VEC not PowOp + checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + }; + checker.set_dtype(0, dtype::Int8()); + checker.set_dtype(1, dtype::Int8()); + run(); + checker.set_dtype(0, dtype::Int16()); + checker.set_dtype(1, dtype::Int16()); + run(); + checker.set_dtype(0, dtype::Int32()); + checker.set_dtype(1, dtype::Int32()); + run(); +} + +TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_FP32) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + + auto run = [&](Mode mode) { + // VEC_BCAST101x + checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); + // BCAST101x_VEC not powOp + checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); + checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); + checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); + checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); + }; + run(Mode::ADD); + run(Mode::FUSE_ADD_H_SWISH); + run(Mode::FUSE_ADD_RELU); + run(Mode::MAX); + run(Mode::MIN); + run(Mode::MUL); + run(Mode::SUB); + run(Mode::TRUE_DIV); + run(Mode::POW); +} + +TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW88_FP) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(Mode::FUSE_ADD_RELU) + .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + + auto run = [&](Mode mode) { + // VEC_BCAST101x + checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); + // BCAST101x_VEC not powOp + checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); + checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); + checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); + checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); + checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); + }; + auto run_all = [&]() { + run(Mode::ADD); + run(Mode::FUSE_ADD_H_SWISH); + run(Mode::FUSE_ADD_RELU); + run(Mode::MAX); + run(Mode::MIN); + run(Mode::MUL); + run(Mode::SUB); + run(Mode::TRUE_DIV); + run(Mode::POW); + }; + + { + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + run_all(); + } +} + +TEST_F(FALLBACK, ELEMWISE_FORWARD_N1HW_FP32_BCAST) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + + //! 2 dim + auto run = [&](Mode mode) { + // VEC_BCASTX0X + checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}}); + checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}}); + // BCASTX0X_VEC + checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}}); + checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}}); + }; + run(Mode::ADD); + run(Mode::MUL); + run(Mode::SUB); +} + +TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY_RECORD) { + using Mode = ElemwiseForward::Param::Mode; + TaskRecordChecker checker(0); + checker.set_param(Mode::FUSE_MUL_ADD3); + + auto run = [&] { + //! nchw44 + checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); + + //! nchw88 + checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); + + checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); + checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); + }; + + // case int + checker.set_dtype(0, dtype::Int32()); + checker.set_dtype(1, dtype::Int32()); + checker.set_dtype(2, dtype::Int32()); + run(); + + // case float + UniformFloatRNG rng(1e-5, 7e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.set_dtype(0, dtype::Float32()); + checker.set_dtype(1, dtype::Float32()); + checker.set_dtype(2, dtype::Float32()); + run(); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { auto naive_handle = create_cpu_handle(2); diff --git a/dnn/test/x86/elemwise_bmark.cpp b/dnn/test/x86/elemwise_bmark.cpp index da620f3e..5f895083 100644 --- a/dnn/test/x86/elemwise_bmark.cpp +++ b/dnn/test/x86/elemwise_bmark.cpp @@ -164,3 +164,148 @@ TEST_F(X86, BENCHMARK_ELEM_EVERY_DTYPE) { // B.set_dtype(2, dtype::Int8()); // BENCHMARK_CASE_INT(1556011) } + +#if MEGDNN_WITH_BENCHMARK +namespace { +void run_elemwise_benchmark( + const TensorShapeArray& shapes, param::Elemwise::Mode mode, + const char* mode_str, DType type, Handle* handle_bench) { + auto handle_fallback = create_cpu_handle(1); + Benchmarker benchmarker_bench(handle_bench); + Benchmarker benchmarker_fallback(handle_fallback.get()); + + float throughput = 0; + SmallVector layouts; + std::string src_strs; + for (size_t i = 0; i < shapes.size(); i++) { + layouts.emplace_back(shapes[i], type); + throughput += layouts.back().span().dist_byte(); + src_strs += layouts.back().to_string(); + if (i != shapes.size() - 1) { + src_strs += ","; + } + } + constexpr size_t RUN = 50; + benchmarker_fallback.set_times(RUN).set_display(false); + benchmarker_bench.set_times(RUN).set_display(false); + + benchmarker_fallback.set_param(mode); + benchmarker_bench.set_param(mode); + + TensorLayout dst_layout; + auto opr = handle_bench->create_operator(); + opr->param() = mode; + opr->deduce_layout(layouts, dst_layout); + + float computations = + dst_layout.total_nr_elems() * (std::max(shapes.size(), 2) - 1); + throughput += dst_layout.span().dist_byte(); + computations *= (1e3 / (1024.0 * 1024)); + throughput *= (1e3 / (1024.0 * 1024)); + + layouts.emplace_back(dst_layout); + auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; + auto bench_time = benchmarker_bench.execl(layouts) / RUN; + + float fallback_flops = computations / fallback_time; + float bench_flops = computations / bench_time; + float fallback_thr = throughput / fallback_time; + float bench_thr = throughput / bench_time; + + printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " + "%fMB/s " + "computations: %fx, throughput: %fx\n", + src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str, + fallback_flops, fallback_thr, bench_flops, bench_thr, + bench_flops / fallback_flops, bench_thr / fallback_thr); +} +} // namespace + +#define INT_RUN(shape, mode) \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle()); + +#define FLOAT_RUN(shape, mode) \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \ + run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle()); + +#define BENCHMARK_CASES(shape) \ + INT_BENCHMARK_CASES(shape) \ + FLOAT_BENCHMARK_CASES(shape) + +TEST_F(X86, BENCHMARK_UNARY) { +#define INT_BENCHMARK_CASES(shape) \ + INT_RUN(shape, Mode::RELU); \ + INT_RUN(shape, Mode::ABS); + +#define FLOAT_BENCHMARK_CASES(shape) \ + FLOAT_RUN(shape, Mode::RELU); \ + FLOAT_RUN(shape, Mode::ABS); \ + FLOAT_RUN(shape, Mode::SIGMOID); \ + FLOAT_RUN(shape, Mode::EXP); \ + FLOAT_RUN(shape, Mode::TANH); \ + FLOAT_RUN(shape, Mode::FAST_TANH); + + using Mode = param::Elemwise::Mode; + BENCHMARK_CASES({{10000}}); + BENCHMARK_CASES({{50000}}); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +TEST_F(X86, BENCHMARK_BINARY) { +#define INT_BENCHMARK_CASES(shape) \ + INT_RUN(shape, Mode::MIN); \ + INT_RUN(shape, Mode::MAX); \ + INT_RUN(shape, Mode::ADD); \ + INT_RUN(shape, Mode::SUB); \ + INT_RUN(shape, Mode::MUL); \ + INT_RUN(shape, Mode::RMULH); \ + INT_RUN(shape, Mode::FUSE_ADD_RELU); + +#define FLOAT_BENCHMARK_CASES(shape) \ + FLOAT_RUN(shape, Mode::MIN); \ + FLOAT_RUN(shape, Mode::MAX); \ + FLOAT_RUN(shape, Mode::ADD); \ + FLOAT_RUN(shape, Mode::SUB); \ + FLOAT_RUN(shape, Mode::MUL); \ + FLOAT_RUN(shape, Mode::POW); \ + FLOAT_RUN(shape, Mode::TRUE_DIV); \ + FLOAT_RUN(shape, Mode::FUSE_ADD_RELU); + + using Mode = param::Elemwise::Mode; + TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}}; + BENCHMARK_CASES(shapes); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +TEST_F(X86, BENCHMARK_TERNARY_FMA3) { +#define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3); + +#define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3); + + using Mode = param::Elemwise::Mode; + TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}}; + BENCHMARK_CASES(shapes); + shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}}; + BENCHMARK_CASES(shapes); + shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}}; + BENCHMARK_CASES(shapes); + +#undef INT_BENCHMARK_CASES +#undef FLOAT_BENCHMARK_CASES +} + +#undef BENCHMARK_CASES +#undef INT_RUN +#undef FLOAT_RUN + +#endif