GitOrigin-RevId: 96ff2e88cc
tags/v1.10.0
| @@ -44,30 +44,29 @@ namespace { | |||
| break; | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>::run( \ | |||
| megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \ | |||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
| OC, OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_BIAS(_mode) \ | |||
| @@ -168,36 +167,33 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| #undef FOR_BIAS | |||
| #undef HANDLE_IDENTITY | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| #define FOR_NONLINEAR_UNARY(_op) \ | |||
| megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \ | |||
| bias_type, dst_type, N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerUnary<_op<opctype, opdtype>, megdnn::arm_common::VEC>::run( \ | |||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW); | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||
| megdnn::arm_common::OpCallerBinary< \ | |||
| _op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \ | |||
| run(static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<opctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const opctype*>(bias_ptr), \ | |||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define HANDLE_IDENTITY(_caller, _op) \ | |||
| case megdnn::NonlineMode::IDENTITY: \ | |||
| @@ -271,26 +267,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_BIAS | |||
| #define FOR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
| OC, OH* OW); | |||
| #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
| #define FOR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW); | |||
| #define FOR_BINARY_BROADCAST_NCHWXX(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ | |||
| OH* OW, pack_oc_size); | |||
| #define FOR_BINARY(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_BIAS(_bias_mode, OH, OW) \ | |||
| @@ -89,163 +89,4 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| fallback::ElemwiseImpl::exec(srcs, dst); | |||
| } | |||
| ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| KernParam kern_param; | |||
| kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE; | |||
| kern_param.mode = opr->param().mode; | |||
| kern_param.handle = opr->handle(); | |||
| auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { | |||
| if (is_vector(l)) | |||
| return true; | |||
| if (l.ndim == 2 && l.stride[1] == 1) | |||
| return true; | |||
| return false; | |||
| }; | |||
| if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { | |||
| kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); | |||
| bool c_is_scalar; | |||
| opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); | |||
| auto &src0 = kern_param.ternary_elparam[0], | |||
| &src1 = kern_param.ternary_elparam[1], | |||
| &src2 = kern_param.ternary_elparam[2]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && | |||
| is_vector(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && | |||
| (is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
| is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| src2.layout.eq_layout(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_vector(src2.layout) && | |||
| is_broadcasted_scalar(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && | |||
| is_broadcasted_scalar(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR; | |||
| return kern_param; | |||
| } | |||
| } else if (opr->m_src->size() == 2) { | |||
| kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); | |||
| auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && is_vector(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { | |||
| kern_param.broad_cast_type = BcastType::SCALAR_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && | |||
| (is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; | |||
| return kern_param; | |||
| } | |||
| } else if (opr->m_src->size() == 1) { | |||
| kern_param.broad_cast_type = BcastType::VEC; | |||
| kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); | |||
| return kern_param; | |||
| } | |||
| return kern_param; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -18,22 +18,12 @@ namespace megdnn { | |||
| namespace arm_common { | |||
| class ElemwiseImpl final : public fallback::ElemwiseImpl { | |||
| public: | |||
| using fallback::ElemwiseImpl::AlgoBase; | |||
| using fallback::ElemwiseImpl::ElemwiseImpl; | |||
| void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override; | |||
| const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; } | |||
| private: | |||
| struct KernParam { | |||
| BcastType broad_cast_type; | |||
| Mode mode; | |||
| const TensorND* m_dst; | |||
| Handle* handle; | |||
| ElemwiseOpParamN<3> ternary_elparam; | |||
| ElemwiseOpParamN<2> binary_elparam; | |||
| ElemwiseOpParamN<1> unary_elparam; | |||
| }; | |||
| KernParam make_kern_param(ElemwiseImpl* opr); | |||
| class AlgoBase; | |||
| class AlgoUnary; | |||
| class AlgoBinaryVecVec; | |||
| class AlgoBinaryVecScalar; | |||
| @@ -54,19 +44,6 @@ private: | |||
| class AlgoPack; | |||
| }; | |||
| /*! | |||
| * | |||
| * \brief base class for Elemwise algo | |||
| * | |||
| */ | |||
| class ElemwiseImpl::AlgoBase : public detail::Algorithm { | |||
| public: | |||
| virtual bool is_available(const KernParam&) const = 0; | |||
| virtual void exec(const KernParam&) const = 0; | |||
| virtual ~AlgoBase() = default; | |||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| #define DISPATCH_TYPE(_case) \ | |||
| if (src0.layout.dtype == dtype::Float32{}) { \ | |||
| @@ -15,10 +15,13 @@ | |||
| #include "src/arm_common/elemwise_helper/op_binary.h" | |||
| #include "src/arm_common/elemwise_helper/op_ternary.h" | |||
| #include "src/arm_common/elemwise_helper/op_unary.h" | |||
| #include "src/fallback/elemwise_helper/op_common.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| using BcastType = megdnn::BcastType; | |||
| ///////////////////////////////// ParamElemVistor /////////////////////////// | |||
| template <typename ctype> | |||
| struct ParamElemVisitor; | |||
| @@ -99,36 +102,6 @@ cb(__fp16, __fp16, float16x8_t, f16); | |||
| #endif | |||
| #undef cb | |||
| /*! | |||
| * \brief broadcast type | |||
| * BCAST_x[0]x[1]...: x[i] == !stride[i] | |||
| */ | |||
| enum BcastType { | |||
| VEC, | |||
| VEC_VEC, | |||
| VEC_BCAST101, | |||
| VEC_BCASTX0X, | |||
| VEC_BCAST111C, | |||
| VEC_BCAST101xX, | |||
| VEC_SCALAR, | |||
| SCALAR_VEC, | |||
| BCAST101_VEC, | |||
| BCASTX0X_VEC, | |||
| BCAST111C_VEC, | |||
| BCAST101xX_VEC, | |||
| VEC_VEC_VEC, | |||
| VEC_VEC_SCALAR, | |||
| BCAST101_VEC_BCAST101, | |||
| BCAST111C_VEC_BCAST111C, | |||
| BCAST101xX_VEC_BCAST101xX, | |||
| VEC_BCAST101_VEC, | |||
| VEC_BCAST111C_VEC, | |||
| VEC_BCAST101xX_VEC, | |||
| VEC_SCALAR_VEC, | |||
| VEC_SCALAR_SCALAR, | |||
| UNKNOWN_BCAST_TYPE | |||
| }; | |||
| ///////////////////////////////// OpCaller ///////////////////////////// | |||
| template <typename Op, BcastType bcast_type> | |||
| struct OpCallerUnary; | |||
| @@ -1,14 +1,7 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/opr_binary_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * \file dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/utils.h" | |||
| @@ -1,14 +1,7 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/opr_unary_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * \file dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/utils.h" | |||
| @@ -0,0 +1,535 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/binary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_binary) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| namespace { | |||
| static inline bool is_available_common(Elemwise::Mode mode) { | |||
| /** | |||
| * Fused sigmoid & tanh may be slower than the naive algo, because the | |||
| * time used by neon function `exp_ps_f32` is decided by the input. | |||
| */ | |||
| if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID || | |||
| mode == Elemwise::Mode::FUSE_ADD_TANH) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // anonymous namespace | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| auto mode = kern_param.mode; \ | |||
| if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ | |||
| mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \ | |||
| mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \ | |||
| mode == Mode::FUSE_ADD_H_SWISH) \ | |||
| return true; | |||
| #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ | |||
| auto mode = kern_param.mode; \ | |||
| if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \ | |||
| mode == Mode::SUB || mode == Mode::MUL || mode == Mode::FUSE_ADD_RELU) \ | |||
| return true; | |||
| bool ElemwiseImpl::AlgoBinaryVecVec::is_available(const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| (BcastType::VEC_VEC != kern_param.broad_cast_type)) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| //! exactly match [x, y] + [x, y] | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecScalar::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) && | |||
| (BcastType::SCALAR_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) && | |||
| (BcastType::BCAST101_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) && | |||
| (BcastType::BCASTX0X_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) && | |||
| (BcastType::BCAST111C_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::is_available"_hash); | |||
| return false; | |||
| } | |||
| bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( | |||
| const KernParam& kern_param) const { | |||
| if (!is_available_common(kern_param.mode) || | |||
| ((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) && | |||
| (BcastType::BCAST101xX_VEC != kern_param.broad_cast_type))) | |||
| return false; | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto& src0 = elparam[0]; | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::is_available"_hash); | |||
| return false; | |||
| } | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| switch (kern_param.mode) { \ | |||
| DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ | |||
| DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ | |||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | |||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | |||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | |||
| DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ | |||
| DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ | |||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | |||
| DISPATCH_BINARY( \ | |||
| FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ | |||
| default: \ | |||
| megdnn_throw(ssprintf( \ | |||
| "No avaiable algo find for: %d", \ | |||
| static_cast<int>(kern_param.mode))); \ | |||
| } | |||
| #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ | |||
| switch (kern_param.mode) { \ | |||
| DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ | |||
| DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ | |||
| DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ | |||
| DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ | |||
| DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ | |||
| DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ | |||
| default: \ | |||
| megdnn_throw(ssprintf( \ | |||
| "No avaiable algo find for: %d", \ | |||
| static_cast<int>(kern_param.mode))); \ | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| //! exactly match [x, y] + [x, y] | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t)> \ | |||
| run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::exec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| // Case 2: vector + scalar | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type, _type*, DType, DType, DType, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) { | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_vec_sca"_hash); | |||
| } | |||
| #undef DISPATCH_BINARY | |||
| // scalar + vector | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type, const _type*, _type*, DType, DType, DType, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr())[0], \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, \ | |||
| src1.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) { | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_sca_vec"_hash); | |||
| } | |||
| #undef DISPATCH_BINARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // Case 3: BcastType::VEC + BCAST_101 | |||
| if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type && | |||
| is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_101 + BcastType::VEC | |||
| if (BcastType::BCAST101_VEC == kern_param.broad_cast_type && | |||
| is_broadcasted_channel_like(src0.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // Case: BcastType::VEC + BCAST_X0X | |||
| if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type && | |||
| is_broadcasted_3dim_like(src1.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_X0X + BcastType::VEC | |||
| if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type && | |||
| is_broadcasted_3dim_like(src0.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // Case extra: BcastType::VEC + BCAST_111C | |||
| if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_111C + BcastType::VEC | |||
| if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | |||
| binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.binary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1]; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| BroadcastChannelInfo binfo; | |||
| // BcastType::VEC + BCAST_101X | |||
| if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) { | |||
| megdnn_assert( | |||
| is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo), | |||
| "only nchw44 and nchw88 supported"); | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | |||
| binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_vec_b"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| // BCAST_101x + BcastType::VEC | |||
| if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) { | |||
| megdnn_assert( | |||
| is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo), | |||
| "only nchw44 and nchw88 supported"); | |||
| #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_binary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, _type*, DType, DType, DType, size_t, \ | |||
| size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary< \ | |||
| _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | |||
| binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_b_vec"_hash); | |||
| #undef DISPATCH_BINARY | |||
| } | |||
| return; | |||
| } | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/binary/algo.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define DECL_CB(case) \ | |||
| class ElemwiseImpl::AlgoBinary##case final : public ElemwiseImpl::AlgoBase { \ | |||
| mutable std::string m_name; \ | |||
| AlgoAttribute attribute() const override { \ | |||
| return AlgoAttribute::REPRODUCIBLE; \ | |||
| } \ | |||
| const char* name() const override { \ | |||
| if (m_name.empty()) { \ | |||
| m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \ | |||
| } \ | |||
| return m_name.c_str(); \ | |||
| } \ | |||
| bool is_available(const KernParam&) const override; \ | |||
| void exec(const KernParam&) const override; \ | |||
| }; | |||
| DECL_CB(VecVec); | |||
| DECL_CB(VecScalar); | |||
| DECL_CB(VecBcast101); | |||
| DECL_CB(VecBcastX0X); | |||
| DECL_CB(VecBcast111C); | |||
| DECL_CB(VecBcast101xX); | |||
| #undef DECL_CB | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,383 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_mathfun.cpp | |||
| * | |||
| * This file has been modified by Megvii ("Megvii Modifications"). | |||
| * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights | |||
| * reserved. | |||
| * | |||
| */ | |||
| /* NEON implementation of sin, cos, exp and log | |||
| Inspired by Intel Approximate Math library, and based on the | |||
| corresponding algorithms of the cephes math library | |||
| */ | |||
| /* Copyright (C) 2011 Julien Pommier | |||
| This software is provided 'as-is', without any express or implied | |||
| warranty. In no event will the authors be held liable for any damages | |||
| arising from the use of this software. | |||
| Permission is granted to anyone to use this software for any purpose, | |||
| including commercial applications, and to alter it and redistribute it | |||
| freely, subject to the following restrictions: | |||
| 1. The origin of this software must not be misrepresented; you must not | |||
| claim that you wrote the original software. If you use this software | |||
| in a product, an acknowledgment in the product documentation would be | |||
| appreciated but is not required. | |||
| 2. Altered source versions must be plainly marked as such, and must not be | |||
| misrepresented as being the original software. | |||
| 3. This notice may not be removed or altered from any source distribution. | |||
| (this is the zlib license) | |||
| */ | |||
| #include "./gi_mathfun.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define c_inv_mant_mask ~0x7f800000u | |||
| #define c_cephes_SQRTHF 0.707106781186547524 | |||
| #define c_cephes_log_p0 7.0376836292E-2 | |||
| #define c_cephes_log_p1 -1.1514610310E-1 | |||
| #define c_cephes_log_p2 1.1676998740E-1 | |||
| #define c_cephes_log_p3 -1.2420140846E-1 | |||
| #define c_cephes_log_p4 +1.4249322787E-1 | |||
| #define c_cephes_log_p5 -1.6668057665E-1 | |||
| #define c_cephes_log_p6 +2.0000714765E-1 | |||
| #define c_cephes_log_p7 -2.4999993993E-1 | |||
| #define c_cephes_log_p8 +3.3333331174E-1 | |||
| #define c_cephes_log_q1 -2.12194440e-4 | |||
| #define c_cephes_log_q2 0.693359375 | |||
| /** | |||
| * natural logarithm computed for 4 simultaneous float return NaN for x <= 0 | |||
| */ | |||
| v4sf GiLogPsFloat32(v4sf x) { | |||
| v4sf one = GiBroadcastFloat32(1); | |||
| x = GiMaximumFloat32( | |||
| x, GiBroadcastFloat32(0)); /* force flush to zero on denormal values */ | |||
| v4su invalid_mask = GiLessThanEqFloat32(x, GiBroadcastFloat32(0)); | |||
| v4si ux = GiReinterpretAsInt32(x); | |||
| v4si emm0 = GiShiftRight23Int32(ux); | |||
| /* keep only the fractional part */ | |||
| ux = GiAndInt32(ux, GiBroadcastInt32(c_inv_mant_mask)); | |||
| ux = GiOrInt32(ux, GiReinterpretAsInt32(GiBroadcastFloat32(0.5f))); | |||
| x = GiReintInt32ToFloat32(ux); | |||
| emm0 = GiSubtractInt32(emm0, GiBroadcastInt32(0x7f)); | |||
| v4sf e = GiCastToFloat32(emm0); | |||
| e = GiAddFloat32(e, one); | |||
| /* part2: | |||
| * if( x < SQRTHF ) { | |||
| * e -= 1; | |||
| * x = x + x - 1.0; | |||
| * } else { x = x - 1.0; } | |||
| */ | |||
| v4su mask = GiLessThanFloat32(x, GiBroadcastFloat32(c_cephes_SQRTHF)); | |||
| v4sf tmp = GiAndFloat32(x, GiReintUint32ToFloat32(mask)); | |||
| x = GiSubtractFloat32(x, one); | |||
| e = GiSubtractFloat32(e, GiAndFloat32(one, GiReintUint32ToFloat32(mask))); | |||
| x = GiAddFloat32(x, tmp); | |||
| v4sf z = GiMultiplyFloat32(x, x); | |||
| v4sf y = GiBroadcastFloat32(c_cephes_log_p0); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p1), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p2), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p3), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p4), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p5), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p6), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p7), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p8), y, x); | |||
| y = GiMultiplyFloat32(y, x); | |||
| y = GiMultiplyFloat32(y, z); | |||
| y = GiMultiplyAddFloat32(y, e, GiBroadcastFloat32(c_cephes_log_q1)); | |||
| y = GiMultiplySubFloat32(y, z, GiBroadcastFloat32(0.5f)); | |||
| x = GiAddFloat32(x, y); | |||
| x = GiMultiplyAddFloat32(x, e, GiBroadcastFloat32(c_cephes_log_q2)); | |||
| x = GiOrFloat32( | |||
| x, GiReintUint32ToFloat32(invalid_mask)); // negative arg will be NAN | |||
| return x; | |||
| } | |||
| #define c_exp_hi 88.3762626647949f | |||
| #define c_exp_lo -88.3762626647949f | |||
| #define c_cephes_LOG2EF 1.44269504088896341 | |||
| #define c_cephes_exp_C1 0.693359375 | |||
| #define c_cephes_exp_C2 -2.12194440e-4 | |||
| #define c_cephes_exp_p0 1.9875691500E-4 | |||
| #define c_cephes_exp_p1 1.3981999507E-3 | |||
| #define c_cephes_exp_p2 8.3334519073E-3 | |||
| #define c_cephes_exp_p3 4.1665795894E-2 | |||
| #define c_cephes_exp_p4 1.6666665459E-1 | |||
| #define c_cephes_exp_p5 5.0000001201E-1 | |||
| /* exp() computed for 4 float at once */ | |||
| v4sf GiExpPsFloat32(v4sf x) { | |||
| v4sf tmp, fx; | |||
| v4sf one = GiBroadcastFloat32(1); | |||
| x = GiMinimumFloat32(x, GiBroadcastFloat32(c_exp_hi)); | |||
| x = GiMaximumFloat32(x, GiBroadcastFloat32(c_exp_lo)); | |||
| /* express exp(x) as exp(g + n*log(2)) */ | |||
| fx = GiMultiplyAddFloat32( | |||
| GiBroadcastFloat32(0.5f), x, GiBroadcastFloat32(c_cephes_LOG2EF)); | |||
| /* perform a floorf */ | |||
| tmp = GiCastToFloat32(GiCastToInt32(fx)); | |||
| /* if greater, subtract 1 */ | |||
| v4su mask = GiGreaterThanFloat32(tmp, fx); | |||
| v4sf mask_float = GiAndFloat32(GiReintUint32ToFloat32(mask), one); | |||
| fx = GiSubtractFloat32(tmp, mask_float); | |||
| tmp = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C1)); | |||
| v4sf z = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C2)); | |||
| x = GiSubtractFloat32(x, tmp); | |||
| x = GiSubtractFloat32(x, z); | |||
| z = GiMultiplyFloat32(x, x); | |||
| v4sf y = GiBroadcastFloat32(c_cephes_exp_p0); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p1), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p2), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p3), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p4), y, x); | |||
| y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p5), y, x); | |||
| y = GiMultiplyAddFloat32(x, y, z); | |||
| y = GiAddFloat32(y, one); | |||
| /* build 2^n */ | |||
| v4si mm; | |||
| mm = GiCastToInt32(fx); | |||
| mm = GiAddInt32(mm, GiBroadcastInt32(0x7f)); | |||
| mm = GiShiftLeft23Int32(mm); | |||
| v4sf pow2n = GiReintInt32ToFloat32(mm); | |||
| y = GiMultiplyFloat32(y, pow2n); | |||
| return y; | |||
| } | |||
| #define c_minus_cephes_DP1 -0.78515625 | |||
| #define c_minus_cephes_DP2 -2.4187564849853515625e-4 | |||
| #define c_minus_cephes_DP3 -3.77489497744594108e-8 | |||
| #define c_sincof_p0 -1.9515295891E-4 | |||
| #define c_sincof_p1 8.3321608736E-3 | |||
| #define c_sincof_p2 -1.6666654611E-1 | |||
| #define c_coscof_p0 2.443315711809948E-005 | |||
| #define c_coscof_p1 -1.388731625493765E-003 | |||
| #define c_coscof_p2 4.166664568298827E-002 | |||
| #define c_cephes_FOPI 1.27323954473516 // 4 / M_PI | |||
| /* evaluation of 4 sines & cosines at once. | |||
| The code is the exact rewriting of the cephes sinf function. | |||
| Precision is excellent as long as x < 8192 (I did not bother to | |||
| take into account the special handling they have for greater values | |||
| -- it does not return garbage for arguments over 8192, though, but | |||
| the extra precision is missing). | |||
| Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the | |||
| surprising but correct result. | |||
| Note also that when you compute sin(x), cos(x) is available at | |||
| almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of | |||
| sincos_ps_f32.. | |||
| */ | |||
| void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos) { | |||
| // any x | |||
| v4sf y; | |||
| v4su emm2; | |||
| v4su sign_mask_sin, sign_mask_cos; | |||
| sign_mask_sin = GiLessThanFloat32(x, GiBroadcastFloat32(0)); | |||
| x = GiAbsFloat32(x); | |||
| /* scale by 4/Pi */ | |||
| y = GiMultiplyFloat32(x, GiBroadcastFloat32(c_cephes_FOPI)); | |||
| /* store the integer part of y in mm0 */ | |||
| emm2 = GiReinterpretAsUint32(y); | |||
| /* j=(j+1) & (~1) (see the cephes sources) */ | |||
| emm2 = GiAddUint32(emm2, GiBroadcastUint32(1)); | |||
| emm2 = GiAddUint32(emm2, GiBroadcastUint32(~1)); | |||
| y = GiReintUint32ToFloat32(emm2); | |||
| /* get the polynom selection mask | |||
| * there is one polynom for 0 <= x <= Pi/4 | |||
| * and another one for Pi/4<x<=Pi/2 | |||
| * | |||
| * Both branches will be computed. | |||
| */ | |||
| v4su poly_mask = GiTestAndSetUint32(emm2, GiBroadcastUint32(2)); | |||
| /* The magic pass: "Extended precision modular arithmetic" | |||
| * x = ((x - y * DP1) - y * DP2) - y * DP3; */ | |||
| x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP1)); | |||
| x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP2)); | |||
| x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP3)); | |||
| sign_mask_sin = | |||
| GiEOrUint32(sign_mask_sin, GiTestAndSetUint32(emm2, GiBroadcastUint32(4))); | |||
| sign_mask_cos = GiTestAndSetUint32( | |||
| GiSubtractUint32(emm2, GiBroadcastUint32(2)), GiBroadcastUint32(4)); | |||
| /* Evaluate the first polynom (0 <= x <= Pi/4) in y1, | |||
| * and the second polynom (Pi/4 <= x <= 0) in y2 */ | |||
| v4sf z = GiMultiplyFloat32(x, x); | |||
| v4sf y1, y2; | |||
| y1 = GiMultiplyAddFloat32( | |||
| GiBroadcastFloat32(c_coscof_p1), z, GiBroadcastFloat32(c_coscof_p0)); | |||
| y2 = GiMultiplyAddFloat32( | |||
| GiBroadcastFloat32(c_sincof_p1), z, GiBroadcastFloat32(c_sincof_p0)); | |||
| y1 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_coscof_p2), y1, z); | |||
| y2 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_sincof_p2), y2, z); | |||
| y1 = GiMultiplyFloat32(y1, z); | |||
| y2 = GiMultiplyFloat32(y2, z); | |||
| y1 = GiMultiplyFloat32(y1, z); | |||
| y1 = GiMultiplySubFloat32(y1, z, GiBroadcastFloat32(0.5f)); | |||
| y2 = GiMultiplyAddFloat32(x, y2, x); | |||
| y1 = GiAddFloat32(y1, GiBroadcastFloat32(1)); | |||
| /* select the correct result from the two polynoms */ | |||
| v4sf ys = GiBSLFloat32(poly_mask, y1, y2); | |||
| v4sf yc = GiBSLFloat32(poly_mask, y2, y1); | |||
| *ysin = GiBSLFloat32(sign_mask_sin, GiNegFloat32(ys), ys); | |||
| *ycos = GiBSLFloat32(sign_mask_cos, yc, GiNegFloat32(yc)); | |||
| } | |||
| v4sf GiSinPsFloat32(v4sf x) { | |||
| v4sf ysin, ycos; | |||
| GiSinCosPsFloat32(x, &ysin, &ycos); | |||
| return ysin; | |||
| } | |||
| v4sf GiCosPsFloat32(v4sf x) { | |||
| v4sf ysin, ycos; | |||
| GiSinCosPsFloat32(x, &ysin, &ycos); | |||
| return ycos; | |||
| } | |||
| v4sf GiTanPsFloat32(v4sf x) { | |||
| v4sf ysin, ycos; | |||
| GiSinCosPsFloat32(x, &ysin, &ycos); | |||
| return ysin / ycos; | |||
| } | |||
| #undef c_exp_hi | |||
| #undef c_exp_lo | |||
| #undef c_cephes_LOG2EF | |||
| #undef c_cephes_exp_C1 | |||
| #undef c_cephes_exp_C2 | |||
| #undef c_cephes_exp_p0 | |||
| #undef c_cephes_exp_p1 | |||
| #undef c_cephes_exp_p2 | |||
| #undef c_cephes_exp_p3 | |||
| #undef c_cephes_exp_p4 | |||
| #undef c_cephes_exp_p5 | |||
| #undef c_minus_cephes_DP1 | |||
| #undef c_minus_cephes_DP2 | |||
| #undef c_minus_cephes_DP3 | |||
| #undef c_sincof_p0 | |||
| #undef c_sincof_p1 | |||
| #undef c_sincof_p2 | |||
| #undef c_coscof_p0 | |||
| #undef c_coscof_p1 | |||
| #undef c_coscof_p2 | |||
| #undef c_cephes_FOPI | |||
| #undef c_inv_mant_mask | |||
| #undef c_cephes_SQRTHF | |||
| #undef c_cephes_log_p0 | |||
| #undef c_cephes_log_p1 | |||
| #undef c_cephes_log_p2 | |||
| #undef c_cephes_log_p3 | |||
| #undef c_cephes_log_p4 | |||
| #undef c_cephes_log_p5 | |||
| #undef c_cephes_log_p6 | |||
| #undef c_cephes_log_p7 | |||
| #undef c_cephes_log_p8 | |||
| #undef c_cephes_log_q1 | |||
| #undef c_cephes_log_q2 | |||
| static const struct { | |||
| float lower_range; | |||
| float upper_range; | |||
| float alpha_9; | |||
| float alpha_7; | |||
| float alpha_5; | |||
| float alpha_3; | |||
| float alpha_1; | |||
| float beta_10; | |||
| float beta_8; | |||
| float beta_6; | |||
| float beta_4; | |||
| float beta_2; | |||
| float beta_0; | |||
| float one_half; | |||
| } sigmoid_constants = { | |||
| -18.0f, | |||
| 18.0f, | |||
| 4.37031012579801e-11f, | |||
| 1.15627324459942e-07f, | |||
| 6.08574864600143e-05f, | |||
| 8.51377133304701e-03f, | |||
| 2.48287947061529e-01f, | |||
| 6.10247389755681e-13f, | |||
| 5.76102136993427e-09f, | |||
| 6.29106785017040e-06f, | |||
| 1.70198817374094e-03f, | |||
| 1.16817656904453e-01f, | |||
| 9.93151921023180e-01f, | |||
| 0.5f, | |||
| }; | |||
| v4sf GiSigmoidPsFloat32(v4sf src) { | |||
| auto val = GiMaximumFloat32(GiBroadcastFloat32(sigmoid_constants.lower_range), src); | |||
| val = GiMinimumFloat32(GiBroadcastFloat32(sigmoid_constants.upper_range), val); | |||
| auto squared = GiMultiplyFloat32(val, val); | |||
| auto p = GiMultiplyAddFloat32( | |||
| GiBroadcastFloat32(sigmoid_constants.alpha_7), squared, | |||
| GiBroadcastFloat32(sigmoid_constants.alpha_9)); | |||
| p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_5), p, squared); | |||
| p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_3), p, squared); | |||
| p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_1), p, squared); | |||
| p = GiMultiplyFloat32(p, val); | |||
| auto q = GiMultiplyAddFloat32( | |||
| GiBroadcastFloat32(sigmoid_constants.beta_8), squared, | |||
| GiBroadcastFloat32(sigmoid_constants.beta_10)); | |||
| q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_6), q, squared); | |||
| q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_4), q, squared); | |||
| q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_2), q, squared); | |||
| q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_0), q, squared); | |||
| return GiAddFloat32( | |||
| GiDivideFloat32(p, q), GiBroadcastFloat32(sigmoid_constants.one_half)); | |||
| } | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/gi_mathfun.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "src/fallback/general_intrinsic/gi_int.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| typedef GI_FLOAT32_t v4sf; // vector of 4 float | |||
| typedef GI_INT32_t v4si; // vector of 4 int32 | |||
| typedef GI_UINT32_t v4su; // vector of 4 uint32 | |||
| /** | |||
| * \brief natural logarithm computed for 4 simultaneous float | |||
| * return NaN for x <= 0 | |||
| */ | |||
| v4sf GiLogPsFloat32(v4sf x); | |||
| //! exp() computed for 4 float at once | |||
| v4sf GiExpPsFloat32(v4sf x); | |||
| /** | |||
| * \brief evaluation of 4 sines & cosines at once. | |||
| * | |||
| * The code is the exact rewriting of the cephes sinf function. | |||
| * Precision is excellent as long as x < 8192 (I did not bother to | |||
| * take into account the special handling they have for greater values | |||
| * -- it does not return garbage for arguments over 8192, though, but | |||
| * the extra precision is missing). | |||
| * | |||
| * Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the | |||
| * surprising but correct result. | |||
| * | |||
| * Note also that when you compute sin(x), cos(x) is available at | |||
| * almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of | |||
| * sincos_ps_f32.. | |||
| */ | |||
| void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos); | |||
| v4sf GiSinPsFloat32(v4sf x); | |||
| v4sf GiCosPsFloat32(v4sf x); | |||
| v4sf GiTanPsFloat32(v4sf x); | |||
| v4sf GiSigmoidPsFloat32(v4sf x); | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,459 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_ternary) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| auto mode = kern_param.mode; \ | |||
| if (mode == Mode::FUSE_MUL_ADD3) \ | |||
| return true; | |||
| #define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT | |||
| #define DECL_AVAILABLE(case, type) \ | |||
| bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available( \ | |||
| const KernParam& kern_param) const { \ | |||
| if (type == kern_param.broad_cast_type) { \ | |||
| auto& elparam = kern_param.ternary_elparam; \ | |||
| auto& src0 = elparam[0]; \ | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3::is_available" #case##_hash); \ | |||
| } \ | |||
| return false; \ | |||
| } | |||
| DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); | |||
| DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); | |||
| DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); | |||
| DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C); | |||
| DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX); | |||
| DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); | |||
| DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC); | |||
| DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC); | |||
| DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); | |||
| DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); | |||
| #undef DECL_CB | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| switch (kern_param.mode) { \ | |||
| DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \ | |||
| default: \ | |||
| megdnn_throw(ssprintf( \ | |||
| "No avaiable algo find for: %d", \ | |||
| static_cast<int>(kern_param.mode))); \ | |||
| } | |||
| #define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT | |||
| void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 1: shape of (src0, src2) and src1 are exactly match | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecVec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape) | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type, _type*, DType, DType, \ | |||
| DType, DType, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr())[0], \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecScalar::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 3: shape of src0 and src2 is {1, C, 1, 1} | |||
| BroadcastChannelInfo binfo; | |||
| is_broadcasted_channel_like(src0.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | |||
| auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||
| binfo, dst, run](size_t task_id, size_t) { \ | |||
| size_t offset = task_id * nr_channels_per_thread; \ | |||
| size_t nr_channels_thread = \ | |||
| std::min(nr_channels - offset, nr_channels_per_thread); \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||
| static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \ | |||
| static_cast<const _type*>(src2.raw_ptr()) + offset, \ | |||
| static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||
| src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||
| dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||
| binfo.y * binfo.z); \ | |||
| }; \ | |||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||
| kernel); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||
| ->megcore_dispatcher() | |||
| ->nr_threads(); | |||
| size_t nr_channels = binfo.y; | |||
| size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 3: shape of src0 and src2 is {1, 1, 1, C} | |||
| BroadcastChannelInfo binfo; | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, size_t, const _type*, _type*, DType, \ | |||
| DType, DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, \ | |||
| BcastType::BCAST111C_VEC_BCAST111C>::run; \ | |||
| auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ | |||
| binfo, dst, run](size_t task_id, size_t) { \ | |||
| size_t offset = task_id * nr_channels_per_thread; \ | |||
| size_t nr_channels_thread = \ | |||
| std::min(nr_channels - offset, nr_channels_per_thread); \ | |||
| size_t src1_offset = \ | |||
| is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()) + \ | |||
| offset * (binfo.z + src1_offset), \ | |||
| src1_offset, static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ | |||
| src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ | |||
| dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ | |||
| binfo.y * binfo.z); \ | |||
| }; \ | |||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||
| kernel); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||
| ->megcore_dispatcher() | |||
| ->nr_threads(); | |||
| size_t nr_channels = binfo.y; | |||
| size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| BroadcastChannelInfo binfo; | |||
| megdnn_assert( | |||
| is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo), | |||
| "only nchw44 and nchw88 supported"); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, \ | |||
| BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| BroadcastChannelInfo binfo; | |||
| megdnn_assert( | |||
| is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo), | |||
| "only nchw44 and nchw88 supported"); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101xXVec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig | |||
| BroadcastChannelInfo binfo; | |||
| is_broadcasted_channel_like(src1.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type*, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101Vec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig | |||
| BroadcastChannelInfo binfo; | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo); | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, size_t, const _type*, const _type*, size_t, _type*, \ | |||
| DType, DType, DType, DType, size_t, size_t, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | |||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| binfo.x, binfo.y, binfo.z)); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast111CVec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape) | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type, const _type*, _type*, DType, DType, \ | |||
| DType, DType, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarVec::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( | |||
| const KernParam& kern_param) const { | |||
| auto& elparam = kern_param.ternary_elparam; | |||
| auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; | |||
| // Case 6: (src1 and src2 is scalar) && (src0 is vector) | |||
| #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_ternary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), _type_midout_id) { \ | |||
| thin_function<void( \ | |||
| const _type*, const _type, const _type, _type*, DType, DType, \ | |||
| DType, DType, size_t)> \ | |||
| run = OpCallerTernary< \ | |||
| _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||
| static_cast<const _type*>(src2.raw_ptr())[0], \ | |||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | |||
| src0.layout.total_nr_elems())); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto&& dst = *(kern_param.m_dst); | |||
| DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarScalar::exec"_hash); | |||
| #undef DISPATCH_TERNARY | |||
| return; | |||
| } | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define DECL_CB(case) \ | |||
| class ElemwiseImpl::AlgoTernaryFma3##case final : public ElemwiseImpl::AlgoBase { \ | |||
| mutable std::string m_name; \ | |||
| AlgoAttribute attribute() const override { \ | |||
| return AlgoAttribute::REPRODUCIBLE; \ | |||
| } \ | |||
| const char* name() const override { \ | |||
| if (m_name.empty()) { \ | |||
| m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \ | |||
| } \ | |||
| return m_name.c_str(); \ | |||
| } \ | |||
| bool is_available(const KernParam&) const override; \ | |||
| void exec(const KernParam&) const override; \ | |||
| }; | |||
| DECL_CB(VecVecVec); | |||
| DECL_CB(VecVecScalar); | |||
| DECL_CB(Bcast101VecBcast101); | |||
| DECL_CB(Bcast111CVecBcast111C); | |||
| DECL_CB(Bcast101xXVecBcast101xX); | |||
| DECL_CB(VecBcast101Vec); | |||
| DECL_CB(VecBcast111CVec); | |||
| DECL_CB(VecBcast101xXVec); | |||
| DECL_CB(VecScalarVec); | |||
| DECL_CB(VecScalarScalar); | |||
| #undef DECL_CB | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp | |||
| */ | |||
| #include "src/fallback/elemwise/gi_impl/unary/algo.h" | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_unary) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { | |||
| if (BcastType::VEC != kern_param.broad_cast_type) | |||
| return false; | |||
| if (kern_param.m_dst->layout.dtype.category() != DTypeCategory::FLOAT && | |||
| (kern_param.mode == Mode::EXP || kern_param.mode == Mode::SIGMOID || | |||
| kern_param.mode == Mode::TANH || kern_param.mode == Mode::FAST_TANH || | |||
| kern_param.mode == Mode::H_SWISH)) { | |||
| return false; | |||
| } | |||
| //! As `NEGATE` mode is so simple, that the code generate by compiler is | |||
| //! vectorized optimized, while other mode such as `ABS` has branch, the | |||
| //! compiler may not generate code as good as user intrinsic. | |||
| if (kern_param.mode == Mode::NEGATE) { | |||
| return false; | |||
| } | |||
| auto& elparam = kern_param.unary_elparam; | |||
| if (!elparam[0].layout.is_contiguous()) | |||
| return false; | |||
| megdnn_assert(elparam[0].layout.ndim == 1); | |||
| auto& src0 = elparam[0]; | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| auto mode = kern_param.mode; \ | |||
| if (mode == Mode::RELU || mode == Mode::ABS || mode == Mode::SIGMOID || \ | |||
| mode == Mode::EXP || mode == Mode::TANH || mode == Mode::FAST_TANH || \ | |||
| mode == Mode::H_SWISH) \ | |||
| return true; | |||
| #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ | |||
| auto mode = kern_param.mode; \ | |||
| if (mode == Mode::RELU || mode == Mode::ABS) \ | |||
| return true; | |||
| DISPATCH_TYPE_FALLBACK("AlgoUnary::is_available"_hash); | |||
| return false; | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| } | |||
| void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { | |||
| #define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \ | |||
| case Mode::_mode: \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_fallback_elemwise_unary, midout_iv(_case), \ | |||
| midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \ | |||
| thin_function<void(const _type*, _type*, DType, DType, size_t)> run = \ | |||
| OpCallerUnary<_op<_type, _type>, BcastType::VEC>::run; \ | |||
| auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, run]( \ | |||
| size_t task_id, size_t) { \ | |||
| size_t offset = task_id * nr_elems_per_thread; \ | |||
| size_t nr_elems_thread = \ | |||
| std::min(nr_elems - offset, nr_elems_per_thread); \ | |||
| run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||
| static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \ | |||
| src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ | |||
| }; \ | |||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ | |||
| kernel); \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return | |||
| auto& elparam = kern_param.unary_elparam; | |||
| megdnn_assert(elparam[0].layout.ndim == 1); | |||
| auto& src0 = elparam[0]; | |||
| auto& dst_tensor = *(kern_param.m_dst); | |||
| size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) | |||
| ->megcore_dispatcher() | |||
| ->nr_threads(); | |||
| size_t nr_elems = src0.layout.total_nr_elems(); | |||
| size_t nr_elems_per_thread = (nr_elems + nr_threads - 1) / nr_threads; | |||
| #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ | |||
| switch (kern_param.mode) { \ | |||
| DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ | |||
| DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ | |||
| DISPATCH_UNARY(SIGMOID, _case, _type, _type_midout_id, SigmoidOp); \ | |||
| DISPATCH_UNARY(EXP, _case, _type, _type_midout_id, ExpOp); \ | |||
| DISPATCH_UNARY(TANH, _case, _type, _type_midout_id, TanhOp); \ | |||
| DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \ | |||
| DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \ | |||
| default: \ | |||
| megdnn_throw(ssprintf( \ | |||
| "No avaiable algo find for: %d", \ | |||
| static_cast<int>(kern_param.mode))); \ | |||
| } | |||
| #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ | |||
| switch (kern_param.mode) { \ | |||
| DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ | |||
| DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ | |||
| default: \ | |||
| megdnn_throw(ssprintf( \ | |||
| "No avaiable algo find for: %d", \ | |||
| static_cast<int>(kern_param.mode))); \ | |||
| } | |||
| DISPATCH_TYPE_FALLBACK("AlgoUnary::exec"_hash); | |||
| #undef DISPATCH_MODE_FLOAT | |||
| #undef DISPATCH_MODE_INT | |||
| #undef DISPATCH_UNARY | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/unary/algo.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { | |||
| mutable std::string m_name; | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { | |||
| if (m_name.empty()) { | |||
| m_name = ssprintf("Elemwise::AlgoUnary"); | |||
| } | |||
| return m_name.c_str(); | |||
| } | |||
| bool is_available(const KernParam&) const override; | |||
| void exec(const KernParam&) const override; | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -12,6 +12,9 @@ | |||
| #include "src/common/elemwise/kern_defs.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback//elemwise/gi_impl/unary/algo.h" | |||
| #include "src/fallback/elemwise/gi_impl/binary/algo.h" | |||
| #include "src/fallback/elemwise/gi_impl/ternary/algo.h" | |||
| #include "src/naive/handle.h" | |||
| #include "midout.h" | |||
| @@ -21,13 +24,22 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT) | |||
| MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT) | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| if (!dst.layout.is_contiguous()) { | |||
| return naive::ElemwiseForwardImpl::exec(srcs, dst); | |||
| } | |||
| if (!exec_gi_intrinsic(srcs, dst)) { | |||
| return exec_fallback(srcs, dst); | |||
| } | |||
| } | |||
| void ElemwiseImpl::exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| if (!dst.layout.is_contiguous()) { | |||
| return naive::ElemwiseForwardImpl::exec(srcs, dst); | |||
| } | |||
| m_src = &srcs; | |||
| m_dst = &dst; | |||
| @@ -82,7 +94,229 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| naive::ElemwiseForwardImpl::exec(srcs, dst); | |||
| } | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| class ElemwiseImpl::AlgoPack { | |||
| #if !(MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||
| AlgoUnary algo_unary; | |||
| AlgoBinaryVecVec algo_binary_vec_vec; | |||
| AlgoBinaryVecScalar algo_binary_vec_sca; | |||
| AlgoBinaryVecBcast101 algo_binary_vec_bcast101; | |||
| AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X; | |||
| AlgoBinaryVecBcast111C algo_binary_vec_bcast110; | |||
| AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX; | |||
| AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; | |||
| AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; | |||
| AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; | |||
| AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110; | |||
| AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; | |||
| AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; | |||
| AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec; | |||
| AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; | |||
| AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; | |||
| AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; | |||
| #endif | |||
| public: | |||
| AlgoPack() { | |||
| #if !(MEGDNN_AARCH64 || MEGDNN_ARMV7) | |||
| all_algos.emplace_back(&algo_unary); | |||
| all_algos.emplace_back(&algo_binary_vec_vec); | |||
| all_algos.emplace_back(&algo_binary_vec_sca); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast101); | |||
| all_algos.emplace_back(&algo_binary_vec_bcastX0X); | |||
| all_algos.emplace_back(&algo_binary_vec_bcast110); | |||
| all_algos.emplace_back(&algo_binary_VEC_BCAST101xX); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110); | |||
| all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); | |||
| all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); | |||
| #endif | |||
| } | |||
| SmallVector<AlgoBase*> all_algos; | |||
| }; | |||
| bool ElemwiseImpl::exec_gi_intrinsic( | |||
| const TensorNDArray& srcs, _megdnn_tensor_out dst) { | |||
| m_src = &srcs; | |||
| m_dst = &dst; | |||
| if (m_dst->layout.dtype == dtype::Float32() || | |||
| m_dst->layout.dtype == dtype::Int32() || | |||
| m_dst->layout.dtype == dtype::Int16() || m_dst->layout.dtype == dtype::Int8()) { | |||
| auto kern_param = make_kern_param(this); | |||
| kern_param.m_dst = &dst; | |||
| static AlgoPack m_algo_pack; | |||
| for (auto& m_algo : m_algo_pack.all_algos) { | |||
| if (m_algo->is_available(kern_param)) { | |||
| m_algo->exec(kern_param); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { | |||
| KernParam kern_param; | |||
| kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE; | |||
| kern_param.mode = opr->param().mode; | |||
| kern_param.handle = opr->handle(); | |||
| auto is_legal_layout_for_nhwc = [](const TensorLayout& l) { | |||
| if (is_vector(l)) | |||
| return true; | |||
| if (l.ndim == 2 && l.stride[1] == 1) | |||
| return true; | |||
| return false; | |||
| }; | |||
| if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { | |||
| kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); | |||
| bool c_is_scalar; | |||
| opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); | |||
| auto &src0 = kern_param.ternary_elparam[0], | |||
| &src1 = kern_param.ternary_elparam[1], | |||
| &src2 = kern_param.ternary_elparam[2]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && | |||
| is_vector(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && | |||
| (is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo)) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
| is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo) && | |||
| src0.layout.eq_layout(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| src2.layout.eq_layout(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_vector(src2.layout) && | |||
| is_broadcasted_scalar(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && | |||
| is_broadcasted_scalar(src2.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR; | |||
| return kern_param; | |||
| } | |||
| } else if (opr->m_src->size() == 2) { | |||
| kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); | |||
| auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && is_vector(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_SCALAR; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { | |||
| kern_param.broad_cast_type = BcastType::SCALAR_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCASTX0X; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCASTX0X_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src1.layout) && | |||
| is_NHWC_broadcasted_channel_like(src0.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::BCAST111C_VEC; | |||
| return kern_param; | |||
| } | |||
| if (is_legal_layout_for_nhwc(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST111C; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src0.layout) && | |||
| (is_broadcastedx_channel_like<4>(src1.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src1.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::VEC_BCAST101xX; | |||
| return kern_param; | |||
| } | |||
| if (is_vector(src1.layout) && | |||
| (is_broadcastedx_channel_like<4>(src0.layout, binfo) || | |||
| is_broadcastedx_channel_like<8>(src0.layout, binfo))) { | |||
| kern_param.broad_cast_type = BcastType::BCAST101xX_VEC; | |||
| return kern_param; | |||
| } | |||
| } else if (opr->m_src->size() == 1) { | |||
| kern_param.broad_cast_type = BcastType::VEC; | |||
| kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); | |||
| return kern_param; | |||
| } | |||
| return kern_param; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -9,6 +9,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_op.h" | |||
| #include "src/naive/elemwise/opr_impl.h" | |||
| namespace megdnn { | |||
| @@ -33,13 +34,69 @@ class ElemwiseImpl : public naive::ElemwiseForwardImpl { | |||
| template <uint32_t mode> | |||
| void exec_BINARY_FLOAT(); | |||
| void exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst); | |||
| bool exec_gi_intrinsic(const TensorNDArray& srcs, _megdnn_tensor_out dst); | |||
| private: | |||
| class AlgoUnary; | |||
| class AlgoBinaryVecVec; | |||
| class AlgoBinaryVecScalar; | |||
| class AlgoBinaryVecBcast101; | |||
| class AlgoBinaryVecBcastX0X; | |||
| class AlgoBinaryVecBcast111C; | |||
| class AlgoBinaryVecBcast101xX; | |||
| class AlgoTernaryFma3VecVecVec; | |||
| class AlgoTernaryFma3VecVecScalar; | |||
| class AlgoTernaryFma3Bcast101VecBcast101; | |||
| class AlgoTernaryFma3Bcast111CVecBcast111C; | |||
| class AlgoTernaryFma3Bcast101xXVecBcast101xX; | |||
| class AlgoTernaryFma3VecBcast101Vec; | |||
| class AlgoTernaryFma3VecBcast111CVec; | |||
| class AlgoTernaryFma3VecBcast101xXVec; | |||
| class AlgoTernaryFma3VecScalarVec; | |||
| class AlgoTernaryFma3VecScalarScalar; | |||
| class AlgoPack; | |||
| public: | |||
| class AlgoBase; | |||
| struct KernParam { | |||
| BcastType broad_cast_type; | |||
| Mode mode; | |||
| const TensorND* m_dst; | |||
| Handle* handle; | |||
| ElemwiseOpParamN<3> ternary_elparam; | |||
| ElemwiseOpParamN<2> binary_elparam; | |||
| ElemwiseOpParamN<1> unary_elparam; | |||
| }; | |||
| KernParam make_kern_param(ElemwiseImpl* opr); | |||
| using naive::ElemwiseForwardImpl::ElemwiseForwardImpl; | |||
| void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override; | |||
| const char* get_algorithm_set_name() const { return "FALLBACK ELEMWISE"; } | |||
| bool is_thread_safe() const override { return true; } | |||
| }; | |||
| /*! | |||
| * \brief base class for Elemwise algo | |||
| */ | |||
| class ElemwiseImpl::AlgoBase : public detail::Algorithm { | |||
| public: | |||
| virtual bool is_available(const KernParam&) const = 0; | |||
| virtual void exec(const KernParam&) const = 0; | |||
| virtual ~AlgoBase() = default; | |||
| uint32_t type() const override { return INVALID_ALGO_TYPE; }; | |||
| }; | |||
| //! fallback only support float, int32, int8 | |||
| #define DISPATCH_TYPE_FALLBACK(_case) \ | |||
| if (src0.layout.dtype == dtype::Float32{}) { \ | |||
| DISPATCH_MODE_FLOAT(_case, float, 0); \ | |||
| } else if (src0.layout.dtype == dtype::Int32{}) { \ | |||
| DISPATCH_MODE_INT(_case, int, 2); \ | |||
| } else if (src0.layout.dtype == dtype::Int8{}) { \ | |||
| DISPATCH_MODE_INT(_case, dt_int8, 4); \ | |||
| } | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/abs.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct AbsOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : (-src); } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct AbsOp; | |||
| #define OP(_ctype, _gi_type, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct AbsOp<_ctype> : AbsOpBase<_ctype> { \ | |||
| using AbsOpBase::AbsOpBase; \ | |||
| using AbsOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _gi_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _gi_type operator()(const _gi_type& src) const { \ | |||
| auto vitem0 = GiAbs##_func_suffix(src.val[0]); \ | |||
| auto vitem1 = GiAbs##_func_suffix(src.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(dt_float32)) | |||
| OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32)) | |||
| OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8)) | |||
| #undef OP | |||
| template <> | |||
| struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint8& src, dt_qint8* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src) const { | |||
| float fsrc = src.as_int8() * this->scale; | |||
| fsrc = fsrc > 0 ? fsrc : -fsrc; | |||
| return QConverter::convert<dt_qint8, float>(fsrc); | |||
| } | |||
| }; | |||
| template <> | |||
| struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> { | |||
| using AbsOpBase::AbsOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using AbsOpBase::operator(); | |||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | |||
| OPERATOR_UNARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); | |||
| vitem0 = GiAbsFloat32(vitem0); | |||
| vitem1 = GiAbsFloat32(vitem1); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/add.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct AddOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 + src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct AddOp; | |||
| #define OP(_ctype, _gi_type, _gi_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct AddOp<_ctype> : AddOpBase<_ctype> { \ | |||
| using AddOpBase::AddOpBase; \ | |||
| using AddOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _gi_type2& src0, const _gi_type2& src1, dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _gi_type2 operator()(const _gi_type2& src0, const _gi_type2& src1) const { \ | |||
| auto vitem0 = GiAdd##_func_suffix(src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiAdd##_func_suffix(src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _gi_type& src0, const _gi_type& src1, dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _gi_type operator()(const _gi_type& src0, const _gi_type& src1) const { \ | |||
| return GiAdd##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, | |||
| GI_SIMD_LEN_BYTE / sizeof(dt_float32)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8)) | |||
| #undef OP | |||
| template <> | |||
| struct AddOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| return QConverter::convert<dt_qint8, float>( | |||
| src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct AddOp<dt_qint8, dt_qint8> : AddOpBase<dt_qint8, dt_qint8> { | |||
| using AddOpBase::AddOpBase; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| using AddOpBase::operator(); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| template <> | |||
| struct AddOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { | |||
| return QConverter::convert<dt_qint8, float>( | |||
| src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct AddOp<dt_qint32, dt_qint8> : AddOpBase<dt_qint32, dt_qint8> { | |||
| using AddOpBase::AddOpBase; | |||
| using AddOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1)); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/exp.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct ExpOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { | |||
| float tmp = src; | |||
| return exp(tmp); | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct ExpOp; | |||
| #define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct ExpOp<_ctype> : ExpOpBase<_ctype> { \ | |||
| using ExpOpBase::ExpOpBase; \ | |||
| using ExpOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| auto vitem0 = GiExpPs##_func_suffix(src.val[0]); \ | |||
| auto vitem1 = GiExpPs##_func_suffix(src.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| //! tanh = x * (27 + x^2) / (27 + 9 * x^2) | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FastTanhOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { | |||
| float x = src; | |||
| return x * (27.f + x * x) / (27.f + 9.f * x * x); | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FastTanhOp; | |||
| #define OP(_ctype, _simd_type, _func_suffix, _fix_func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FastTanhOp<_ctype> : FastTanhOpBase<_ctype> { \ | |||
| using FastTanhOpBase::FastTanhOpBase; \ | |||
| using FastTanhOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| auto val_27 = GiBroadcast##_func_suffix(27.f); \ | |||
| auto val_9 = GiBroadcast##_func_suffix(9.f); \ | |||
| auto valx = src.val[0]; \ | |||
| auto valx1 = src.val[1]; \ | |||
| auto valxp2 = GiMultiply##_fix_func_suffix(valx, valx); \ | |||
| auto valx1p2 = GiMultiply##_fix_func_suffix(valx1, valx1); \ | |||
| auto denominator = GiAdd##_fix_func_suffix(valxp2, val_27); \ | |||
| auto denominator1 = GiAdd##_fix_func_suffix(valx1p2, val_27); \ | |||
| valx = GiMultiply##_fix_func_suffix(valx, denominator); \ | |||
| valx1 = GiMultiply##_fix_func_suffix(valx1, denominator1); \ | |||
| denominator = GiMultiplyAdd##_fix_func_suffix(val_27, valxp2, val_9); \ | |||
| denominator1 = GiMultiplyAdd##_fix_func_suffix(val_27, valx1p2, val_9); \ | |||
| auto r_denominator = GiRecpe##_func_suffix(denominator); \ | |||
| auto r_denominator1 = GiRecpe##_func_suffix(denominator1); \ | |||
| r_denominator = GiMultiply##_fix_func_suffix( \ | |||
| GiRecpeS##_func_suffix(denominator, r_denominator), \ | |||
| r_denominator); \ | |||
| r_denominator1 = GiMultiply##_fix_func_suffix( \ | |||
| GiRecpeS##_func_suffix(denominator1, r_denominator1), \ | |||
| r_denominator1); \ | |||
| valx = GiMultiply##_fix_func_suffix(valx, r_denominator); \ | |||
| valx1 = GiMultiply##_fix_func_suffix(valx1, r_denominator1); \ | |||
| return {{valx, valx1}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddHSwishOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| float tmp = src0 + src1; | |||
| tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; | |||
| return tmp; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddHSwishOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \ | |||
| using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \ | |||
| using FuseAddHSwishOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto val1 = src0.val[0]; \ | |||
| auto val2 = src0.val[1]; \ | |||
| auto val3 = src1.val[0]; \ | |||
| auto val4 = src1.val[1]; \ | |||
| val1 = GiAdd##_func_suffix(val1, val3); \ | |||
| val2 = GiAdd##_func_suffix(val2, val4); \ | |||
| H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| auto val1 = src0; \ | |||
| auto val2 = src1; \ | |||
| val1 = GiAdd##_func_suffix(val1, val2); \ | |||
| H_SWISH_KERN_N1_FALLBACK(_func_suffix, val1); \ | |||
| return val1; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| template <> | |||
| struct FuseAddHSwishOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { | |||
| float tmp = | |||
| src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1; | |||
| tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; | |||
| tmp *= this->scale_dst; | |||
| return QConverter::convert<dt_qint8, float>(tmp); | |||
| } | |||
| }; | |||
| template <> | |||
| struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> { | |||
| using FuseAddHSwishOpBase::FuseAddHSwishOpBase; | |||
| using FuseAddHSwishOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| void operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1)); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| GI_FLOAT32_t vitem0, vitem1; | |||
| vitem0 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1)); | |||
| vitem1 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1)); | |||
| H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1); | |||
| vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst); | |||
| vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| #include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h" | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,162 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h | |||
| */ | |||
| #pragma once | |||
| #include "gi_util_impl_helper.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddReluOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| auto tmp = src0 + src1; | |||
| return tmp > 0 ? tmp : 0; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddReluOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \ | |||
| using FuseAddReluOpBase::FuseAddReluOpBase; \ | |||
| using FuseAddReluOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto val1 = src0.val[0]; \ | |||
| auto val2 = src0.val[1]; \ | |||
| auto val3 = src1.val[0]; \ | |||
| auto val4 = src1.val[1]; \ | |||
| FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, _func_suffix); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| auto val1 = src0; \ | |||
| auto val2 = src1; \ | |||
| FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, _func_suffix); \ | |||
| return val1; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <typename ctype> | |||
| struct FuseAddReluOpCommon; | |||
| template <> | |||
| struct FuseAddReluOpCommon<float> { | |||
| inline static GI_FLOAT32_t vzero() { return GiBroadcastFloat32(0); } | |||
| }; | |||
| template <> | |||
| struct FuseAddReluOpCommon<int> { | |||
| inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); } | |||
| }; | |||
| template <> | |||
| struct FuseAddReluOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| return QConverter::convert<dt_qint8, float>(std::max<float>( | |||
| src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, 0.f)); | |||
| } | |||
| }; | |||
| template <> | |||
| struct FuseAddReluOp<dt_qint8, dt_qint8> : FuseAddReluOpBase<dt_qint8, dt_qint8>, | |||
| FuseAddReluOpCommon<float> { | |||
| using FuseAddReluOpBase::FuseAddReluOpBase; | |||
| using FuseAddReluOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| vitem0 = GiMaximumFloat32(vitem0, this->vzero()); | |||
| vitem1 = GiMaximumFloat32(vitem1, this->vzero()); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| template <> | |||
| struct FuseAddReluOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { | |||
| return QConverter::convert<dt_qint8, float>(std::max<float>( | |||
| src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, 0.f)); | |||
| } | |||
| }; | |||
| template <> | |||
| struct FuseAddReluOp<dt_qint32, dt_qint8> : FuseAddReluOpBase<dt_qint32, dt_qint8>, | |||
| FuseAddReluOpCommon<float> { | |||
| using FuseAddReluOpBase::FuseAddReluOpBase; | |||
| using FuseAddReluOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| void operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1)); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiAddFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| vitem0 = GiMaximumFloat32(vitem0, this->vzero()); | |||
| vitem1 = GiMaximumFloat32(vitem1, this->vzero()); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddSigmoidOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| float tmpf = src0 + src1; | |||
| tmpf = exp(-tmpf); | |||
| tmpf = 1.f / (1.f + tmpf); | |||
| return tmpf; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddSigmoidOp; | |||
| #define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \ | |||
| using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \ | |||
| using FuseAddSigmoidOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| auto val1 = src0.val[0]; \ | |||
| auto val2 = src0.val[1]; \ | |||
| auto val3 = src1.val[0]; \ | |||
| auto val4 = src1.val[1]; \ | |||
| val1 = GiAdd##_func_suffix(val1, val3); \ | |||
| val2 = GiAdd##_func_suffix(val2, val4); \ | |||
| val1 = GiSigmoidPs##_func_suffix(val1); \ | |||
| val2 = GiSigmoidPs##_func_suffix(val2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddTanhOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| float tmpf = exp(src0 + (src1)); | |||
| float tmpf2 = 1 / tmpf; | |||
| return (tmpf - tmpf2) / (tmpf + tmpf2); | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseAddTanhOp; | |||
| #define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \ | |||
| using FuseAddTanhOpBase::FuseAddTanhOpBase; \ | |||
| using FuseAddTanhOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| auto val1 = src0.val[0]; \ | |||
| auto val2 = src0.val[1]; \ | |||
| auto val3 = src1.val[0]; \ | |||
| auto val4 = src1.val[1]; \ | |||
| val1 = GiAdd##_func_suffix(val1, val3); \ | |||
| val2 = GiAdd##_func_suffix(val2, val4); \ | |||
| auto exp1 = GiExpPs##_func_suffix(val1); \ | |||
| auto exp2 = GiExpPs##_func_suffix(val2); \ | |||
| auto rexp1 = GiRecpe##_func_suffix(exp1); \ | |||
| auto rexp2 = GiRecpe##_func_suffix(exp2); \ | |||
| rexp1 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \ | |||
| rexp2 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \ | |||
| val1 = GiSubtract##_func_suffix(exp1, rexp1); \ | |||
| val2 = GiSubtract##_func_suffix(exp2, rexp2); \ | |||
| exp1 = GiAdd##_func_suffix(exp1, rexp1); \ | |||
| exp2 = GiAdd##_func_suffix(exp2, rexp2); \ | |||
| rexp1 = GiRecpe##_func_suffix(exp1); \ | |||
| rexp2 = GiRecpe##_func_suffix(exp2); \ | |||
| rexp1 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \ | |||
| rexp2 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \ | |||
| val1 = GiMultiply##_func_suffix(val1, rexp1); \ | |||
| val2 = GiMultiply##_func_suffix(val2, rexp2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseMulAdd3OpBase : TernaryOpBase<src_ctype, dst_ctype> { | |||
| using TernaryOpBase<src_ctype, dst_ctype>::TernaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, const src_ctype src2, | |||
| dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1, src2); | |||
| } | |||
| dst_ctype operator()( | |||
| const src_ctype& src0, const src_ctype& src1, const src_ctype& src2) const { | |||
| return (src0 * src1) + src2; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct FuseMulAdd3Op; | |||
| #define OP(_ctype, _simd_type, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \ | |||
| using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \ | |||
| using FuseMulAdd3OpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| const _simd_type& src2, dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1, src2); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| const _simd_type& src2) const { \ | |||
| auto vitem0 = GiMultiplyAdd##_func_suffix( \ | |||
| src2.val[0], src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiMultiplyAdd##_func_suffix( \ | |||
| src2.val[1], src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise/gi_impl/gi_util_impl_helper.h | |||
| */ | |||
| #pragma once | |||
| /*! | |||
| * \brief compute fuse_add_relu on two simd packs | |||
| * | |||
| * Compute | |||
| * | |||
| * val1 = fuse_add_relu(val1, val3) | |||
| * val2 = fuse_add_relu(val2, val4) | |||
| * | |||
| * This algorithm handles int overflow. | |||
| */ | |||
| #define FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, func_suffix) \ | |||
| do { \ | |||
| val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val3)); \ | |||
| val2 = GiMaximum##func_suffix(val2, GiNeg##func_suffix(val4)); \ | |||
| val1 = GiAdd##func_suffix(val1, val3); \ | |||
| val2 = GiAdd##func_suffix(val2, val4); \ | |||
| } while (0) | |||
| #define FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, func_suffix) \ | |||
| do { \ | |||
| val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val2)); \ | |||
| val1 = GiAdd##func_suffix(val1, val2); \ | |||
| } while (0) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,108 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/hswish.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct HSwishOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { | |||
| float tmp = src; | |||
| tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; | |||
| return (tmp); | |||
| } | |||
| }; | |||
| //! h_swish(x) = x * clip(x + 3, 0, 6) / 6 | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct HSwishOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \ | |||
| using HSwishOpBase::HSwishOpBase; \ | |||
| using HSwishOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type2& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type2 operator()(const _simd_type2& src) const { \ | |||
| auto val1 = src.val[0]; \ | |||
| auto val2 = src.val[1]; \ | |||
| H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| auto val_zero = GiBroadcast##_func_suffix(0.f); \ | |||
| auto val_six = GiBroadcast##_func_suffix(6.f); \ | |||
| auto val_three = GiBroadcast##_func_suffix(3.f); \ | |||
| auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ | |||
| auto clip1 = GiMaximum##_func_suffix( \ | |||
| GiMinimum##_func_suffix( \ | |||
| GiAdd##_func_suffix(src, val_three), val_six), \ | |||
| val_zero); \ | |||
| return GiMultiply##_func_suffix( \ | |||
| GiMultiply##_func_suffix(src, clip1), val_rec_six); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| template <> | |||
| struct HSwishOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint32& src, dt_qint8* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src) const { | |||
| float tmp = src.as_int32() * this->scale_src; | |||
| tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; | |||
| tmp *= this->scale_dst; | |||
| return QConverter::convert<dt_qint8, float>(tmp); | |||
| } | |||
| }; | |||
| template <> | |||
| struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> { | |||
| using HSwishOpBase::HSwishOpBase; | |||
| using HSwishOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src); | |||
| H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1); | |||
| vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst); | |||
| vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| #include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h" | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,7 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h | |||
| */ | |||
| #undef H_SWISH_KERN_FALLBACK | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h | |||
| */ | |||
| #define H_SWISH_KERN_FALLBACK(_func_suffix, _val1, _val2) \ | |||
| do { \ | |||
| auto val_zero = GiBroadcast##_func_suffix(0.f); \ | |||
| auto val_six = GiBroadcast##_func_suffix(6.f); \ | |||
| auto val_three = GiBroadcast##_func_suffix(3.f); \ | |||
| auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ | |||
| auto clip1 = GiMaximum##_func_suffix( \ | |||
| GiMinimum##_func_suffix( \ | |||
| GiAdd##_func_suffix(_val1, val_three), val_six), \ | |||
| val_zero); \ | |||
| auto clip2 = GiMaximum##_func_suffix( \ | |||
| GiMinimum##_func_suffix( \ | |||
| GiAdd##_func_suffix(_val2, val_three), val_six), \ | |||
| val_zero); \ | |||
| _val1 = GiMultiply##_func_suffix( \ | |||
| GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \ | |||
| _val2 = GiMultiply##_func_suffix( \ | |||
| GiMultiply##_func_suffix(_val2, clip2), val_rec_six); \ | |||
| } while (0); | |||
| #define H_SWISH_KERN_N1_FALLBACK(_func_suffix, _val1) \ | |||
| do { \ | |||
| auto val_zero = GiBroadcast##_func_suffix(0.f); \ | |||
| auto val_six = GiBroadcast##_func_suffix(6.f); \ | |||
| auto val_three = GiBroadcast##_func_suffix(3.f); \ | |||
| auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \ | |||
| auto clip1 = GiMaximum##_func_suffix( \ | |||
| GiMinimum##_func_suffix( \ | |||
| GiAdd##_func_suffix(_val1, val_three), val_six), \ | |||
| val_zero); \ | |||
| _val1 = GiMultiply##_func_suffix( \ | |||
| GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \ | |||
| } while (0); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/max.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MaxOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 > src1 ? src0 : src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MaxOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct MaxOp<_ctype> : MaxOpBase<_ctype> { \ | |||
| using MaxOpBase::MaxOpBase; \ | |||
| using MaxOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto vitem0 = GiMaximum##_func_suffix(src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiMaximum##_func_suffix(src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| return GiMaximum##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using src_ctype = dt_qint8; | |||
| using dst_ctype = dt_qint8; | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| float fsrc0 = src0.as_int8() * this->scale0; | |||
| float fsrc1 = src1.as_int8() * this->scale1; | |||
| return QConverter::convert<dst_ctype, float>(fsrc0 > fsrc1 ? fsrc0 : fsrc1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> { | |||
| using MaxOpBase::MaxOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using MaxOpBase::operator(); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiMaximumFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiMaximumFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,99 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/min.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MinOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 < src1 ? src0 : src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MinOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct MinOp<_ctype> : MinOpBase<_ctype> { \ | |||
| using MinOpBase::MinOpBase; \ | |||
| using MinOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto vitem0 = GiMinimum##_func_suffix(src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiMinimum##_func_suffix(src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| return GiMinimum##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| float fsrc0 = src0.as_int8() * this->scale0; | |||
| float fsrc1 = src1.as_int8() * this->scale1; | |||
| return QConverter::convert<dt_qint8, float>(fsrc0 < fsrc1 ? fsrc0 : fsrc1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> { | |||
| using MinOpBase::MinOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using MinOpBase::operator(); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiMinimumFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiMinimumFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,99 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/mul.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MulOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 * src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct MulOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct MulOp<_ctype> : MulOpBase<_ctype> { \ | |||
| using MulOpBase::MulOpBase; \ | |||
| using MulOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto vitem0 = GiMultiply##_func_suffix(src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiMultiply##_func_suffix(src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| return GiMultiply##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| return QConverter::convert<dt_qint8, float>( | |||
| src0.as_int8() * scale_src0 * src1.as_int8() * scale1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> { | |||
| using MulOpBase::MulOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using MulOpBase::operator(); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiMultiplyFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiMultiplyFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/none.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct NoneOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| dst_ctype operator()(const src_ctype& src) const { return src; } | |||
| }; | |||
| template <typename src_ctype, typename dst_type = src_ctype> | |||
| struct NoneOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | |||
| NoneOp(){}; \ | |||
| NoneOp(float, float){}; \ | |||
| using NoneOpBase::NoneOpBase; \ | |||
| using NoneOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| _simd_type2 operator()(const _simd_type2& src) const { return src; } \ | |||
| void operator()(const _simd_type2& src, _ctype* dst) const { \ | |||
| GiStore##_func_suffix(dst, src.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \ | |||
| } \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| GiStore##_func_suffix(dst, src); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { return src; } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct NoneOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; } | |||
| }; | |||
| template <> | |||
| struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint32& src, dt_qint8* dst) const { | |||
| *(reinterpret_cast<dt_qint32*>(dst)) = src; | |||
| } | |||
| }; | |||
| #pragma GCC diagnostic ignored "-Waddress-of-packed-member" | |||
| template <> | |||
| struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> { | |||
| using NoneOpBase::NoneOpBase; | |||
| using NoneOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]); | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]); | |||
| } | |||
| void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | |||
| GiStoreInt32(reinterpret_cast<int32_t*>(dst), src); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,450 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/op_base.h | |||
| */ | |||
| #pragma once | |||
| #include <cmath> | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/elemwise/gi_impl/gi_mathfun.h" | |||
| #include "src/fallback/quantized_converter.h" | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "src/fallback/general_intrinsic/gi_int.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| ////////////////////////// unary ////////////////////////// | |||
| template <typename _src_ctype, typename _dst_ctype = _src_ctype> | |||
| struct OpBase { | |||
| using src_ctype = _src_ctype; | |||
| using dst_ctype = _dst_ctype; | |||
| OpBase() = default; | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct UnaryOpBase : OpBase<src_ctype, dst_ctype> { | |||
| using OpBase<src_ctype, dst_ctype>::OpBase; | |||
| UnaryOpBase() = default; | |||
| UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {} | |||
| }; | |||
| #define OPERATOR_UNARY_QINT8_FALLBACK \ | |||
| GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst), operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0), \ | |||
| GiMoveHighLongInt16(vsrct0)}})); \ | |||
| GI_INT16_t vsrct1 = GiMoveHighLongInt8(vsrc.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 8), \ | |||
| operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \ | |||
| GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 16), \ | |||
| operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | |||
| GI_INT16_t vsrct3 = GiMoveHighLongInt8(vsrc.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 24), \ | |||
| operator()({{GiMoveLowLongInt16(vsrct3), GiMoveHighLongInt16(vsrct3)}})) | |||
| //! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul) | |||
| //! scale = src.scale / dst.scale | |||
| template <> | |||
| struct UnaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> { | |||
| using OpBase::OpBase; | |||
| float scale_src, scale_dst; | |||
| GI_FLOAT32_t vscale_src, vscale_dst; | |||
| float scale; | |||
| GI_FLOAT32_t vscale; | |||
| void init(float src_scale, float dst_scale) { | |||
| scale_src = src_scale; | |||
| vscale_src = GiBroadcastFloat32(scale_src); | |||
| scale_dst = 1.f / dst_scale; | |||
| vscale_dst = GiBroadcastFloat32(scale_dst); | |||
| scale = src_scale / dst_scale; | |||
| vscale = GiBroadcastFloat32(scale); | |||
| } | |||
| UnaryOpBase(DType src_dtype, DType dst_dtype) { | |||
| float src_scale = src_dtype.param<dtype::QuantizedS8>().scale; | |||
| float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale; | |||
| init(src_scale, dst_scale); | |||
| } | |||
| UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } | |||
| }; | |||
| template <> | |||
| struct UnaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> { | |||
| using OpBase::OpBase; | |||
| using src_ctype = dt_qint32; | |||
| using dst_ctype = dt_qint8; | |||
| float scale; | |||
| GI_FLOAT32_t vscale; | |||
| float scale_src, scale_dst; | |||
| GI_FLOAT32_t vscale_src, vscale_dst; | |||
| void init(float src_scale, float dst_scale) { | |||
| scale_src = src_scale; | |||
| vscale_src = GiBroadcastFloat32(src_scale); | |||
| scale_dst = 1 / dst_scale; | |||
| vscale_dst = GiBroadcastFloat32(scale_dst); | |||
| scale = src_scale / dst_scale; | |||
| vscale = GiBroadcastFloat32(scale); | |||
| } | |||
| UnaryOpBase(DType src_dtype, DType dst_dtype) { | |||
| float src_scale = src_dtype.param<dtype::QuantizedS32>().scale; | |||
| float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale; | |||
| init(src_scale, dst_scale); | |||
| } | |||
| UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } | |||
| }; | |||
| ////////////////////////// binary ////////////////////////// | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct BinaryOpBase : OpBase<src_ctype, dst_ctype> { | |||
| using OpBase<src_ctype, dst_ctype>::OpBase; | |||
| BinaryOpBase() = default; | |||
| BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*dst_dtype*/) {} | |||
| }; | |||
| /* ================= binary op for quantized types ================== */ | |||
| #define OPERATOR_BINARY_QINT8_FALLBACK \ | |||
| GI_INT16_t vsrct0_0 = GiMoveLowLongInt8(vsrc0.val[0]); \ | |||
| GI_INT16_t vsrct1_0 = GiMoveLowLongInt8(vsrc1.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0_0), GiMoveHighLongInt16(vsrct0_0)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1_0), GiMoveHighLongInt16(vsrct1_0)}})); \ | |||
| GI_INT16_t vsrct0_1 = GiMoveHighLongInt8(vsrc0.val[0]); \ | |||
| GI_INT16_t vsrct1_1 = GiMoveHighLongInt8(vsrc1.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 8), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0_1), GiMoveHighLongInt16(vsrct0_1)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1_1), GiMoveHighLongInt16(vsrct1_1)}})); \ | |||
| GI_INT16_t vsrct0_2 = GiMoveLowLongInt8(vsrc0.val[1]); \ | |||
| GI_INT16_t vsrct1_2 = GiMoveLowLongInt8(vsrc1.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 16), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0_2), GiMoveHighLongInt16(vsrct0_2)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1_2), GiMoveHighLongInt16(vsrct1_2)}})); \ | |||
| GI_INT16_t vsrct0_3 = GiMoveHighLongInt8(vsrc0.val[1]); \ | |||
| GI_INT16_t vsrct1_3 = GiMoveHighLongInt8(vsrc1.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 24), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0_3), GiMoveHighLongInt16(vsrct0_3)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1_3), GiMoveHighLongInt16(vsrct1_3)}})) | |||
| //! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f / | |||
| //! dst.scale scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale | |||
| template <> | |||
| struct BinaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> { | |||
| using OpBase::OpBase; | |||
| using src_ctype = dt_qint8; | |||
| using dst_ctype = dt_qint8; | |||
| float scale_src0, scale_src1, scale_dst; | |||
| GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst; | |||
| float scale0, scale1; | |||
| GI_FLOAT32_t vscale0, vscale1; | |||
| void init(float src0_scale, float src1_scale, float dst_scale) { | |||
| scale_src0 = src0_scale; | |||
| vscale_src0 = GiBroadcastFloat32(scale_src0); | |||
| scale_src1 = src1_scale; | |||
| vscale_src1 = GiBroadcastFloat32(scale_src1); | |||
| scale_dst = 1.f / dst_scale; | |||
| vscale_dst = GiBroadcastFloat32(scale_dst); | |||
| scale0 = src0_scale / dst_scale; | |||
| vscale0 = GiBroadcastFloat32(scale0); | |||
| scale1 = src1_scale / dst_scale; | |||
| vscale1 = GiBroadcastFloat32(scale1); | |||
| } | |||
| BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { | |||
| float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale; | |||
| float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale; | |||
| float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale; | |||
| init(src0_scale, src1_scale, dst_scale); | |||
| } | |||
| BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { | |||
| init(src0_scale, src1_scale, dst_scale); | |||
| } | |||
| }; | |||
| template <> | |||
| struct BinaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> { | |||
| using OpBase::OpBase; | |||
| using src_ctype = dt_qint32; | |||
| using dst_ctype = dt_qint8; | |||
| float scale0, scale1; | |||
| GI_FLOAT32_t vscale0, vscale1; | |||
| float scale_src0, scale_src1, scale_dst; | |||
| GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst; | |||
| void init(float src0_scale, float src1_scale, float dst_scale) { | |||
| scale_src0 = src0_scale; | |||
| vscale_src0 = GiBroadcastFloat32(src0_scale); | |||
| scale_src1 = src1_scale; | |||
| vscale_src1 = GiBroadcastFloat32(src1_scale); | |||
| scale_dst = 1 / dst_scale; | |||
| vscale_dst = GiBroadcastFloat32(scale_dst); | |||
| scale0 = src0_scale / dst_scale; | |||
| vscale0 = GiBroadcastFloat32(scale0); | |||
| scale1 = src1_scale / dst_scale; | |||
| vscale1 = GiBroadcastFloat32(scale1); | |||
| } | |||
| BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) { | |||
| float src0_scale = src0_dtype.param<dtype::QuantizedS32>().scale; | |||
| float src1_scale = src1_dtype.param<dtype::QuantizedS32>().scale; | |||
| float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale; | |||
| init(src0_scale, src1_scale, dst_scale); | |||
| } | |||
| BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) { | |||
| init(src0_scale, src1_scale, dst_scale); | |||
| } | |||
| }; | |||
| ////////////////////////// ternary ////////////////////////// | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct TernaryOpBase : OpBase<src_ctype, dst_ctype> { | |||
| using OpBase<src_ctype, dst_ctype>::OpBase; | |||
| TernaryOpBase() = default; | |||
| TernaryOpBase( | |||
| DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*src2_dtype*/, | |||
| DType /*dst_dtype*/) {} | |||
| }; | |||
| #define OPERATOR_TERNARY_QINT8_FALLBACK \ | |||
| GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc0.val[0]); \ | |||
| GI_INT16_t vsrct1 = GiMoveLowLongInt8(vsrc1.val[0]); \ | |||
| GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc2.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | |||
| vsrct0 = GiMoveHighLongInt8(vsrc0.val[0]); \ | |||
| vsrct1 = GiMoveHighLongInt8(vsrc1.val[0]); \ | |||
| vsrct2 = GiMoveHighLongInt8(vsrc2.val[0]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 8), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | |||
| vsrct0 = GiMoveLowLongInt8(vsrc0.val[1]); \ | |||
| vsrct1 = GiMoveLowLongInt8(vsrc1.val[1]); \ | |||
| vsrct2 = GiMoveLowLongInt8(vsrc2.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 16), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \ | |||
| vsrct0 = GiMoveHighLongInt8(vsrc0.val[1]); \ | |||
| vsrct1 = GiMoveHighLongInt8(vsrc1.val[1]); \ | |||
| vsrct2 = GiMoveHighLongInt8(vsrc2.val[1]); \ | |||
| GiStoreLowInt8( \ | |||
| reinterpret_cast<int8_t*>(dst + 24), \ | |||
| operator()( \ | |||
| {{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \ | |||
| {{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})) | |||
| /*========================= ternaty op for quanzited ====================*/ | |||
| template <> | |||
| struct TernaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> { | |||
| using OpBase::OpBase; | |||
| using src_ctype = dt_qint8; | |||
| using dst_ctype = dt_qint8; | |||
| float scale_src0, scale_src1, scale_src2, scale_dst; | |||
| GI_FLOAT32_t vscale_src0, vscale_src1, vscale_src2, vscale_dst; | |||
| float scale0, scale1, scale2; | |||
| GI_FLOAT32_t vscale0, vscale1, vscale2; | |||
| void init(float src0_scale, float src1_scale, float src2_scale, float dst_scale) { | |||
| scale_src0 = src0_scale; | |||
| scale_src1 = src1_scale; | |||
| scale_src2 = src2_scale; | |||
| scale_dst = 1.f / dst_scale; | |||
| vscale_src0 = GiBroadcastFloat32(scale_src0); | |||
| vscale_src1 = GiBroadcastFloat32(scale_src1); | |||
| vscale_src2 = GiBroadcastFloat32(scale_src2); | |||
| vscale_dst = GiBroadcastFloat32(scale_dst); | |||
| scale0 = src0_scale / dst_scale; | |||
| scale1 = src1_scale / dst_scale; | |||
| scale2 = src2_scale / dst_scale; | |||
| vscale0 = GiBroadcastFloat32(scale0); | |||
| vscale1 = GiBroadcastFloat32(scale1); | |||
| vscale2 = GiBroadcastFloat32(scale2); | |||
| } | |||
| TernaryOpBase( | |||
| DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) { | |||
| float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale; | |||
| float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale; | |||
| float src2_scale = src2_dtype.param<dtype::QuantizedS8>().scale; | |||
| float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale; | |||
| init(src0_scale, src1_scale, src2_scale, dst_scale); | |||
| } | |||
| TernaryOpBase( | |||
| float src0_scale, float src1_scale, float src2_scale, float dst_scale) { | |||
| init(src0_scale, src1_scale, src2_scale, dst_scale); | |||
| } | |||
| }; | |||
| ////////////////////////// fixup ////////////////////////// | |||
| struct FixupBase { | |||
| GI_INT32_t vmultiplier, vshift; | |||
| FixupBase(float scale) { | |||
| //! ignore Fixup if scale >= 0.5, using typecvt instead of shift & | |||
| //! multiplier, as it may introduce errors. | |||
| if (scale >= 0.5) | |||
| return; | |||
| int shift = static_cast<int>(::ceilf(::log2f(0.5 / scale))); | |||
| scale *= ::powf(2, shift); | |||
| //! Using double can get full precision here, but it can be ignored. | |||
| vmultiplier = GiBroadcastInt32( | |||
| std::round(static_cast<double>(scale) * ((2LL) << 30))); | |||
| vshift = GiBroadcastInt32(-shift); | |||
| } | |||
| }; | |||
| //////////////////////// quantization common //////////////////// | |||
| template <typename src_type, typename dst_type, typename Op> | |||
| struct UnaryQuantizationOp; | |||
| template <typename Op> | |||
| struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| Op op; | |||
| void operator()(const dt_qint8& src, dt_qint8* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src) const { | |||
| float fsrc = src.as_int8() * this->scale_src; | |||
| fsrc = op(fsrc); | |||
| fsrc = fsrc * this->scale_dst; | |||
| return QConverter::convert<dt_qint8, float>(fsrc); | |||
| } | |||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | |||
| OPERATOR_UNARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src); | |||
| auto val = this->op({{vitem0, vitem1}}); | |||
| val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | |||
| val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(val); | |||
| } | |||
| }; | |||
| template <typename src_type, typename dst_type, typename Op> | |||
| struct BinaryQuantizationOp; | |||
| template <typename Op> | |||
| struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| Op op; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| float fsrc0 = src0.as_int8() * this->scale_src0; | |||
| float fsrc1 = src1.as_int8() * this->scale_src1; | |||
| float fdst = op(fsrc0, fsrc1); | |||
| fdst = fdst * this->scale_dst; | |||
| return QConverter::convert<dt_qint8, float>(fdst); | |||
| } | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0); | |||
| auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0); | |||
| auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1); | |||
| auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1); | |||
| auto val = op({{val0, val1}}, {{val2, val3}}); | |||
| val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | |||
| val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val); | |||
| } | |||
| }; | |||
| template <typename src_type, typename dst_type, typename Op> | |||
| struct TernaryQuantizationOp; | |||
| template <typename Op> | |||
| struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op> | |||
| : TernaryOpBase<dt_qint8, dt_qint8> { | |||
| using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| Op op; | |||
| void operator()( | |||
| const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2, | |||
| dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1, src2); | |||
| } | |||
| dt_qint8 operator()( | |||
| const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2) const { | |||
| float fsrc0 = src0.as_int8() * this->scale_src0; | |||
| float fsrc1 = src1.as_int8() * this->scale_src1; | |||
| float fsrc2 = src2.as_int8() * this->scale_src2; | |||
| float fdst = op(fsrc0, fsrc1, fsrc2); | |||
| fdst = fdst * this->scale_dst; | |||
| return QConverter::convert<dt_qint8, float>(fdst); | |||
| } | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, | |||
| const GI_INT8_V2_t& vsrc2, dt_qint8* dst) const { | |||
| OPERATOR_TERNARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()( | |||
| const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1, | |||
| const GI_INT32_V2_t& vsrc2) const { | |||
| auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0); | |||
| auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0); | |||
| auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1); | |||
| auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1); | |||
| auto val4 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[0]), this->vscale_src2); | |||
| auto val5 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[1]), this->vscale_src2); | |||
| auto val = op({{val0, val1}}, {{val2, val3}}, {{val4, val5}}); | |||
| val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst); | |||
| val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/pow.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| /////////////////////// POW float only //////////////////////////// | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct PowOp : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 1; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return powf(src0, src1); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,188 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/relu.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; } | |||
| }; | |||
| template <typename src_ctype, typename dst_type = src_ctype> | |||
| struct ReluOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct ReluOp<_ctype> : ReluOpBase<_ctype> { \ | |||
| using ReluOpBase::ReluOpBase; \ | |||
| using ReluOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type2& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()(const _simd_type2& src) const { \ | |||
| auto vzero = GiBroadcast##_func_suffix(0); \ | |||
| auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \ | |||
| auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| auto vzero = GiBroadcast##_func_suffix(0); \ | |||
| return GiMaximum##_func_suffix(src, vzero); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint8& src, dt_qint8* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src) const { | |||
| float fsrc = src.as_int8() * this->scale; | |||
| fsrc = std::max<float>(fsrc, 0.f); | |||
| return QConverter::convert<dt_qint8, float>(fsrc); | |||
| } | |||
| }; | |||
| template <> | |||
| struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> { | |||
| using ReluOpBase::ReluOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using ReluOpBase::operator(); | |||
| void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const { | |||
| OPERATOR_UNARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vzero = GiBroadcastFloat32(0.f); | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); | |||
| vitem0 = GiMaximumFloat32(vitem0, vzero); | |||
| vitem1 = GiMaximumFloat32(vitem1, vzero); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| template <> | |||
| struct ReluOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| void operator()(const dt_qint32& src, dt_qint8* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src) const { | |||
| float fsrc = src.as_int32() * this->scale; | |||
| fsrc = std::max<float>(fsrc, 0.f); | |||
| return QConverter::convert<dt_qint8, float>(fsrc); | |||
| } | |||
| }; | |||
| //! if old armv7, special define relu with fixup | |||
| #if defined(__ARM_ARCH) && __ARM_ARCH < 8 | |||
| template <> | |||
| struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase { | |||
| using ReluOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = 4; | |||
| ReluOp(DType src_dtype, DType dst_dtype) | |||
| : ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {} | |||
| ReluOp(float src_scale, float dst_scale) | |||
| : ReluOpBase(src_scale, dst_scale), FixupBase(scale) {} | |||
| void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const { | |||
| vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| } | |||
| int8x8_t operator()(const int32x4x2_t& vsrc) const { | |||
| int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier); | |||
| int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); | |||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | |||
| vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); | |||
| return vqmovn_s16(vcombine_s16( | |||
| vqmovn_s32(vrshlq_s32(vitem0, vshift)), | |||
| vqmovn_s32(vrshlq_s32(vitem1, vshift)))); | |||
| } | |||
| int8x8_t operator()(const float32x4_t& vsrc) const { | |||
| int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); | |||
| vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); | |||
| vitem0 = vrshlq_s32(vitem0, vshift); | |||
| int16x4_t vitem = vqmovn_s32(vitem0); | |||
| return vqmovn_s16(vcombine_s16(vitem, vitem)); | |||
| } | |||
| void operator()(const int32x4_t& src, dt_qint8* dst) const { | |||
| auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale); | |||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | |||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||
| } | |||
| void operator()(const float32x4_t& src, dt_qint8* dst) const { | |||
| auto vitem0 = vmulq_f32(src, this->vscale); | |||
| vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); | |||
| auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0); | |||
| vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0); | |||
| } | |||
| }; | |||
| #else | |||
| template <> | |||
| struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> { | |||
| using ReluOpBase::ReluOpBase; | |||
| using ReluOpBase::operator(); | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t); | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| } | |||
| void operator()(const GI_INT32_t& src, dt_qint8* dst) const { | |||
| GiStoreLane0Int32( | |||
| reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(src))); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); | |||
| vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); | |||
| vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero()); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_t& src) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); | |||
| vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); | |||
| } | |||
| GI_INT8_t operator()(const GI_FLOAT32_t& src) const { | |||
| auto vitem0 = GiMultiplyFloat32(src, this->vscale); | |||
| vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero()); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); | |||
| } | |||
| }; | |||
| #endif | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct SigmoidOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { | |||
| float tmpf = src; | |||
| tmpf = exp(-tmpf); | |||
| tmpf = 1.f / (1.f + tmpf); | |||
| return tmpf; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct SigmoidOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ | |||
| using SigmoidOpBase::SigmoidOpBase; \ | |||
| using SigmoidOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type2& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| void operator()(const _simd_type& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type2 operator()(const _simd_type2& src) const { \ | |||
| return {{operator()(src.val[0]), operator()(src.val[1])}}; \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| return GiSigmoidPs##_func_suffix(src); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/sub.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct SubOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 - src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct SubOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct SubOp<_ctype> : SubOpBase<_ctype> { \ | |||
| using SubOpBase::SubOpBase; \ | |||
| using SubOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto vitem0 = GiSubtract##_func_suffix(src0.val[0], src1.val[0]); \ | |||
| auto vitem1 = GiSubtract##_func_suffix(src0.val[1], src1.val[1]); \ | |||
| return {{vitem0, vitem1}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| return GiSubtract##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t)) | |||
| OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t)) | |||
| #undef OP | |||
| template <> | |||
| struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> { | |||
| using BinaryOpBase::BinaryOpBase; | |||
| void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { | |||
| return QConverter::convert<dt_qint8, float>( | |||
| src0.as_int8() * scale0 - src1.as_int8() * scale1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> { | |||
| using SubOpBase::SubOpBase; | |||
| constexpr static size_t SIMD_WIDTH = 16; | |||
| using SubOpBase::operator(); | |||
| void operator()( | |||
| const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const { | |||
| OPERATOR_BINARY_QINT8_FALLBACK; | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const { | |||
| auto vitem0 = GiSubtractFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1)); | |||
| auto vitem1 = GiSubtractFloat32( | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0), | |||
| GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1)); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/tanh.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||
| using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase; | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src) const { | |||
| float tmp = src; | |||
| return tanh(tmp); | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_type = src_ctype> | |||
| struct TanhOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ | |||
| using TanhOpBase::TanhOpBase; \ | |||
| using TanhOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()(const _simd_type2& src, _ctype* dst) const { \ | |||
| auto vitem = operator()(src); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()(const _simd_type2& src) const { \ | |||
| auto one_val = GiBroadcast##_func_suffix(1.f); \ | |||
| auto two_val = GiBroadcast##_func_suffix(2.f); \ | |||
| auto val1 = src.val[0]; \ | |||
| auto val2 = src.val[1]; \ | |||
| val1 = GiMultiply##_func_suffix(two_val, val1); \ | |||
| val2 = GiMultiply##_func_suffix(two_val, val2); \ | |||
| val1 = GiExpPs##_func_suffix(val1); \ | |||
| val2 = GiExpPs##_func_suffix(val2); \ | |||
| val1 = GiAdd##_func_suffix(one_val, val1); \ | |||
| val2 = GiAdd##_func_suffix(one_val, val2); \ | |||
| auto rval1 = GiRecpe##_func_suffix(val1); \ | |||
| auto rval2 = GiRecpe##_func_suffix(val2); \ | |||
| rval1 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(val1, rval1), rval1); \ | |||
| rval2 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(val2, rval2), rval2); \ | |||
| val1 = GiMultiply##_func_suffix(two_val, rval1); \ | |||
| val2 = GiMultiply##_func_suffix(two_val, rval2); \ | |||
| val1 = GiSubtract##_func_suffix(one_val, val1); \ | |||
| val2 = GiSubtract##_func_suffix(one_val, val2); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src) const { \ | |||
| auto one_val = GiBroadcast##_func_suffix(1.f); \ | |||
| auto two_val = GiBroadcast##_func_suffix(2.f); \ | |||
| auto val1 = src; \ | |||
| val1 = GiMultiply##_func_suffix(two_val, val1); \ | |||
| val1 = GiExpPs##_func_suffix(val1); \ | |||
| val1 = GiAdd##_func_suffix(one_val, val1); \ | |||
| auto rval1 = GiRecpe##_func_suffix(val1); \ | |||
| rval1 = GiMultiply##_func_suffix( \ | |||
| GiRecpeS##_func_suffix(val1, rval1), rval1); \ | |||
| val1 = GiMultiply##_func_suffix(two_val, rval1); \ | |||
| val1 = GiSubtract##_func_suffix(one_val, val1); \ | |||
| return val1; \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/true_div.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| //! use a couple Newton-Raphson steps to refine the estimate. | |||
| //! A / B => 1. rB = vrecpeq_f32(B) 2. rB= vmulq_f32(vrecpsq_f32(B, rB), rB) | |||
| //! 3. A * rB | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct TrueDivOpBase : BinaryOpBase<src_ctype, dst_ctype> { | |||
| using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase; | |||
| void operator()( | |||
| const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { | |||
| *dst = operator()(src0, src1); | |||
| } | |||
| dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { | |||
| return src0 / src1; | |||
| } | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct TrueDivOp; | |||
| #define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \ | |||
| template <> \ | |||
| struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \ | |||
| using TrueDivOpBase::TrueDivOpBase; \ | |||
| using TrueDivOpBase::operator(); \ | |||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | |||
| void operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem.val[0]); \ | |||
| GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | |||
| } \ | |||
| _simd_type2 operator()( \ | |||
| const _simd_type2& src0, const _simd_type2& src1) const { \ | |||
| auto val1 = src0.val[0]; \ | |||
| auto val2 = src0.val[1]; \ | |||
| auto val3 = src1.val[0]; \ | |||
| auto val4 = src1.val[1]; \ | |||
| val1 = GiDivide##_func_suffix(val1, val3); \ | |||
| val2 = GiDivide##_func_suffix(val2, val4); \ | |||
| return {{val1, val2}}; \ | |||
| } \ | |||
| void operator()( \ | |||
| const _simd_type& src0, const _simd_type& src1, \ | |||
| dst_ctype* dst) const { \ | |||
| auto vitem = operator()(src0, src1); \ | |||
| GiStore##_func_suffix(dst, vitem); \ | |||
| } \ | |||
| _simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \ | |||
| return GiDivide##_func_suffix(src0, src1); \ | |||
| } \ | |||
| }; | |||
| OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float)) | |||
| #undef OP | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/kimpl/typecvt.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/op_base.h" | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| template <typename src_ctype, typename dst_ctype = src_ctype> | |||
| struct TypeCvtOp; | |||
| template <> | |||
| struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> { | |||
| using UnaryOpBase::UnaryOpBase; | |||
| constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float); | |||
| void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc)); | |||
| } | |||
| void operator()(const GI_INT32_t& vsrc, dt_qint8* dst) const { | |||
| GiStoreLane0Int32( | |||
| reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(vsrc))); | |||
| } | |||
| void operator()(const src_ctype& src, dst_ctype* dst) const { | |||
| *dst = operator()(src); | |||
| } | |||
| dt_qint8 operator()(const dt_qint32& src) const { | |||
| float fsrc = src.as_int32() * this->scale; | |||
| return QConverter::convert<dt_qint8, float>(fsrc); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale); | |||
| auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}}); | |||
| } | |||
| GI_INT8_t operator()(const GI_INT32_t& src) const { | |||
| auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); | |||
| } | |||
| GI_INT8_t operator()(const GI_FLOAT32_t& src) const { | |||
| auto vitem0 = GiMultiplyFloat32(src, this->vscale); | |||
| return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0); | |||
| } | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/op_binary.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/add.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/fuse_add_relu.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/max.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/min.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/mul.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/pow.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/sub.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/true_div.h" | |||
| //////////////////// quantization ////////////////////////////// | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define cb(op) \ | |||
| template <> \ | |||
| struct op<dt_qint8, dt_qint8> \ | |||
| : BinaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \ | |||
| using BinaryQuantizationOp< \ | |||
| dt_qint8, dt_qint8, op<float, float>>::BinaryQuantizationOp; \ | |||
| }; | |||
| cb(TrueDivOp); | |||
| cb(FuseAddSigmoidOp); | |||
| cb(FuseAddTanhOp); | |||
| cb(FuseAddHSwishOp); | |||
| #undef cb | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/op_common.h | |||
| */ | |||
| #pragma once | |||
| namespace megdnn { | |||
| /*! | |||
| * \brief broadcast type | |||
| * BCAST_x[0]x[1]...: x[i] == !stride[i] | |||
| */ | |||
| enum BcastType { | |||
| VEC, | |||
| VEC_VEC, | |||
| VEC_BCAST101, | |||
| VEC_BCASTX0X, | |||
| VEC_BCAST111C, | |||
| VEC_BCAST101xX, | |||
| VEC_SCALAR, | |||
| SCALAR_VEC, | |||
| BCAST101_VEC, | |||
| BCASTX0X_VEC, | |||
| BCAST111C_VEC, | |||
| BCAST101xX_VEC, | |||
| VEC_VEC_VEC, | |||
| VEC_VEC_SCALAR, | |||
| BCAST101_VEC_BCAST101, | |||
| BCAST111C_VEC_BCAST111C, | |||
| BCAST101xX_VEC_BCAST101xX, | |||
| VEC_BCAST101_VEC, | |||
| VEC_BCAST111C_VEC, | |||
| VEC_BCAST101xX_VEC, | |||
| VEC_SCALAR_VEC, | |||
| VEC_SCALAR_SCALAR, | |||
| UNKNOWN_BCAST_TYPE | |||
| }; | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/op_ternary.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h" | |||
| //////////////////// quantization ////////////////////////////// | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define cb(op) \ | |||
| template <> \ | |||
| struct op<dt_qint8, dt_qint8> \ | |||
| : TernaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \ | |||
| using TernaryQuantizationOp< \ | |||
| dt_qint8, dt_qint8, op<float, float>>::TernaryQuantizationOp; \ | |||
| }; | |||
| cb(FuseMulAdd3Op); | |||
| #undef cb | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file dnn/src/fallback/elemwise_helper/op_unary.h | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/elemwise_helper/kimpl/abs.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/exp.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/fast_tanh.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/hswish.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/none.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/relu.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/sigmoid.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/tanh.h" | |||
| #include "src/fallback/elemwise_helper/kimpl/typecvt.h" | |||
| //////////////////// quantization ////////////////////////////// | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| #define cb(op) \ | |||
| template <> \ | |||
| struct op<dt_qint8, dt_qint8> \ | |||
| : UnaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \ | |||
| using UnaryQuantizationOp< \ | |||
| dt_qint8, dt_qint8, op<float, float>>::UnaryQuantizationOp; \ | |||
| }; | |||
| cb(SigmoidOp); | |||
| cb(ExpOp); | |||
| cb(TanhOp); | |||
| cb(FastTanhOp); | |||
| cb(HSwishOp); | |||
| #undef cb | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -19,7 +19,7 @@ | |||
| #include <windows.h> | |||
| #else | |||
| #if defined(__arm__) || defined(__aarch64__) | |||
| #include <arm_neon.h> | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #endif | |||
| #if defined(__x86_64__) || defined(__i386__) | |||
| #include <cpuid.h> | |||
| @@ -21,13 +21,13 @@ namespace megdnn { | |||
| namespace fallback { | |||
| struct QConverterBase { | |||
| inline static GI_INT32 vzero() { return GiBroadcastInt32(0); } | |||
| inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); } | |||
| inline static GI_FLOAT32 vfzero() { return GiBroadcastFloat32(0.f); } | |||
| inline static GI_FLOAT32_t vfzero() { return GiBroadcastFloat32(0.f); } | |||
| inline static GI_FLOAT32 vfhalf() { return GiBroadcastFloat32(0.5f); } | |||
| inline static GI_FLOAT32_t vfhalf() { return GiBroadcastFloat32(0.5f); } | |||
| inline static GI_FLOAT32 vfneg_half() { return GiBroadcastFloat32(-0.5f); } | |||
| inline static GI_FLOAT32_t vfneg_half() { return GiBroadcastFloat32(-0.5f); } | |||
| }; | |||
| struct QConverter { | |||
| @@ -56,23 +56,23 @@ inline dt_qint32 QConverter::convert(const float& src) { | |||
| } | |||
| template <> | |||
| inline GI_FLOAT32_V2 QConverter::convert(const GI_INT16& vsrc) { | |||
| GI_INT32 vhi = GiMoveHighLongInt16(vsrc); | |||
| GI_INT32 vlo = GiMoveLowLongInt16(vsrc); | |||
| inline GI_FLOAT32_V2_t QConverter::convert(const GI_INT16_t& vsrc) { | |||
| GI_INT32_t vhi = GiMoveHighLongInt16(vsrc); | |||
| GI_INT32_t vlo = GiMoveLowLongInt16(vsrc); | |||
| return {{GiCastToFloat32(vlo), GiCastToFloat32(vhi)}}; | |||
| } | |||
| template <> | |||
| inline GI_INT8 QConverter::convert(const GI_FLOAT32_V2& vsrc) { | |||
| inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) { | |||
| return GiCvtFromFloat32V2ToInt8(vsrc); | |||
| } | |||
| template <> | |||
| inline GI_INT8 QConverter::convert(const GI_FLOAT32& src) { | |||
| inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) { | |||
| return GiCvtFromFloat32ToInt8(src); | |||
| } | |||
| template <> | |||
| inline GI_INT32 QConverter::round(const GI_FLOAT32& vsrc) { | |||
| inline GI_INT32_t QConverter::round(const GI_FLOAT32_t& vsrc) { | |||
| return GiRoundAsInt32(vsrc); | |||
| } | |||
| } // namespace fallback | |||
| @@ -38,6 +38,309 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) { | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{10, 10, 32}, {10, 10, 32}, {}}); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| checker.set_param(Mode::FUSE_MUL_ADD3); | |||
| auto run = [&] { | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| //! nchw88 | |||
| checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||
| //! nchw88 | |||
| checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||
| checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||
| checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||
| checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||
| checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); | |||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); | |||
| checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); | |||
| checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}}); | |||
| checker.execs({{1, 7}, {1, 7}, {1, 7}, {}}); | |||
| checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}}); | |||
| checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}}); | |||
| checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}}); | |||
| checker.execs({{3, 4, 5}, {1}, {1}, {}}); | |||
| checker.execs({{1}, {3, 4, 5}, {1}, {}}); | |||
| }; | |||
| // case int | |||
| checker.set_dtype(0, dtype::Int8()); | |||
| checker.set_dtype(1, dtype::Int8()); | |||
| checker.set_dtype(2, dtype::Int8()); | |||
| run(); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Int16()); | |||
| checker.set_dtype(2, dtype::Int16()); | |||
| run(); | |||
| checker.set_dtype(0, dtype::Int32()); | |||
| checker.set_dtype(1, dtype::Int32()); | |||
| checker.set_dtype(2, dtype::Int32()); | |||
| run(); | |||
| // case float | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| run(); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| auto run = [&]() { | |||
| // VEC_BCAST101x not PowOp | |||
| checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| // BCAST101x_VEC not PowOp | |||
| checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| }; | |||
| checker.set_dtype(0, dtype::Int8()); | |||
| checker.set_dtype(1, dtype::Int8()); | |||
| run(); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Int16()); | |||
| run(); | |||
| checker.set_dtype(0, dtype::Int32()); | |||
| checker.set_dtype(1, dtype::Int32()); | |||
| run(); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_FP32) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| auto run = [&](Mode mode) { | |||
| // VEC_BCAST101x | |||
| checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}}); | |||
| checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}}); | |||
| // BCAST101x_VEC not powOp | |||
| checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}}); | |||
| checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}}); | |||
| checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); | |||
| checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); | |||
| }; | |||
| run(Mode::ADD); | |||
| run(Mode::FUSE_ADD_H_SWISH); | |||
| run(Mode::FUSE_ADD_RELU); | |||
| run(Mode::MAX); | |||
| run(Mode::MIN); | |||
| run(Mode::MUL); | |||
| run(Mode::SUB); | |||
| run(Mode::TRUE_DIV); | |||
| run(Mode::POW); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW88_FP) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.set_param(Mode::FUSE_ADD_RELU) | |||
| .execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||
| auto run = [&](Mode mode) { | |||
| // VEC_BCAST101x | |||
| checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}}); | |||
| checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}}); | |||
| // BCAST101x_VEC not powOp | |||
| checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}}); | |||
| checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}}); | |||
| checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}}); | |||
| checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}}); | |||
| checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}}); | |||
| }; | |||
| auto run_all = [&]() { | |||
| run(Mode::ADD); | |||
| run(Mode::FUSE_ADD_H_SWISH); | |||
| run(Mode::FUSE_ADD_RELU); | |||
| run(Mode::MAX); | |||
| run(Mode::MIN); | |||
| run(Mode::MUL); | |||
| run(Mode::SUB); | |||
| run(Mode::TRUE_DIV); | |||
| run(Mode::POW); | |||
| }; | |||
| { | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| run_all(); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_N1HW_FP32_BCAST) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| Checker<ElemwiseForward> checker(handle()); | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| //! 2 dim | |||
| auto run = [&](Mode mode) { | |||
| // VEC_BCASTX0X | |||
| checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}}); | |||
| checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}}); | |||
| // BCASTX0X_VEC | |||
| checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}}); | |||
| checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}}); | |||
| }; | |||
| run(Mode::ADD); | |||
| run(Mode::MUL); | |||
| run(Mode::SUB); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY_RECORD) { | |||
| using Mode = ElemwiseForward::Param::Mode; | |||
| TaskRecordChecker<ElemwiseForward> checker(0); | |||
| checker.set_param(Mode::FUSE_MUL_ADD3); | |||
| auto run = [&] { | |||
| //! nchw44 | |||
| checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}}); | |||
| //! nchw88 | |||
| checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}}); | |||
| checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); | |||
| checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); | |||
| }; | |||
| // case int | |||
| checker.set_dtype(0, dtype::Int32()); | |||
| checker.set_dtype(1, dtype::Int32()); | |||
| checker.set_dtype(2, dtype::Int32()); | |||
| run(); | |||
| // case float | |||
| UniformFloatRNG rng(1e-5, 7e1); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_epsilon(1e-5); | |||
| checker.set_dtype(0, dtype::Float32()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| run(); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { | |||
| auto naive_handle = create_cpu_handle(2); | |||
| @@ -164,3 +164,148 @@ TEST_F(X86, BENCHMARK_ELEM_EVERY_DTYPE) { | |||
| // B.set_dtype(2, dtype::Int8()); | |||
| // BENCHMARK_CASE_INT(1556011) | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| void run_elemwise_benchmark( | |||
| const TensorShapeArray& shapes, param::Elemwise::Mode mode, | |||
| const char* mode_str, DType type, Handle* handle_bench) { | |||
| auto handle_fallback = create_cpu_handle(1); | |||
| Benchmarker<Elemwise> benchmarker_bench(handle_bench); | |||
| Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get()); | |||
| float throughput = 0; | |||
| SmallVector<TensorLayout> layouts; | |||
| std::string src_strs; | |||
| for (size_t i = 0; i < shapes.size(); i++) { | |||
| layouts.emplace_back(shapes[i], type); | |||
| throughput += layouts.back().span().dist_byte(); | |||
| src_strs += layouts.back().to_string(); | |||
| if (i != shapes.size() - 1) { | |||
| src_strs += ","; | |||
| } | |||
| } | |||
| constexpr size_t RUN = 50; | |||
| benchmarker_fallback.set_times(RUN).set_display(false); | |||
| benchmarker_bench.set_times(RUN).set_display(false); | |||
| benchmarker_fallback.set_param(mode); | |||
| benchmarker_bench.set_param(mode); | |||
| TensorLayout dst_layout; | |||
| auto opr = handle_bench->create_operator<Elemwise>(); | |||
| opr->param() = mode; | |||
| opr->deduce_layout(layouts, dst_layout); | |||
| float computations = | |||
| dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1); | |||
| throughput += dst_layout.span().dist_byte(); | |||
| computations *= (1e3 / (1024.0 * 1024)); | |||
| throughput *= (1e3 / (1024.0 * 1024)); | |||
| layouts.emplace_back(dst_layout); | |||
| auto fallback_time = benchmarker_fallback.execl(layouts) / RUN; | |||
| auto bench_time = benchmarker_bench.execl(layouts) / RUN; | |||
| float fallback_flops = computations / fallback_time; | |||
| float bench_flops = computations / bench_time; | |||
| float fallback_thr = throughput / fallback_time; | |||
| float bench_thr = throughput / bench_time; | |||
| printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS " | |||
| "%fMB/s " | |||
| "computations: %fx, throughput: %fx\n", | |||
| src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str, | |||
| fallback_flops, fallback_thr, bench_flops, bench_thr, | |||
| bench_flops / fallback_flops, bench_thr / fallback_thr); | |||
| } | |||
| } // namespace | |||
| #define INT_RUN(shape, mode) \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle()); | |||
| #define FLOAT_RUN(shape, mode) \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \ | |||
| run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle()); | |||
| #define BENCHMARK_CASES(shape) \ | |||
| INT_BENCHMARK_CASES(shape) \ | |||
| FLOAT_BENCHMARK_CASES(shape) | |||
| TEST_F(X86, BENCHMARK_UNARY) { | |||
| #define INT_BENCHMARK_CASES(shape) \ | |||
| INT_RUN(shape, Mode::RELU); \ | |||
| INT_RUN(shape, Mode::ABS); | |||
| #define FLOAT_BENCHMARK_CASES(shape) \ | |||
| FLOAT_RUN(shape, Mode::RELU); \ | |||
| FLOAT_RUN(shape, Mode::ABS); \ | |||
| FLOAT_RUN(shape, Mode::SIGMOID); \ | |||
| FLOAT_RUN(shape, Mode::EXP); \ | |||
| FLOAT_RUN(shape, Mode::TANH); \ | |||
| FLOAT_RUN(shape, Mode::FAST_TANH); | |||
| using Mode = param::Elemwise::Mode; | |||
| BENCHMARK_CASES({{10000}}); | |||
| BENCHMARK_CASES({{50000}}); | |||
| #undef INT_BENCHMARK_CASES | |||
| #undef FLOAT_BENCHMARK_CASES | |||
| } | |||
| TEST_F(X86, BENCHMARK_BINARY) { | |||
| #define INT_BENCHMARK_CASES(shape) \ | |||
| INT_RUN(shape, Mode::MIN); \ | |||
| INT_RUN(shape, Mode::MAX); \ | |||
| INT_RUN(shape, Mode::ADD); \ | |||
| INT_RUN(shape, Mode::SUB); \ | |||
| INT_RUN(shape, Mode::MUL); \ | |||
| INT_RUN(shape, Mode::RMULH); \ | |||
| INT_RUN(shape, Mode::FUSE_ADD_RELU); | |||
| #define FLOAT_BENCHMARK_CASES(shape) \ | |||
| FLOAT_RUN(shape, Mode::MIN); \ | |||
| FLOAT_RUN(shape, Mode::MAX); \ | |||
| FLOAT_RUN(shape, Mode::ADD); \ | |||
| FLOAT_RUN(shape, Mode::SUB); \ | |||
| FLOAT_RUN(shape, Mode::MUL); \ | |||
| FLOAT_RUN(shape, Mode::POW); \ | |||
| FLOAT_RUN(shape, Mode::TRUE_DIV); \ | |||
| FLOAT_RUN(shape, Mode::FUSE_ADD_RELU); | |||
| using Mode = param::Elemwise::Mode; | |||
| TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}}; | |||
| BENCHMARK_CASES(shapes); | |||
| shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}}; | |||
| BENCHMARK_CASES(shapes); | |||
| shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}}; | |||
| BENCHMARK_CASES(shapes); | |||
| #undef INT_BENCHMARK_CASES | |||
| #undef FLOAT_BENCHMARK_CASES | |||
| } | |||
| TEST_F(X86, BENCHMARK_TERNARY_FMA3) { | |||
| #define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3); | |||
| #define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3); | |||
| using Mode = param::Elemwise::Mode; | |||
| TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}}; | |||
| BENCHMARK_CASES(shapes); | |||
| shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}}; | |||
| BENCHMARK_CASES(shapes); | |||
| shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}}; | |||
| BENCHMARK_CASES(shapes); | |||
| #undef INT_BENCHMARK_CASES | |||
| #undef FLOAT_BENCHMARK_CASES | |||
| } | |||
| #undef BENCHMARK_CASES | |||
| #undef INT_RUN | |||
| #undef FLOAT_RUN | |||
| #endif | |||