Browse Source

feat(fallback): support general intrinsic in elemwise in fallback

GitOrigin-RevId: 96ff2e88cc
tags/v1.10.0
Megvii Engine Team 3 years ago
parent
commit
547945e854
50 changed files with 6439 additions and 310 deletions
  1. +59
    -64
      dnn/src/arm_common/conv_bias/postprocess_helper.h
  2. +0
    -159
      dnn/src/arm_common/elemwise/opr_impl.cpp
  3. +1
    -24
      dnn/src/arm_common/elemwise/opr_impl.h
  4. +3
    -30
      dnn/src/arm_common/elemwise_op.h
  5. +2
    -9
      dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp
  6. +2
    -9
      dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp
  7. +535
    -0
      dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
  8. +36
    -0
      dnn/src/fallback/elemwise/gi_impl/binary/algo.h
  9. +383
    -0
      dnn/src/fallback/elemwise/gi_impl/gi_mathfun.cpp
  10. +55
    -0
      dnn/src/fallback/elemwise/gi_impl/gi_mathfun.h
  11. +459
    -0
      dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp
  12. +39
    -0
      dnn/src/fallback/elemwise/gi_impl/ternary/algo.h
  13. +125
    -0
      dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
  14. +25
    -0
      dnn/src/fallback/elemwise/gi_impl/unary/algo.h
  15. +238
    -4
      dnn/src/fallback/elemwise/opr_impl.cpp
  16. +57
    -0
      dnn/src/fallback/elemwise/opr_impl.h
  17. +78
    -0
      dnn/src/fallback/elemwise_helper/kimpl/abs.h
  18. +134
    -0
      dnn/src/fallback/elemwise_helper/kimpl/add.h
  19. +49
    -0
      dnn/src/fallback/elemwise_helper/kimpl/exp.h
  20. +70
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h
  21. +118
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
  22. +162
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h
  23. +60
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h
  24. +77
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h
  25. +60
    -0
      dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h
  26. +31
    -0
      dnn/src/fallback/elemwise_helper/kimpl/gi_util_impl_helper.h
  27. +108
    -0
      dnn/src/fallback/elemwise_helper/kimpl/hswish.h
  28. +7
    -0
      dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h
  29. +39
    -0
      dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h
  30. +102
    -0
      dnn/src/fallback/elemwise_helper/kimpl/max.h
  31. +99
    -0
      dnn/src/fallback/elemwise_helper/kimpl/min.h
  32. +99
    -0
      dnn/src/fallback/elemwise_helper/kimpl/mul.h
  33. +77
    -0
      dnn/src/fallback/elemwise_helper/kimpl/none.h
  34. +450
    -0
      dnn/src/fallback/elemwise_helper/kimpl/op_base.h
  35. +28
    -0
      dnn/src/fallback/elemwise_helper/kimpl/pow.h
  36. +188
    -0
      dnn/src/fallback/elemwise_helper/kimpl/relu.h
  37. +56
    -0
      dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h
  38. +97
    -0
      dnn/src/fallback/elemwise_helper/kimpl/sub.h
  39. +81
    -0
      dnn/src/fallback/elemwise_helper/kimpl/tanh.h
  40. +68
    -0
      dnn/src/fallback/elemwise_helper/kimpl/true_div.h
  41. +53
    -0
      dnn/src/fallback/elemwise_helper/kimpl/typecvt.h
  42. +39
    -0
      dnn/src/fallback/elemwise_helper/op_binary.h
  43. +39
    -0
      dnn/src/fallback/elemwise_helper/op_common.h
  44. +24
    -0
      dnn/src/fallback/elemwise_helper/op_ternary.h
  45. +36
    -0
      dnn/src/fallback/elemwise_helper/op_unary.h
  46. +1432
    -0
      dnn/src/fallback/elemwise_op.h
  47. +1
    -1
      dnn/src/fallback/general_intrinsic/gi_common.h
  48. +10
    -10
      dnn/src/fallback/quantized_converter.h
  49. +303
    -0
      dnn/test/fallback/elemwise.cpp
  50. +145
    -0
      dnn/test/x86/elemwise_bmark.cpp

+ 59
- 64
dnn/src/arm_common/conv_bias/postprocess_helper.h View File

@@ -44,30 +44,29 @@ namespace {
break;

#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>::run( \
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), reinterpret_cast<ctype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);

#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);

#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_BIAS(_mode) \
@@ -168,36 +167,33 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_BIAS
#undef HANDLE_IDENTITY

#define FOR_NONLINEAR_UNARY(_op) \
#define FOR_NONLINEAR_UNARY(_op) \
megdnn::arm_common::OpCallerUnary<_op<opctype, opdtype>, megdnn::VEC>::run( \
static_cast<opctype*>(conv_dst_ptr), reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerUnary<_op<opctype, opdtype>, megdnn::arm_common::VEC>::run( \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW);
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);

#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common::OpCallerBinary< \
_op<opctype, opdtype>, megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<opctype, opdtype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);

#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \
@@ -271,26 +267,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR
#undef FOR_BIAS

#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);

#define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N, OC, OH* OW, pack_oc_size);

#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW);

#define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_BCAST101xX>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, N, OC, \
OH* OW, pack_oc_size);

#define FOR_BINARY(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, megdnn::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);

#define FOR_BIAS(_bias_mode, OH, OW) \


+ 0
- 159
dnn/src/arm_common/elemwise/opr_impl.cpp View File

@@ -89,163 +89,4 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
fallback::ElemwiseImpl::exec(srcs, dst);
}

ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
KernParam kern_param;
kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE;
kern_param.mode = opr->param().mode;
kern_param.handle = opr->handle();

auto is_legal_layout_for_nhwc = [](const TensorLayout& l) {
if (is_vector(l))
return true;
if (l.ndim == 2 && l.stride[1] == 1)
return true;
return false;
};

if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
bool c_is_scalar;
opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar);
auto &src0 = kern_param.ternary_elparam[0],
&src1 = kern_param.ternary_elparam[1],
&src2 = kern_param.ternary_elparam[2];
BroadcastChannelInfo binfo;

if (is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) {
kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101;
return kern_param;
}

if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
return kern_param;
}

if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C;
return kern_param;
}

if (is_legal_layout_for_nhwc(src0.layout) &&
src2.layout.eq_layout(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC;
return kern_param;
}

if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_vector(src2.layout) &&
is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR;
return kern_param;
}
} else if (opr->m_src->size() == 2) {
kern_param.binary_elparam = opr->make_elemwise_op_param<2>();
auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) {
kern_param.broad_cast_type = BcastType::SCALAR_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C;
return kern_param;
}

if (is_vector(src0.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
return kern_param;
}

if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
return kern_param;
}
} else if (opr->m_src->size() == 1) {
kern_param.broad_cast_type = BcastType::VEC;
kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
return kern_param;
}

return kern_param;
}

// vim: syntax=cpp.doxygen

+ 1
- 24
dnn/src/arm_common/elemwise/opr_impl.h View File

@@ -18,22 +18,12 @@ namespace megdnn {
namespace arm_common {
class ElemwiseImpl final : public fallback::ElemwiseImpl {
public:
using fallback::ElemwiseImpl::AlgoBase;
using fallback::ElemwiseImpl::ElemwiseImpl;
void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;
const char* get_algorithm_set_name() const { return "ARM COMMON ELEMWISE"; }

private:
struct KernParam {
BcastType broad_cast_type;
Mode mode;
const TensorND* m_dst;
Handle* handle;
ElemwiseOpParamN<3> ternary_elparam;
ElemwiseOpParamN<2> binary_elparam;
ElemwiseOpParamN<1> unary_elparam;
};
KernParam make_kern_param(ElemwiseImpl* opr);
class AlgoBase;
class AlgoUnary;
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
@@ -54,19 +44,6 @@ private:
class AlgoPack;
};

/*!
*
* \brief base class for Elemwise algo
*
*/
class ElemwiseImpl::AlgoBase : public detail::Algorithm {
public:
virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define DISPATCH_TYPE(_case) \
if (src0.layout.dtype == dtype::Float32{}) { \


+ 3
- 30
dnn/src/arm_common/elemwise_op.h View File

@@ -15,10 +15,13 @@
#include "src/arm_common/elemwise_helper/op_binary.h"
#include "src/arm_common/elemwise_helper/op_ternary.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/fallback/elemwise_helper/op_common.h"

namespace megdnn {
namespace arm_common {

using BcastType = megdnn::BcastType;

///////////////////////////////// ParamElemVistor ///////////////////////////
template <typename ctype>
struct ParamElemVisitor;
@@ -99,36 +102,6 @@ cb(__fp16, __fp16, float16x8_t, f16);
#endif
#undef cb

/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
*/
enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST111C_VEC_BCAST111C,
BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC,
VEC_BCAST111C_VEC,
VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE
};

///////////////////////////////// OpCaller /////////////////////////////
template <typename Op, BcastType bcast_type>
struct OpCallerUnary;


dnn/src/fallback/elemwise/opr_binary_impl.cpp → dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp View File

@@ -1,14 +1,7 @@
/**
* \file dnn/src/fallback/elemwise/opr_binary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* \file dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp
*/
#include "./opr_impl.h"
#include "src/fallback/elemwise/opr_impl.h"

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"

dnn/src/fallback/elemwise/opr_unary_impl.cpp → dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp View File

@@ -1,14 +1,7 @@
/**
* \file dnn/src/fallback/elemwise/opr_unary_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* \file dnn/src/fallback/elemwise/fallback_impl/opr_unary_impl.cpp
*/
#include "./opr_impl.h"
#include "src/fallback/elemwise/opr_impl.h"

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"

+ 535
- 0
dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp View File

@@ -0,0 +1,535 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.cpp
*/
#include "src/fallback/elemwise/gi_impl/binary/algo.h"
#include "src/fallback/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_binary)

using namespace megdnn;
using namespace fallback;

namespace {
static inline bool is_available_common(Elemwise::Mode mode) {
/**
* Fused sigmoid & tanh may be slower than the naive algo, because the
* time used by neon function `exp_ps_f32` is decided by the input.
*/
if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID ||
mode == Elemwise::Mode::FUSE_ADD_TANH) {
return false;
}

return true;
}
} // anonymous namespace

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \
mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \
mode == Mode::FUSE_ADD_H_SWISH) \
return true;

#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
mode == Mode::SUB || mode == Mode::MUL || mode == Mode::FUSE_ADD_RELU) \
return true;

bool ElemwiseImpl::AlgoBinaryVecVec::is_available(const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
(BcastType::VEC_VEC != kern_param.broad_cast_type))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

//! exactly match [x, y] + [x, y]
DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::is_available"_hash);

return false;
}

bool ElemwiseImpl::AlgoBinaryVecScalar::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_SCALAR != kern_param.broad_cast_type) &&
(BcastType::SCALAR_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::is_available"_hash);
return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) &&
(BcastType::BCAST101_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::is_available"_hash);

return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) &&
(BcastType::BCASTX0X_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::is_available"_hash);

return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCAST111C != kern_param.broad_cast_type) &&
(BcastType::BCAST111C_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::is_available"_hash);

return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) &&
(BcastType::BCAST101xX_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::is_available"_hash);

return false;
}

#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
DISPATCH_BINARY( \
FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}

#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}

void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];

//! exactly match [x, y] + [x, y]
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t)> \
run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoBinaryVecVec::exec"_hash);

#undef DISPATCH_BINARY

return;
}

void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);

// Case 2: vector + scalar
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type, _type*, DType, DType, DType, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr())[0], \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) {
DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_vec_sca"_hash);
}
#undef DISPATCH_BINARY

// scalar + vector
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type, const _type*, _type*, DType, DType, DType, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::SCALAR_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr())[0], \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, \
src1.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) {
DISPATCH_TYPE_FALLBACK("AlgoBinaryVecScalar::exec_sca_vec"_hash);
}
#undef DISPATCH_BINARY

return;
}

void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;

// Case 3: BcastType::VEC + BCAST_101
if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type &&
is_broadcasted_channel_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCAST101>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_vec_b"_hash);

#undef DISPATCH_BINARY
}

// BCAST_101 + BcastType::VEC
if (BcastType::BCAST101_VEC == kern_param.broad_cast_type &&
is_broadcasted_channel_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCAST101_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101::exec_b_vec"_hash);

#undef DISPATCH_BINARY
}
return;
}

void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;

// Case: BcastType::VEC + BCAST_X0X
if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_vec_b"_hash);

#undef DISPATCH_BINARY
}

// BCAST_X0X + BcastType::VEC
if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcastX0X::exec_b_vec"_hash);

#undef DISPATCH_BINARY
}
return;
}

void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;

// Case extra: BcastType::VEC + BCAST_111C
if (BcastType::VEC_BCAST111C == kern_param.broad_cast_type &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_vec_b"_hash);

#undef DISPATCH_BINARY
}

// BCAST_111C + BcastType::VEC
if (BcastType::BCAST111C_VEC == kern_param.broad_cast_type &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast111C::exec_b_vec"_hash);

#undef DISPATCH_BINARY
}
return;
}

void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;

// BcastType::VEC + BCAST_101X
if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) {
megdnn_assert(
is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \
binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_vec_b"_hash);

#undef DISPATCH_BINARY
}

// BCAST_101x + BcastType::VEC
if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) {
megdnn_assert(
is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \
binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);

DISPATCH_TYPE_FALLBACK("AlgoBinaryVecBcast101xX::exec_b_vec"_hash);

#undef DISPATCH_BINARY
}
return;
}

#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

// vim: syntax=cpp.doxygen

+ 36
- 0
dnn/src/fallback/elemwise/gi_impl/binary/algo.h View File

@@ -0,0 +1,36 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/binary/algo.h
*/

#pragma once
#include "src/fallback/elemwise/opr_impl.h"

namespace megdnn {
namespace fallback {

#define DECL_CB(case) \
class ElemwiseImpl::AlgoBinary##case final : public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \
} \
return m_name.c_str(); \
} \
bool is_available(const KernParam&) const override; \
void exec(const KernParam&) const override; \
};

DECL_CB(VecVec);
DECL_CB(VecScalar);
DECL_CB(VecBcast101);
DECL_CB(VecBcastX0X);
DECL_CB(VecBcast111C);
DECL_CB(VecBcast101xX);
#undef DECL_CB
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 383
- 0
dnn/src/fallback/elemwise/gi_impl/gi_mathfun.cpp View File

@@ -0,0 +1,383 @@
/**
* \file dnn/src/fallback/elemwise/gi_mathfun.cpp
*
* This file has been modified by Megvii ("Megvii Modifications").
* All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights
* reserved.
*
*/

/* NEON implementation of sin, cos, exp and log

Inspired by Intel Approximate Math library, and based on the
corresponding algorithms of the cephes math library
*/

/* Copyright (C) 2011 Julien Pommier

This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.

Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:

1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.
3. This notice may not be removed or altered from any source distribution.

(this is the zlib license)
*/

#include "./gi_mathfun.h"

namespace megdnn {
namespace fallback {

#define c_inv_mant_mask ~0x7f800000u
#define c_cephes_SQRTHF 0.707106781186547524
#define c_cephes_log_p0 7.0376836292E-2
#define c_cephes_log_p1 -1.1514610310E-1
#define c_cephes_log_p2 1.1676998740E-1
#define c_cephes_log_p3 -1.2420140846E-1
#define c_cephes_log_p4 +1.4249322787E-1
#define c_cephes_log_p5 -1.6668057665E-1
#define c_cephes_log_p6 +2.0000714765E-1
#define c_cephes_log_p7 -2.4999993993E-1
#define c_cephes_log_p8 +3.3333331174E-1
#define c_cephes_log_q1 -2.12194440e-4
#define c_cephes_log_q2 0.693359375

/**
* natural logarithm computed for 4 simultaneous float return NaN for x <= 0
*/
v4sf GiLogPsFloat32(v4sf x) {
v4sf one = GiBroadcastFloat32(1);

x = GiMaximumFloat32(
x, GiBroadcastFloat32(0)); /* force flush to zero on denormal values */
v4su invalid_mask = GiLessThanEqFloat32(x, GiBroadcastFloat32(0));

v4si ux = GiReinterpretAsInt32(x);

v4si emm0 = GiShiftRight23Int32(ux);

/* keep only the fractional part */
ux = GiAndInt32(ux, GiBroadcastInt32(c_inv_mant_mask));
ux = GiOrInt32(ux, GiReinterpretAsInt32(GiBroadcastFloat32(0.5f)));
x = GiReintInt32ToFloat32(ux);

emm0 = GiSubtractInt32(emm0, GiBroadcastInt32(0x7f));
v4sf e = GiCastToFloat32(emm0);

e = GiAddFloat32(e, one);

/* part2:
* if( x < SQRTHF ) {
* e -= 1;
* x = x + x - 1.0;
* } else { x = x - 1.0; }
*/
v4su mask = GiLessThanFloat32(x, GiBroadcastFloat32(c_cephes_SQRTHF));
v4sf tmp = GiAndFloat32(x, GiReintUint32ToFloat32(mask));
x = GiSubtractFloat32(x, one);
e = GiSubtractFloat32(e, GiAndFloat32(one, GiReintUint32ToFloat32(mask)));
x = GiAddFloat32(x, tmp);

v4sf z = GiMultiplyFloat32(x, x);

v4sf y = GiBroadcastFloat32(c_cephes_log_p0);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p1), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p2), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p3), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p4), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p5), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p6), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p7), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_log_p8), y, x);
y = GiMultiplyFloat32(y, x);

y = GiMultiplyFloat32(y, z);

y = GiMultiplyAddFloat32(y, e, GiBroadcastFloat32(c_cephes_log_q1));

y = GiMultiplySubFloat32(y, z, GiBroadcastFloat32(0.5f));

x = GiAddFloat32(x, y);
x = GiMultiplyAddFloat32(x, e, GiBroadcastFloat32(c_cephes_log_q2));
x = GiOrFloat32(
x, GiReintUint32ToFloat32(invalid_mask)); // negative arg will be NAN
return x;
}

#define c_exp_hi 88.3762626647949f
#define c_exp_lo -88.3762626647949f

#define c_cephes_LOG2EF 1.44269504088896341
#define c_cephes_exp_C1 0.693359375
#define c_cephes_exp_C2 -2.12194440e-4

#define c_cephes_exp_p0 1.9875691500E-4
#define c_cephes_exp_p1 1.3981999507E-3
#define c_cephes_exp_p2 8.3334519073E-3
#define c_cephes_exp_p3 4.1665795894E-2
#define c_cephes_exp_p4 1.6666665459E-1
#define c_cephes_exp_p5 5.0000001201E-1

/* exp() computed for 4 float at once */
v4sf GiExpPsFloat32(v4sf x) {
v4sf tmp, fx;

v4sf one = GiBroadcastFloat32(1);
x = GiMinimumFloat32(x, GiBroadcastFloat32(c_exp_hi));
x = GiMaximumFloat32(x, GiBroadcastFloat32(c_exp_lo));

/* express exp(x) as exp(g + n*log(2)) */
fx = GiMultiplyAddFloat32(
GiBroadcastFloat32(0.5f), x, GiBroadcastFloat32(c_cephes_LOG2EF));

/* perform a floorf */
tmp = GiCastToFloat32(GiCastToInt32(fx));

/* if greater, subtract 1 */
v4su mask = GiGreaterThanFloat32(tmp, fx);
v4sf mask_float = GiAndFloat32(GiReintUint32ToFloat32(mask), one);

fx = GiSubtractFloat32(tmp, mask_float);

tmp = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C1));
v4sf z = GiMultiplyFloat32(fx, GiBroadcastFloat32(c_cephes_exp_C2));
x = GiSubtractFloat32(x, tmp);
x = GiSubtractFloat32(x, z);

z = GiMultiplyFloat32(x, x);

v4sf y = GiBroadcastFloat32(c_cephes_exp_p0);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p1), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p2), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p3), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p4), y, x);
y = GiMultiplyAddFloat32(GiBroadcastFloat32(c_cephes_exp_p5), y, x);

y = GiMultiplyAddFloat32(x, y, z);
y = GiAddFloat32(y, one);

/* build 2^n */
v4si mm;
mm = GiCastToInt32(fx);
mm = GiAddInt32(mm, GiBroadcastInt32(0x7f));
mm = GiShiftLeft23Int32(mm);
v4sf pow2n = GiReintInt32ToFloat32(mm);

y = GiMultiplyFloat32(y, pow2n);
return y;
}

#define c_minus_cephes_DP1 -0.78515625
#define c_minus_cephes_DP2 -2.4187564849853515625e-4
#define c_minus_cephes_DP3 -3.77489497744594108e-8
#define c_sincof_p0 -1.9515295891E-4
#define c_sincof_p1 8.3321608736E-3
#define c_sincof_p2 -1.6666654611E-1
#define c_coscof_p0 2.443315711809948E-005
#define c_coscof_p1 -1.388731625493765E-003
#define c_coscof_p2 4.166664568298827E-002
#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI

/* evaluation of 4 sines & cosines at once.

The code is the exact rewriting of the cephes sinf function.
Precision is excellent as long as x < 8192 (I did not bother to
take into account the special handling they have for greater values
-- it does not return garbage for arguments over 8192, though, but
the extra precision is missing).

Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
surprising but correct result.

Note also that when you compute sin(x), cos(x) is available at
almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
sincos_ps_f32..
*/
void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos) {
// any x
v4sf y;

v4su emm2;

v4su sign_mask_sin, sign_mask_cos;
sign_mask_sin = GiLessThanFloat32(x, GiBroadcastFloat32(0));
x = GiAbsFloat32(x);

/* scale by 4/Pi */
y = GiMultiplyFloat32(x, GiBroadcastFloat32(c_cephes_FOPI));

/* store the integer part of y in mm0 */
emm2 = GiReinterpretAsUint32(y);
/* j=(j+1) & (~1) (see the cephes sources) */
emm2 = GiAddUint32(emm2, GiBroadcastUint32(1));
emm2 = GiAddUint32(emm2, GiBroadcastUint32(~1));
y = GiReintUint32ToFloat32(emm2);

/* get the polynom selection mask
* there is one polynom for 0 <= x <= Pi/4
* and another one for Pi/4<x<=Pi/2
*
* Both branches will be computed.
*/
v4su poly_mask = GiTestAndSetUint32(emm2, GiBroadcastUint32(2));

/* The magic pass: "Extended precision modular arithmetic"
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP1));
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP2));
x = GiMultiplyAddFloat32(x, y, GiBroadcastFloat32(c_minus_cephes_DP3));

sign_mask_sin =
GiEOrUint32(sign_mask_sin, GiTestAndSetUint32(emm2, GiBroadcastUint32(4)));
sign_mask_cos = GiTestAndSetUint32(
GiSubtractUint32(emm2, GiBroadcastUint32(2)), GiBroadcastUint32(4));

/* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
* and the second polynom (Pi/4 <= x <= 0) in y2 */
v4sf z = GiMultiplyFloat32(x, x);
v4sf y1, y2;

y1 = GiMultiplyAddFloat32(
GiBroadcastFloat32(c_coscof_p1), z, GiBroadcastFloat32(c_coscof_p0));
y2 = GiMultiplyAddFloat32(
GiBroadcastFloat32(c_sincof_p1), z, GiBroadcastFloat32(c_sincof_p0));
y1 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_coscof_p2), y1, z);
y2 = GiMultiplyAddFloat32(GiBroadcastFloat32(c_sincof_p2), y2, z);
y1 = GiMultiplyFloat32(y1, z);
y2 = GiMultiplyFloat32(y2, z);
y1 = GiMultiplyFloat32(y1, z);
y1 = GiMultiplySubFloat32(y1, z, GiBroadcastFloat32(0.5f));
y2 = GiMultiplyAddFloat32(x, y2, x);
y1 = GiAddFloat32(y1, GiBroadcastFloat32(1));

/* select the correct result from the two polynoms */
v4sf ys = GiBSLFloat32(poly_mask, y1, y2);
v4sf yc = GiBSLFloat32(poly_mask, y2, y1);
*ysin = GiBSLFloat32(sign_mask_sin, GiNegFloat32(ys), ys);
*ycos = GiBSLFloat32(sign_mask_cos, yc, GiNegFloat32(yc));
}

v4sf GiSinPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ysin;
}

v4sf GiCosPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ycos;
}

v4sf GiTanPsFloat32(v4sf x) {
v4sf ysin, ycos;
GiSinCosPsFloat32(x, &ysin, &ycos);
return ysin / ycos;
}

#undef c_exp_hi
#undef c_exp_lo
#undef c_cephes_LOG2EF
#undef c_cephes_exp_C1
#undef c_cephes_exp_C2
#undef c_cephes_exp_p0
#undef c_cephes_exp_p1
#undef c_cephes_exp_p2
#undef c_cephes_exp_p3
#undef c_cephes_exp_p4
#undef c_cephes_exp_p5

#undef c_minus_cephes_DP1
#undef c_minus_cephes_DP2
#undef c_minus_cephes_DP3
#undef c_sincof_p0
#undef c_sincof_p1
#undef c_sincof_p2
#undef c_coscof_p0
#undef c_coscof_p1
#undef c_coscof_p2
#undef c_cephes_FOPI

#undef c_inv_mant_mask
#undef c_cephes_SQRTHF
#undef c_cephes_log_p0
#undef c_cephes_log_p1
#undef c_cephes_log_p2
#undef c_cephes_log_p3
#undef c_cephes_log_p4
#undef c_cephes_log_p5
#undef c_cephes_log_p6
#undef c_cephes_log_p7
#undef c_cephes_log_p8
#undef c_cephes_log_q1
#undef c_cephes_log_q2

static const struct {
float lower_range;
float upper_range;
float alpha_9;
float alpha_7;
float alpha_5;
float alpha_3;
float alpha_1;
float beta_10;
float beta_8;
float beta_6;
float beta_4;
float beta_2;
float beta_0;
float one_half;
} sigmoid_constants = {
-18.0f,
18.0f,
4.37031012579801e-11f,
1.15627324459942e-07f,
6.08574864600143e-05f,
8.51377133304701e-03f,
2.48287947061529e-01f,
6.10247389755681e-13f,
5.76102136993427e-09f,
6.29106785017040e-06f,
1.70198817374094e-03f,
1.16817656904453e-01f,
9.93151921023180e-01f,
0.5f,
};

v4sf GiSigmoidPsFloat32(v4sf src) {
auto val = GiMaximumFloat32(GiBroadcastFloat32(sigmoid_constants.lower_range), src);
val = GiMinimumFloat32(GiBroadcastFloat32(sigmoid_constants.upper_range), val);
auto squared = GiMultiplyFloat32(val, val);
auto p = GiMultiplyAddFloat32(
GiBroadcastFloat32(sigmoid_constants.alpha_7), squared,
GiBroadcastFloat32(sigmoid_constants.alpha_9));
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_5), p, squared);
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_3), p, squared);
p = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.alpha_1), p, squared);
p = GiMultiplyFloat32(p, val);
auto q = GiMultiplyAddFloat32(
GiBroadcastFloat32(sigmoid_constants.beta_8), squared,
GiBroadcastFloat32(sigmoid_constants.beta_10));
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_6), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_4), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_2), q, squared);
q = GiMultiplyAddFloat32(GiBroadcastFloat32(sigmoid_constants.beta_0), q, squared);
return GiAddFloat32(
GiDivideFloat32(p, q), GiBroadcastFloat32(sigmoid_constants.one_half));
}

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 55
- 0
dnn/src/fallback/elemwise/gi_impl/gi_mathfun.h View File

@@ -0,0 +1,55 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/gi_mathfun.h
*/

#pragma once

#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"

namespace megdnn {
namespace fallback {

typedef GI_FLOAT32_t v4sf; // vector of 4 float
typedef GI_INT32_t v4si; // vector of 4 int32
typedef GI_UINT32_t v4su; // vector of 4 uint32

/**
* \brief natural logarithm computed for 4 simultaneous float
* return NaN for x <= 0
*/
v4sf GiLogPsFloat32(v4sf x);

//! exp() computed for 4 float at once
v4sf GiExpPsFloat32(v4sf x);

/**
* \brief evaluation of 4 sines & cosines at once.
*
* The code is the exact rewriting of the cephes sinf function.
* Precision is excellent as long as x < 8192 (I did not bother to
* take into account the special handling they have for greater values
* -- it does not return garbage for arguments over 8192, though, but
* the extra precision is missing).
*
* Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
* surprising but correct result.
*
* Note also that when you compute sin(x), cos(x) is available at
* almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
* sincos_ps_f32..
*/
void GiSinCosPsFloat32(v4sf x, v4sf* ysin, v4sf* ycos);

v4sf GiSinPsFloat32(v4sf x);

v4sf GiCosPsFloat32(v4sf x);

v4sf GiTanPsFloat32(v4sf x);

v4sf GiSigmoidPsFloat32(v4sf x);

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 459
- 0
dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp View File

@@ -0,0 +1,459 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.cpp
*/

#include "src/fallback/elemwise/gi_impl/ternary/algo.h"
#include "src/fallback/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_ternary)

using namespace megdnn;
using namespace fallback;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::FUSE_MUL_ADD3) \
return true;
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT

#define DECL_AVAILABLE(case, type) \
bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available( \
const KernParam& kern_param) const { \
if (type == kern_param.broad_cast_type) { \
auto& elparam = kern_param.ternary_elparam; \
auto& src0 = elparam[0]; \
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3::is_available" #case##_hash); \
} \
return false; \
}

DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C);
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC);
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT
void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 1: shape of (src0, src2) and src1 are exactly match
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecVec::exec"_hash);
#undef DISPATCH_TERNARY

return;
}
void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 2: (src2 is a scalar) && (src0 and src1 has the same shape)
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr())[0], \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecVecScalar::exec"_hash);
#undef DISPATCH_TERNARY

return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 3: shape of src0 and src2 is {1, C, 1, 1}
BroadcastChannelInfo binfo;
is_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \
static_cast<const _type*>(src2.raw_ptr()) + offset, \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return

size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();

size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 3: shape of src0 and src2 is {1, 1, 1, C}
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \
binfo, dst, run](size_t task_id, size_t) { \
size_t offset = task_id * nr_channels_per_thread; \
size_t nr_channels_thread = \
std::min(nr_channels - offset, nr_channels_per_thread); \
size_t src1_offset = \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()) + \
offset * (binfo.z + src1_offset), \
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \
binfo.y * binfo.z); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return

size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();

size_t nr_channels = binfo.y;
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads;
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

BroadcastChannelInfo binfo;
megdnn_assert(
is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return

size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

BroadcastChannelInfo binfo;
megdnn_assert(
is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
batch_size, binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return

size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101xXVec::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig
BroadcastChannelInfo binfo;
is_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, const _type*, _type*, DType, DType, \
DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast101Vec::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig
BroadcastChannelInfo binfo;
is_NHWC_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, size_t, const _type*, const _type*, size_t, _type*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src1.raw_ptr()), \
static_cast<const _type*>(src2.raw_ptr()), \
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecBcast111CVec::exec"_hash);
#undef DISPATCH_TERNARY

return;
}

void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 5: (src1 is a scalar) && (src0 and src2 has the same shape)
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type, const _type*, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr())[0], \
static_cast<const _type*>(src2.raw_ptr()), \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarVec::exec"_hash);
#undef DISPATCH_TERNARY

return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec(
const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

// Case 6: (src1 and src2 is scalar) && (src0 is vector)
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type, const _type, _type*, DType, DType, \
DType, DType, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr()), \
static_cast<const _type*>(src1.raw_ptr())[0], \
static_cast<const _type*>(src2.raw_ptr())[0], \
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
src0.layout.total_nr_elems())); \
} \
MIDOUT_END(); \
return

auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE_FALLBACK("AlgoTernaryFma3VecScalarScalar::exec"_hash);
#undef DISPATCH_TERNARY

return;
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/fallback/elemwise/gi_impl/ternary/algo.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/ternary/algo.h
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"

namespace megdnn {
namespace fallback {

#define DECL_CB(case) \
class ElemwiseImpl::AlgoTernaryFma3##case final : public ElemwiseImpl::AlgoBase { \
mutable std::string m_name; \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
} \
const char* name() const override { \
if (m_name.empty()) { \
m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \
} \
return m_name.c_str(); \
} \
bool is_available(const KernParam&) const override; \
void exec(const KernParam&) const override; \
};

DECL_CB(VecVecVec);
DECL_CB(VecVecScalar);
DECL_CB(Bcast101VecBcast101);
DECL_CB(Bcast111CVecBcast111C);
DECL_CB(Bcast101xXVecBcast101xX);
DECL_CB(VecBcast101Vec);
DECL_CB(VecBcast111CVec);
DECL_CB(VecBcast101xXVec);
DECL_CB(VecScalarVec);
DECL_CB(VecScalarScalar);
#undef DECL_CB
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 125
- 0
dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp View File

@@ -0,0 +1,125 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.cpp
*/
#include "src/fallback/elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_unary)

using namespace megdnn;
using namespace fallback;

bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const {
if (BcastType::VEC != kern_param.broad_cast_type)
return false;

if (kern_param.m_dst->layout.dtype.category() != DTypeCategory::FLOAT &&
(kern_param.mode == Mode::EXP || kern_param.mode == Mode::SIGMOID ||
kern_param.mode == Mode::TANH || kern_param.mode == Mode::FAST_TANH ||
kern_param.mode == Mode::H_SWISH)) {
return false;
}
//! As `NEGATE` mode is so simple, that the code generate by compiler is
//! vectorized optimized, while other mode such as `ABS` has branch, the
//! compiler may not generate code as good as user intrinsic.
if (kern_param.mode == Mode::NEGATE) {
return false;
}

auto& elparam = kern_param.unary_elparam;
if (!elparam[0].layout.is_contiguous())
return false;
megdnn_assert(elparam[0].layout.ndim == 1);
auto& src0 = elparam[0];

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::RELU || mode == Mode::ABS || mode == Mode::SIGMOID || \
mode == Mode::EXP || mode == Mode::TANH || mode == Mode::FAST_TANH || \
mode == Mode::H_SWISH) \
return true;

#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
auto mode = kern_param.mode; \
if (mode == Mode::RELU || mode == Mode::ABS) \
return true;

DISPATCH_TYPE_FALLBACK("AlgoUnary::is_available"_hash);
return false;
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
}

void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const {
#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_fallback_elemwise_unary, midout_iv(_case), \
midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \
thin_function<void(const _type*, _type*, DType, DType, size_t)> run = \
OpCallerUnary<_op<_type, _type>, BcastType::VEC>::run; \
auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, run]( \
size_t task_id, size_t) { \
size_t offset = task_id * nr_elems_per_thread; \
size_t nr_elems_thread = \
std::min(nr_elems - offset, nr_elems_per_thread); \
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \
static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \
src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \
kernel); \
} \
MIDOUT_END(); \
return

auto& elparam = kern_param.unary_elparam;
megdnn_assert(elparam[0].layout.ndim == 1);
auto& src0 = elparam[0];
auto& dst_tensor = *(kern_param.m_dst);

size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle)
->megcore_dispatcher()
->nr_threads();

size_t nr_elems = src0.layout.total_nr_elems();
size_t nr_elems_per_thread = (nr_elems + nr_threads - 1) / nr_threads;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \
DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \
DISPATCH_UNARY(SIGMOID, _case, _type, _type_midout_id, SigmoidOp); \
DISPATCH_UNARY(EXP, _case, _type, _type_midout_id, ExpOp); \
DISPATCH_UNARY(TANH, _case, _type, _type_midout_id, TanhOp); \
DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \
DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}

#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
switch (kern_param.mode) { \
DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \
DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \
default: \
megdnn_throw(ssprintf( \
"No avaiable algo find for: %d", \
static_cast<int>(kern_param.mode))); \
}

DISPATCH_TYPE_FALLBACK("AlgoUnary::exec"_hash);
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
#undef DISPATCH_UNARY
}

// vim: syntax=cpp.doxygen

+ 25
- 0
dnn/src/fallback/elemwise/gi_impl/unary/algo.h View File

@@ -0,0 +1,25 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/unary/algo.h
*/
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
namespace megdnn {
namespace fallback {
class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase {
mutable std::string m_name;

AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override {
if (m_name.empty()) {
m_name = ssprintf("Elemwise::AlgoUnary");
}
return m_name.c_str();
}

bool is_available(const KernParam&) const override;
void exec(const KernParam&) const override;
};

} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen

+ 238
- 4
dnn/src/fallback/elemwise/opr_impl.cpp View File

@@ -12,6 +12,9 @@

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/fallback//elemwise/gi_impl/unary/algo.h"
#include "src/fallback/elemwise/gi_impl/binary/algo.h"
#include "src/fallback/elemwise/gi_impl/ternary/algo.h"
#include "src/naive/handle.h"

#include "midout.h"
@@ -21,13 +24,22 @@ MIDOUT_DECL(megdnn_fallback_elemwise_exec_UNARY_FLOAT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_INT)
MIDOUT_DECL(megdnn_fallback_elemwise_exec_BINARY_FLOAT)

namespace megdnn {
namespace fallback {
using namespace megdnn;
using namespace fallback;

void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
}
if (!exec_gi_intrinsic(srcs, dst)) {
return exec_fallback(srcs, dst);
}
}

void ElemwiseImpl::exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
if (!dst.layout.is_contiguous()) {
return naive::ElemwiseForwardImpl::exec(srcs, dst);
}

m_src = &srcs;
m_dst = &dst;
@@ -82,7 +94,229 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
naive::ElemwiseForwardImpl::exec(srcs, dst);
}

} // namespace fallback
} // namespace megdnn
class ElemwiseImpl::AlgoPack {
#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7)
AlgoUnary algo_unary;
AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X;
AlgoBinaryVecBcast111C algo_binary_vec_bcast110;
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
AlgoTernaryFma3Bcast111CVecBcast111C algo_ternaryfma3_bcast110_vec_bcast110;
AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX;
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
AlgoTernaryFma3VecBcast111CVec algo_ternaryfma3_vec_bcast110_vec;
AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec;
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
#endif

public:
AlgoPack() {
#if !(MEGDNN_AARCH64 || MEGDNN_ARMV7)
all_algos.emplace_back(&algo_unary);
all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_vec_bcastX0X);
all_algos.emplace_back(&algo_binary_vec_bcast110);
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
all_algos.emplace_back(&algo_ternaryfma3_bcast110_vec_bcast110);
all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast110_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
#endif
}
SmallVector<AlgoBase*> all_algos;
};

bool ElemwiseImpl::exec_gi_intrinsic(
const TensorNDArray& srcs, _megdnn_tensor_out dst) {
m_src = &srcs;
m_dst = &dst;

if (m_dst->layout.dtype == dtype::Float32() ||
m_dst->layout.dtype == dtype::Int32() ||
m_dst->layout.dtype == dtype::Int16() || m_dst->layout.dtype == dtype::Int8()) {
auto kern_param = make_kern_param(this);
kern_param.m_dst = &dst;
static AlgoPack m_algo_pack;
for (auto& m_algo : m_algo_pack.all_algos) {
if (m_algo->is_available(kern_param)) {
m_algo->exec(kern_param);
return true;
}
}
}
return false;
}

ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
KernParam kern_param;
kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE;
kern_param.mode = opr->param().mode;
kern_param.handle = opr->handle();

auto is_legal_layout_for_nhwc = [](const TensorLayout& l) {
if (is_vector(l))
return true;
if (l.ndim == 2 && l.stride[1] == 1)
return true;
return false;
};

if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
bool c_is_scalar;
opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar);
auto &src0 = kern_param.ternary_elparam[0],
&src1 = kern_param.ternary_elparam[1],
&src2 = kern_param.ternary_elparam[2];
BroadcastChannelInfo binfo;

if (is_vector(src0.layout) && is_vector(src1.layout) &&
is_vector(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) {
kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101;
return kern_param;
}

if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
return kern_param;
}

if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo) &&
src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC_BCAST111C;
return kern_param;
}

if (is_legal_layout_for_nhwc(src0.layout) &&
src2.layout.eq_layout(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C_VEC;
return kern_param;
}

if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_vector(src2.layout) &&
is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
is_broadcasted_scalar(src2.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR;
return kern_param;
}
} else if (opr->m_src->size() == 2) {
kern_param.binary_elparam = opr->make_elemwise_op_param<2>();
auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1];
BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && is_vector(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
kern_param.broad_cast_type = BcastType::VEC_SCALAR;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) {
kern_param.broad_cast_type = BcastType::SCALAR_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST101_VEC;
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src0.layout) &&
is_NHWC_broadcasted_channel_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCAST111C;
return kern_param;
}

if (is_vector(src0.layout) &&
(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
return kern_param;
}

if (is_vector(src1.layout) &&
(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
return kern_param;
}
} else if (opr->m_src->size() == 1) {
kern_param.broad_cast_type = BcastType::VEC;
kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
return kern_param;
}

return kern_param;
}
// vim: syntax=cpp.doxygen

+ 57
- 0
dnn/src/fallback/elemwise/opr_impl.h View File

@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/fallback/elemwise_op.h"
#include "src/naive/elemwise/opr_impl.h"

namespace megdnn {
@@ -33,13 +34,69 @@ class ElemwiseImpl : public naive::ElemwiseForwardImpl {
template <uint32_t mode>
void exec_BINARY_FLOAT();

void exec_fallback(const TensorNDArray& srcs, _megdnn_tensor_out dst);
bool exec_gi_intrinsic(const TensorNDArray& srcs, _megdnn_tensor_out dst);

private:
class AlgoUnary;
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcastX0X;
class AlgoBinaryVecBcast111C;
class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec;
class AlgoTernaryFma3VecVecScalar;
class AlgoTernaryFma3Bcast101VecBcast101;
class AlgoTernaryFma3Bcast111CVecBcast111C;
class AlgoTernaryFma3Bcast101xXVecBcast101xX;
class AlgoTernaryFma3VecBcast101Vec;
class AlgoTernaryFma3VecBcast111CVec;
class AlgoTernaryFma3VecBcast101xXVec;
class AlgoTernaryFma3VecScalarVec;
class AlgoTernaryFma3VecScalarScalar;
class AlgoPack;

public:
class AlgoBase;
struct KernParam {
BcastType broad_cast_type;
Mode mode;
const TensorND* m_dst;
Handle* handle;
ElemwiseOpParamN<3> ternary_elparam;
ElemwiseOpParamN<2> binary_elparam;
ElemwiseOpParamN<1> unary_elparam;
};
KernParam make_kern_param(ElemwiseImpl* opr);
using naive::ElemwiseForwardImpl::ElemwiseForwardImpl;
void exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) override;

const char* get_algorithm_set_name() const { return "FALLBACK ELEMWISE"; }
bool is_thread_safe() const override { return true; }
};

/*!
* \brief base class for Elemwise algo
*/
class ElemwiseImpl::AlgoBase : public detail::Algorithm {
public:
virtual bool is_available(const KernParam&) const = 0;
virtual void exec(const KernParam&) const = 0;
virtual ~AlgoBase() = default;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
};

//! fallback only support float, int32, int8
#define DISPATCH_TYPE_FALLBACK(_case) \
if (src0.layout.dtype == dtype::Float32{}) { \
DISPATCH_MODE_FLOAT(_case, float, 0); \
} else if (src0.layout.dtype == dtype::Int32{}) { \
DISPATCH_MODE_INT(_case, int, 2); \
} else if (src0.layout.dtype == dtype::Int8{}) { \
DISPATCH_MODE_INT(_case, dt_int8, 4); \
}

} // namespace fallback
} // namespace megdnn



+ 78
- 0
dnn/src/fallback/elemwise_helper/kimpl/abs.h View File

@@ -0,0 +1,78 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/abs.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct AbsOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : (-src); }
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct AbsOp;

#define OP(_ctype, _gi_type, _func_suffix, _simd_width) \
template <> \
struct AbsOp<_ctype> : AbsOpBase<_ctype> { \
using AbsOpBase::AbsOpBase; \
using AbsOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _gi_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_gi_type operator()(const _gi_type& src) const { \
auto vitem0 = GiAbs##_func_suffix(src.val[0]); \
auto vitem1 = GiAbs##_func_suffix(src.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(dt_float32))
OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32))
OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8))
#undef OP

template <>
struct AbsOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint8& src) const {
float fsrc = src.as_int8() * this->scale;
fsrc = fsrc > 0 ? fsrc : -fsrc;
return QConverter::convert<dt_qint8, float>(fsrc);
}
};

template <>
struct AbsOp<dt_qint8, dt_qint8> : AbsOpBase<dt_qint8, dt_qint8> {
using AbsOpBase::AbsOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using AbsOpBase::operator();
void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiAbsFloat32(vitem0);
vitem1 = GiAbsFloat32(vitem1);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 134
- 0
dnn/src/fallback/elemwise_helper/kimpl/add.h View File

@@ -0,0 +1,134 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/add.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct AddOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}

dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 + src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct AddOp;

#define OP(_ctype, _gi_type, _gi_type2, _func_suffix, _simd_width) \
template <> \
struct AddOp<_ctype> : AddOpBase<_ctype> { \
using AddOpBase::AddOpBase; \
using AddOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _gi_type2& src0, const _gi_type2& src1, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_gi_type2 operator()(const _gi_type2& src0, const _gi_type2& src1) const { \
auto vitem0 = GiAdd##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiAdd##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _gi_type& src0, const _gi_type& src1, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_gi_type operator()(const _gi_type& src0, const _gi_type& src1) const { \
return GiAdd##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32,
GI_SIMD_LEN_BYTE / sizeof(dt_float32))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(dt_int32))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(dt_int8))
#undef OP

template <>
struct AddOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1);
}
};

template <>
struct AddOp<dt_qint8, dt_qint8> : AddOpBase<dt_qint8, dt_qint8> {
using AddOpBase::AddOpBase;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);
using AddOpBase::operator();

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));

return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

template <>
struct AddOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1);
}
};

template <>
struct AddOp<dt_qint32, dt_qint8> : AddOpBase<dt_qint32, dt_qint8> {
using AddOpBase::AddOpBase;
using AddOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 49
- 0
dnn/src/fallback/elemwise_helper/kimpl/exp.h View File

@@ -0,0 +1,49 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/exp.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct ExpOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
return exp(tmp);
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct ExpOp;

#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct ExpOp<_ctype> : ExpOpBase<_ctype> { \
using ExpOpBase::ExpOpBase; \
using ExpOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto vitem0 = GiExpPs##_func_suffix(src.val[0]); \
auto vitem1 = GiExpPs##_func_suffix(src.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 70
- 0
dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h View File

@@ -0,0 +1,70 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fast_tanh.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

//! tanh = x * (27 + x^2) / (27 + 9 * x^2)
template <typename src_ctype, typename dst_ctype = src_ctype>
struct FastTanhOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float x = src;
return x * (27.f + x * x) / (27.f + 9.f * x * x);
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FastTanhOp;

#define OP(_ctype, _simd_type, _func_suffix, _fix_func_suffix, _simd_width) \
template <> \
struct FastTanhOp<_ctype> : FastTanhOpBase<_ctype> { \
using FastTanhOpBase::FastTanhOpBase; \
using FastTanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto val_27 = GiBroadcast##_func_suffix(27.f); \
auto val_9 = GiBroadcast##_func_suffix(9.f); \
auto valx = src.val[0]; \
auto valx1 = src.val[1]; \
auto valxp2 = GiMultiply##_fix_func_suffix(valx, valx); \
auto valx1p2 = GiMultiply##_fix_func_suffix(valx1, valx1); \
auto denominator = GiAdd##_fix_func_suffix(valxp2, val_27); \
auto denominator1 = GiAdd##_fix_func_suffix(valx1p2, val_27); \
valx = GiMultiply##_fix_func_suffix(valx, denominator); \
valx1 = GiMultiply##_fix_func_suffix(valx1, denominator1); \
denominator = GiMultiplyAdd##_fix_func_suffix(val_27, valxp2, val_9); \
denominator1 = GiMultiplyAdd##_fix_func_suffix(val_27, valx1p2, val_9); \
auto r_denominator = GiRecpe##_func_suffix(denominator); \
auto r_denominator1 = GiRecpe##_func_suffix(denominator1); \
r_denominator = GiMultiply##_fix_func_suffix( \
GiRecpeS##_func_suffix(denominator, r_denominator), \
r_denominator); \
r_denominator1 = GiMultiply##_fix_func_suffix( \
GiRecpeS##_func_suffix(denominator1, r_denominator1), \
r_denominator1); \
valx = GiMultiply##_fix_func_suffix(valx, r_denominator); \
valx1 = GiMultiply##_fix_func_suffix(valx1, r_denominator1); \
return {{valx, valx1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 118
- 0
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h View File

@@ -0,0 +1,118 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddHSwishOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmp = src0 + src1;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
return tmp;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddHSwishOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \
using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \
using FuseAddHSwishOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0; \
auto val2 = src1; \
val1 = GiAdd##_func_suffix(val1, val2); \
H_SWISH_KERN_N1_FALLBACK(_func_suffix, val1); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

template <>
struct FuseAddHSwishOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
float tmp =
src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
tmp *= this->scale_dst;
return QConverter::convert<dt_qint8, float>(tmp);
}
};

template <>
struct FuseAddHSwishOp<dt_qint32, dt_qint8> : FuseAddHSwishOpBase<dt_qint32, dt_qint8> {
using FuseAddHSwishOpBase::FuseAddHSwishOpBase;
using FuseAddHSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
GI_FLOAT32_t vitem0, vitem1;

vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1));
vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1));
H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1);
vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst);
vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h"

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 162
- 0
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h View File

@@ -0,0 +1,162 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_relu.h
*/
#pragma once

#include "gi_util_impl_helper.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddReluOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
auto tmp = src0 + src1;
return tmp > 0 ? tmp : 0;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddReluOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \
using FuseAddReluOpBase::FuseAddReluOpBase; \
using FuseAddReluOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, _func_suffix); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0; \
auto val2 = src1; \
FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, _func_suffix); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <typename ctype>
struct FuseAddReluOpCommon;

template <>
struct FuseAddReluOpCommon<float> {
inline static GI_FLOAT32_t vzero() { return GiBroadcastFloat32(0); }
};

template <>
struct FuseAddReluOpCommon<int> {
inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); }
};

template <>
struct FuseAddReluOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(std::max<float>(
src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, 0.f));
}
};

template <>
struct FuseAddReluOp<dt_qint8, dt_qint8> : FuseAddReluOpBase<dt_qint8, dt_qint8>,
FuseAddReluOpCommon<float> {
using FuseAddReluOpBase::FuseAddReluOpBase;
using FuseAddReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int8_t);

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));

vitem0 = GiMaximumFloat32(vitem0, this->vzero());
vitem1 = GiMaximumFloat32(vitem1, this->vzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

template <>
struct FuseAddReluOpBase<dt_qint32, dt_qint8> : BinaryOpBase<dt_qint32, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const {
return QConverter::convert<dt_qint8, float>(std::max<float>(
src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, 0.f));
}
};

template <>
struct FuseAddReluOp<dt_qint32, dt_qint8> : FuseAddReluOpBase<dt_qint32, dt_qint8>,
FuseAddReluOpCommon<float> {
using FuseAddReluOpBase::FuseAddReluOpBase;
using FuseAddReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;
void operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc0, vsrc1));
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiAddFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));

vitem0 = GiMaximumFloat32(vitem0, this->vzero());
vitem1 = GiMaximumFloat32(vitem1, this->vzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 60
- 0
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h View File

@@ -0,0 +1,60 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddSigmoidOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmpf = src0 + src1;
tmpf = exp(-tmpf);
tmpf = 1.f / (1.f + tmpf);
return tmpf;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddSigmoidOp;

#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \
using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \
using FuseAddSigmoidOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
val1 = GiSigmoidPs##_func_suffix(val1); \
val2 = GiSigmoidPs##_func_suffix(val2); \
return {{val1, val2}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 77
- 0
dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h View File

@@ -0,0 +1,77 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddTanhOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float tmpf = exp(src0 + (src1));
float tmpf2 = 1 / tmpf;
return (tmpf - tmpf2) / (tmpf + tmpf2);
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseAddTanhOp;

#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \
using FuseAddTanhOpBase::FuseAddTanhOpBase; \
using FuseAddTanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiAdd##_func_suffix(val1, val3); \
val2 = GiAdd##_func_suffix(val2, val4); \
auto exp1 = GiExpPs##_func_suffix(val1); \
auto exp2 = GiExpPs##_func_suffix(val2); \
auto rexp1 = GiRecpe##_func_suffix(exp1); \
auto rexp2 = GiRecpe##_func_suffix(exp2); \
rexp1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \
rexp2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \
val1 = GiSubtract##_func_suffix(exp1, rexp1); \
val2 = GiSubtract##_func_suffix(exp2, rexp2); \
exp1 = GiAdd##_func_suffix(exp1, rexp1); \
exp2 = GiAdd##_func_suffix(exp2, rexp2); \
rexp1 = GiRecpe##_func_suffix(exp1); \
rexp2 = GiRecpe##_func_suffix(exp2); \
rexp1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp1, rexp1), rexp1); \
rexp2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(exp2, rexp2), rexp2); \
val1 = GiMultiply##_func_suffix(val1, rexp1); \
val2 = GiMultiply##_func_suffix(val2, rexp2); \
return {{val1, val2}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 60
- 0
dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h View File

@@ -0,0 +1,60 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseMulAdd3OpBase : TernaryOpBase<src_ctype, dst_ctype> {
using TernaryOpBase<src_ctype, dst_ctype>::TernaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, const src_ctype src2,
dst_ctype* dst) const {
*dst = operator()(src0, src1, src2);
}

dst_ctype operator()(
const src_ctype& src0, const src_ctype& src1, const src_ctype& src2) const {
return (src0 * src1) + src2;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct FuseMulAdd3Op;

#define OP(_ctype, _simd_type, _func_suffix, _simd_width) \
template <> \
struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \
using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \
using FuseMulAdd3OpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
const _simd_type& src2, dst_ctype* dst) const { \
auto vitem = operator()(src0, src1, src2); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type operator()( \
const _simd_type& src0, const _simd_type& src1, \
const _simd_type& src2) const { \
auto vitem0 = GiMultiplyAdd##_func_suffix( \
src2.val[0], src0.val[0], src1.val[0]); \
auto vitem1 = GiMultiplyAdd##_func_suffix( \
src2.val[1], src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
};
OP(dt_float32, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 31
- 0
dnn/src/fallback/elemwise_helper/kimpl/gi_util_impl_helper.h View File

@@ -0,0 +1,31 @@
/**
* \file dnn/src/fallback/elemwise/gi_impl/gi_util_impl_helper.h
*/

#pragma once

/*!
* \brief compute fuse_add_relu on two simd packs
*
* Compute
*
* val1 = fuse_add_relu(val1, val3)
* val2 = fuse_add_relu(val2, val4)
*
* This algorithm handles int overflow.
*/
#define FUSE_ADD_RELU_SIMD_PACK2_FALLBACK(val1, val2, val3, val4, func_suffix) \
do { \
val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val3)); \
val2 = GiMaximum##func_suffix(val2, GiNeg##func_suffix(val4)); \
val1 = GiAdd##func_suffix(val1, val3); \
val2 = GiAdd##func_suffix(val2, val4); \
} while (0)

#define FUSE_ADD_RELU_SIMD_PACK_FALLBACK(val1, val2, func_suffix) \
do { \
val1 = GiMaximum##func_suffix(val1, GiNeg##func_suffix(val2)); \
val1 = GiAdd##func_suffix(val1, val2); \
} while (0)

// vim: syntax=cpp.doxygen

+ 108
- 0
dnn/src/fallback/elemwise_helper/kimpl/hswish.h View File

@@ -0,0 +1,108 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/hswish.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h"
#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct HSwishOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
return (tmp);
}
};

//! h_swish(x) = x * clip(x + 3, 0, 6) / 6
template <typename src_ctype, typename dst_ctype = src_ctype>
struct HSwishOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \
using HSwishOpBase::HSwishOpBase; \
using HSwishOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
H_SWISH_KERN_FALLBACK(_func_suffix, val1, val2); \
return {{val1, val2}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(src, val_three), val_six), \
val_zero); \
return GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(src, clip1), val_rec_six); \
} \
};

OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

template <>
struct HSwishOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*dst = operator()(src);
}

dt_qint8 operator()(const dt_qint32& src) const {
float tmp = src.as_int32() * this->scale_src;
tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f;
tmp *= this->scale_dst;
return QConverter::convert<dt_qint8, float>(tmp);
}
};

template <>
struct HSwishOp<dt_qint32, dt_qint8> : HSwishOpBase<dt_qint32, dt_qint8> {
using HSwishOpBase::HSwishOpBase;
using HSwishOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src);

H_SWISH_KERN_FALLBACK(Float32, vitem0, vitem1);
vitem0 = GiMultiplyFloat32(vitem0, this->vscale_dst);
vitem1 = GiMultiplyFloat32(vitem1, this->vscale_dst);

return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

#include "src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h"
// vim: syntax=cpp.doxygen

+ 7
- 0
dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h View File

@@ -0,0 +1,7 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_epilogue.h
*/

#undef H_SWISH_KERN_FALLBACK

// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/kern_macro_prologue.h
*/

#define H_SWISH_KERN_FALLBACK(_func_suffix, _val1, _val2) \
do { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val1, val_three), val_six), \
val_zero); \
auto clip2 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val2, val_three), val_six), \
val_zero); \
_val1 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \
_val2 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val2, clip2), val_rec_six); \
} while (0);

#define H_SWISH_KERN_N1_FALLBACK(_func_suffix, _val1) \
do { \
auto val_zero = GiBroadcast##_func_suffix(0.f); \
auto val_six = GiBroadcast##_func_suffix(6.f); \
auto val_three = GiBroadcast##_func_suffix(3.f); \
auto val_rec_six = GiBroadcast##_func_suffix(1.f / 6.f); \
auto clip1 = GiMaximum##_func_suffix( \
GiMinimum##_func_suffix( \
GiAdd##_func_suffix(_val1, val_three), val_six), \
val_zero); \
_val1 = GiMultiply##_func_suffix( \
GiMultiply##_func_suffix(_val1, clip1), val_rec_six); \
} while (0);

// vim: syntax=cpp.doxygen

+ 102
- 0
dnn/src/fallback/elemwise_helper/kimpl/max.h View File

@@ -0,0 +1,102 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/max.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {
template <typename src_ctype, typename dst_ctype = src_ctype>
struct MaxOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 > src1 ? src0 : src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct MaxOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MaxOp<_ctype> : MaxOpBase<_ctype> { \
using MaxOpBase::MaxOpBase; \
using MaxOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMaximum##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMaximum##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMaximum##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct MaxOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using src_ctype = dt_qint8;
using dst_ctype = dt_qint8;
using BinaryOpBase::BinaryOpBase;

void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}

dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
float fsrc0 = src0.as_int8() * this->scale0;
float fsrc1 = src1.as_int8() * this->scale1;
return QConverter::convert<dst_ctype, float>(fsrc0 > fsrc1 ? fsrc0 : fsrc1);
}
};

template <>
struct MaxOp<dt_qint8, dt_qint8> : MaxOpBase<dt_qint8, dt_qint8> {
using MaxOpBase::MaxOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MaxOpBase::operator();

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMaximumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMaximumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 99
- 0
dnn/src/fallback/elemwise_helper/kimpl/min.h View File

@@ -0,0 +1,99 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/min.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct MinOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 < src1 ? src0 : src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct MinOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MinOp<_ctype> : MinOpBase<_ctype> { \
using MinOpBase::MinOpBase; \
using MinOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMinimum##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMinimum##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMinimum##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct MinOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
float fsrc0 = src0.as_int8() * this->scale0;
float fsrc1 = src1.as_int8() * this->scale1;
return QConverter::convert<dt_qint8, float>(fsrc0 < fsrc1 ? fsrc0 : fsrc1);
}
};

template <>
struct MinOp<dt_qint8, dt_qint8> : MinOpBase<dt_qint8, dt_qint8> {
using MinOpBase::MinOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MinOpBase::operator();

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMinimumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMinimumFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 99
- 0
dnn/src/fallback/elemwise_helper/kimpl/mul.h View File

@@ -0,0 +1,99 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/mul.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct MulOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 * src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct MulOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct MulOp<_ctype> : MulOpBase<_ctype> { \
using MulOpBase::MulOpBase; \
using MulOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiMultiply##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiMultiply##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiMultiply##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct MulOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;
void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * scale_src0 * src1.as_int8() * scale1);
}
};

template <>
struct MulOp<dt_qint8, dt_qint8> : MulOpBase<dt_qint8, dt_qint8> {
using MulOpBase::MulOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using MulOpBase::operator();

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiMultiplyFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiMultiplyFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));

return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 77
- 0
dnn/src/fallback/elemwise_helper/kimpl/none.h View File

@@ -0,0 +1,77 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/none.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct NoneOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
dst_ctype operator()(const src_ctype& src) const { return src; }
};

template <typename src_ctype, typename dst_type = src_ctype>
struct NoneOp;
#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct NoneOp<_ctype> : NoneOpBase<_ctype> { \
NoneOp(){}; \
NoneOp(float, float){}; \
using NoneOpBase::NoneOpBase; \
using NoneOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
_simd_type2 operator()(const _simd_type2& src) const { return src; } \
void operator()(const _simd_type2& src, _ctype* dst) const { \
GiStore##_func_suffix(dst, src.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
GiStore##_func_suffix(dst, src); \
} \
_simd_type operator()(const _simd_type& src) const { return src; } \
};

OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct NoneOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const { *dst = src; }
};

template <>
struct NoneOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*(reinterpret_cast<dt_qint32*>(dst)) = src;
}
};

#pragma GCC diagnostic ignored "-Waddress-of-packed-member"

template <>
struct NoneOp<dt_qint32, dt_qint8> : NoneOpBase<dt_qint32, dt_qint8> {
using NoneOpBase::NoneOpBase;
using NoneOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), vsrc.val[0]);
GiStoreInt32(reinterpret_cast<int32_t*>(dst + 16), vsrc.val[1]);
}
void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreInt32(reinterpret_cast<int32_t*>(dst), src);
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 450
- 0
dnn/src/fallback/elemwise_helper/kimpl/op_base.h View File

@@ -0,0 +1,450 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/op_base.h
*/
#pragma once

#include <cmath>
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include "src/fallback/elemwise/gi_impl/gi_mathfun.h"
#include "src/fallback/quantized_converter.h"

#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/general_intrinsic/gi_int.h"

namespace megdnn {
namespace fallback {

////////////////////////// unary //////////////////////////
template <typename _src_ctype, typename _dst_ctype = _src_ctype>
struct OpBase {
using src_ctype = _src_ctype;
using dst_ctype = _dst_ctype;
OpBase() = default;
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct UnaryOpBase : OpBase<src_ctype, dst_ctype> {
using OpBase<src_ctype, dst_ctype>::OpBase;
UnaryOpBase() = default;
UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {}
};

#define OPERATOR_UNARY_QINT8_FALLBACK \
GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst), operator()( \
{{GiMoveLowLongInt16(vsrct0), \
GiMoveHighLongInt16(vsrct0)}})); \
GI_INT16_t vsrct1 = GiMoveHighLongInt8(vsrc.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 8), \
operator()({{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}})); \
GI_INT16_t vsrct2 = GiMoveHighLongInt8(vsrc.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 16), \
operator()({{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
GI_INT16_t vsrct3 = GiMoveHighLongInt8(vsrc.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 24), \
operator()({{GiMoveLowLongInt16(vsrct3), GiMoveHighLongInt16(vsrct3)}}))

//! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul)
//! scale = src.scale / dst.scale
template <>
struct UnaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
using OpBase::OpBase;
float scale_src, scale_dst;
GI_FLOAT32_t vscale_src, vscale_dst;
float scale;
GI_FLOAT32_t vscale;

void init(float src_scale, float dst_scale) {
scale_src = src_scale;
vscale_src = GiBroadcastFloat32(scale_src);
scale_dst = 1.f / dst_scale;
vscale_dst = GiBroadcastFloat32(scale_dst);
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}

UnaryOpBase(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS8>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
init(src_scale, dst_scale);
}
UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); }
};

template <>
struct UnaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> {
using OpBase::OpBase;
using src_ctype = dt_qint32;
using dst_ctype = dt_qint8;
float scale;
GI_FLOAT32_t vscale;
float scale_src, scale_dst;
GI_FLOAT32_t vscale_src, vscale_dst;

void init(float src_scale, float dst_scale) {
scale_src = src_scale;
vscale_src = GiBroadcastFloat32(src_scale);
scale_dst = 1 / dst_scale;
vscale_dst = GiBroadcastFloat32(scale_dst);
scale = src_scale / dst_scale;
vscale = GiBroadcastFloat32(scale);
}

UnaryOpBase(DType src_dtype, DType dst_dtype) {
float src_scale = src_dtype.param<dtype::QuantizedS32>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
init(src_scale, dst_scale);
}

UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); }
};

////////////////////////// binary //////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct BinaryOpBase : OpBase<src_ctype, dst_ctype> {
using OpBase<src_ctype, dst_ctype>::OpBase;
BinaryOpBase() = default;
BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*dst_dtype*/) {}
};

/* ================= binary op for quantized types ================== */

#define OPERATOR_BINARY_QINT8_FALLBACK \
GI_INT16_t vsrct0_0 = GiMoveLowLongInt8(vsrc0.val[0]); \
GI_INT16_t vsrct1_0 = GiMoveLowLongInt8(vsrc1.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst), \
operator()( \
{{GiMoveLowLongInt16(vsrct0_0), GiMoveHighLongInt16(vsrct0_0)}}, \
{{GiMoveLowLongInt16(vsrct1_0), GiMoveHighLongInt16(vsrct1_0)}})); \
GI_INT16_t vsrct0_1 = GiMoveHighLongInt8(vsrc0.val[0]); \
GI_INT16_t vsrct1_1 = GiMoveHighLongInt8(vsrc1.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 8), \
operator()( \
{{GiMoveLowLongInt16(vsrct0_1), GiMoveHighLongInt16(vsrct0_1)}}, \
{{GiMoveLowLongInt16(vsrct1_1), GiMoveHighLongInt16(vsrct1_1)}})); \
GI_INT16_t vsrct0_2 = GiMoveLowLongInt8(vsrc0.val[1]); \
GI_INT16_t vsrct1_2 = GiMoveLowLongInt8(vsrc1.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 16), \
operator()( \
{{GiMoveLowLongInt16(vsrct0_2), GiMoveHighLongInt16(vsrct0_2)}}, \
{{GiMoveLowLongInt16(vsrct1_2), GiMoveHighLongInt16(vsrct1_2)}})); \
GI_INT16_t vsrct0_3 = GiMoveHighLongInt8(vsrc0.val[1]); \
GI_INT16_t vsrct1_3 = GiMoveHighLongInt8(vsrc1.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 24), \
operator()( \
{{GiMoveLowLongInt16(vsrct0_3), GiMoveHighLongInt16(vsrct0_3)}}, \
{{GiMoveLowLongInt16(vsrct1_3), GiMoveHighLongInt16(vsrct1_3)}}))

//! scale_src0 = src0.scale; scale_src1 = src1.scale; scale_dst = 1.f /
//! dst.scale scale0 = src0.scale / dst.scale; scale1 = src1.scale / dst.scale
template <>
struct BinaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
using OpBase::OpBase;
using src_ctype = dt_qint8;
using dst_ctype = dt_qint8;
float scale_src0, scale_src1, scale_dst;
GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst;
float scale0, scale1;
GI_FLOAT32_t vscale0, vscale1;

void init(float src0_scale, float src1_scale, float dst_scale) {
scale_src0 = src0_scale;
vscale_src0 = GiBroadcastFloat32(scale_src0);
scale_src1 = src1_scale;
vscale_src1 = GiBroadcastFloat32(scale_src1);
scale_dst = 1.f / dst_scale;
vscale_dst = GiBroadcastFloat32(scale_dst);
scale0 = src0_scale / dst_scale;
vscale0 = GiBroadcastFloat32(scale0);
scale1 = src1_scale / dst_scale;
vscale1 = GiBroadcastFloat32(scale1);
}

BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) {
float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale;
float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
init(src0_scale, src1_scale, dst_scale);
}

BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) {
init(src0_scale, src1_scale, dst_scale);
}
};

template <>
struct BinaryOpBase<dt_qint32, dt_qint8> : OpBase<dt_qint32, dt_qint8> {
using OpBase::OpBase;
using src_ctype = dt_qint32;
using dst_ctype = dt_qint8;
float scale0, scale1;
GI_FLOAT32_t vscale0, vscale1;
float scale_src0, scale_src1, scale_dst;
GI_FLOAT32_t vscale_src0, vscale_src1, vscale_dst;

void init(float src0_scale, float src1_scale, float dst_scale) {
scale_src0 = src0_scale;
vscale_src0 = GiBroadcastFloat32(src0_scale);
scale_src1 = src1_scale;
vscale_src1 = GiBroadcastFloat32(src1_scale);
scale_dst = 1 / dst_scale;
vscale_dst = GiBroadcastFloat32(scale_dst);
scale0 = src0_scale / dst_scale;
vscale0 = GiBroadcastFloat32(scale0);
scale1 = src1_scale / dst_scale;
vscale1 = GiBroadcastFloat32(scale1);
}

BinaryOpBase(DType src0_dtype, DType src1_dtype, DType dst_dtype) {
float src0_scale = src0_dtype.param<dtype::QuantizedS32>().scale;
float src1_scale = src1_dtype.param<dtype::QuantizedS32>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
init(src0_scale, src1_scale, dst_scale);
}

BinaryOpBase(float src0_scale, float src1_scale, float dst_scale) {
init(src0_scale, src1_scale, dst_scale);
}
};

////////////////////////// ternary //////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TernaryOpBase : OpBase<src_ctype, dst_ctype> {
using OpBase<src_ctype, dst_ctype>::OpBase;
TernaryOpBase() = default;
TernaryOpBase(
DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*src2_dtype*/,
DType /*dst_dtype*/) {}
};

#define OPERATOR_TERNARY_QINT8_FALLBACK \
GI_INT16_t vsrct0 = GiMoveLowLongInt8(vsrc0.val[0]); \
GI_INT16_t vsrct1 = GiMoveLowLongInt8(vsrc1.val[0]); \
GI_INT16_t vsrct2 = GiMoveLowLongInt8(vsrc2.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst), \
operator()( \
{{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \
{{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \
{{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
vsrct0 = GiMoveHighLongInt8(vsrc0.val[0]); \
vsrct1 = GiMoveHighLongInt8(vsrc1.val[0]); \
vsrct2 = GiMoveHighLongInt8(vsrc2.val[0]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 8), \
operator()( \
{{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \
{{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \
{{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
vsrct0 = GiMoveLowLongInt8(vsrc0.val[1]); \
vsrct1 = GiMoveLowLongInt8(vsrc1.val[1]); \
vsrct2 = GiMoveLowLongInt8(vsrc2.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 16), \
operator()( \
{{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \
{{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \
{{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}})); \
vsrct0 = GiMoveHighLongInt8(vsrc0.val[1]); \
vsrct1 = GiMoveHighLongInt8(vsrc1.val[1]); \
vsrct2 = GiMoveHighLongInt8(vsrc2.val[1]); \
GiStoreLowInt8( \
reinterpret_cast<int8_t*>(dst + 24), \
operator()( \
{{GiMoveLowLongInt16(vsrct0), GiMoveHighLongInt16(vsrct0)}}, \
{{GiMoveLowLongInt16(vsrct1), GiMoveHighLongInt16(vsrct1)}}, \
{{GiMoveLowLongInt16(vsrct2), GiMoveHighLongInt16(vsrct2)}}))

/*========================= ternaty op for quanzited ====================*/
template <>
struct TernaryOpBase<dt_qint8, dt_qint8> : OpBase<dt_qint8, dt_qint8> {
using OpBase::OpBase;
using src_ctype = dt_qint8;
using dst_ctype = dt_qint8;
float scale_src0, scale_src1, scale_src2, scale_dst;
GI_FLOAT32_t vscale_src0, vscale_src1, vscale_src2, vscale_dst;
float scale0, scale1, scale2;
GI_FLOAT32_t vscale0, vscale1, vscale2;
void init(float src0_scale, float src1_scale, float src2_scale, float dst_scale) {
scale_src0 = src0_scale;
scale_src1 = src1_scale;
scale_src2 = src2_scale;
scale_dst = 1.f / dst_scale;
vscale_src0 = GiBroadcastFloat32(scale_src0);
vscale_src1 = GiBroadcastFloat32(scale_src1);
vscale_src2 = GiBroadcastFloat32(scale_src2);
vscale_dst = GiBroadcastFloat32(scale_dst);
scale0 = src0_scale / dst_scale;
scale1 = src1_scale / dst_scale;
scale2 = src2_scale / dst_scale;
vscale0 = GiBroadcastFloat32(scale0);
vscale1 = GiBroadcastFloat32(scale1);
vscale2 = GiBroadcastFloat32(scale2);
}
TernaryOpBase(
DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) {
float src0_scale = src0_dtype.param<dtype::QuantizedS8>().scale;
float src1_scale = src1_dtype.param<dtype::QuantizedS8>().scale;
float src2_scale = src2_dtype.param<dtype::QuantizedS8>().scale;
float dst_scale = dst_dtype.param<dtype::QuantizedS8>().scale;
init(src0_scale, src1_scale, src2_scale, dst_scale);
}
TernaryOpBase(
float src0_scale, float src1_scale, float src2_scale, float dst_scale) {
init(src0_scale, src1_scale, src2_scale, dst_scale);
}
};

////////////////////////// fixup //////////////////////////
struct FixupBase {
GI_INT32_t vmultiplier, vshift;
FixupBase(float scale) {
//! ignore Fixup if scale >= 0.5, using typecvt instead of shift &
//! multiplier, as it may introduce errors.
if (scale >= 0.5)
return;

int shift = static_cast<int>(::ceilf(::log2f(0.5 / scale)));
scale *= ::powf(2, shift);
//! Using double can get full precision here, but it can be ignored.
vmultiplier = GiBroadcastInt32(
std::round(static_cast<double>(scale) * ((2LL) << 30)));
vshift = GiBroadcastInt32(-shift);
}
};

//////////////////////// quantization common ////////////////////
template <typename src_type, typename dst_type, typename Op>
struct UnaryQuantizationOp;

template <typename Op>
struct UnaryQuantizationOp<dt_qint8, dt_qint8, Op> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase<dt_qint8, dt_qint8>::UnaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
Op op;

void operator()(const dt_qint8& src, dt_qint8* dst) const {
*dst = operator()(src);
}

dt_qint8 operator()(const dt_qint8& src) const {
float fsrc = src.as_int8() * this->scale_src;
fsrc = op(fsrc);
fsrc = fsrc * this->scale_dst;
return QConverter::convert<dt_qint8, float>(fsrc);
}

void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale_src);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale_src);
auto val = this->op({{vitem0, vitem1}});
val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst);
val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V4_t>(val);
}
};

template <typename src_type, typename dst_type, typename Op>
struct BinaryQuantizationOp;

template <typename Op>
struct BinaryQuantizationOp<dt_qint8, dt_qint8, Op> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase<dt_qint8, dt_qint8>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
Op op;

void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}

dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
float fsrc0 = src0.as_int8() * this->scale_src0;
float fsrc1 = src1.as_int8() * this->scale_src1;
float fdst = op(fsrc0, fsrc1);
fdst = fdst * this->scale_dst;
return QConverter::convert<dt_qint8, float>(fdst);
}

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0);
auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0);
auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1);
auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1);
auto val = op({{val0, val1}}, {{val2, val3}});
val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst);
val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
}
};

template <typename src_type, typename dst_type, typename Op>
struct TernaryQuantizationOp;

template <typename Op>
struct TernaryQuantizationOp<dt_qint8, dt_qint8, Op>
: TernaryOpBase<dt_qint8, dt_qint8> {
using TernaryOpBase<dt_qint8, dt_qint8>::TernaryOpBase;
constexpr static size_t SIMD_WIDTH = 16;
Op op;

void operator()(
const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2,
dt_qint8* dst) const {
*dst = operator()(src0, src1, src2);
}

dt_qint8 operator()(
const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2) const {
float fsrc0 = src0.as_int8() * this->scale_src0;
float fsrc1 = src1.as_int8() * this->scale_src1;
float fsrc2 = src2.as_int8() * this->scale_src2;
float fdst = op(fsrc0, fsrc1, fsrc2);
fdst = fdst * this->scale_dst;
return QConverter::convert<dt_qint8, float>(fdst);
}

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1,
const GI_INT8_V2_t& vsrc2, dt_qint8* dst) const {
OPERATOR_TERNARY_QINT8_FALLBACK;
}

GI_INT8_t operator()(
const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1,
const GI_INT32_V2_t& vsrc2) const {
auto val0 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale_src0);
auto val1 = GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale_src0);
auto val2 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale_src1);
auto val3 = GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale_src1);
auto val4 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[0]), this->vscale_src2);
auto val5 = GiMultiplyFloat32(GiCastToFloat32(vsrc2.val[1]), this->vscale_src2);
auto val = op({{val0, val1}}, {{val2, val3}}, {{val4, val5}});
val.val[0] = GiMultiplyFloat32(val.val[0], this->vscale_dst);
val.val[1] = GiMultiplyFloat32(val.val[1], this->vscale_dst);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>(val);
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 28
- 0
dnn/src/fallback/elemwise_helper/kimpl/pow.h View File

@@ -0,0 +1,28 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/pow.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

/////////////////////// POW float only ////////////////////////////
template <typename src_ctype, typename dst_ctype = src_ctype>
struct PowOp : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
constexpr static size_t SIMD_WIDTH = 1;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return powf(src0, src1);
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 188
- 0
dnn/src/fallback/elemwise_helper/kimpl/relu.h View File

@@ -0,0 +1,188 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/relu.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct ReluOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; }
};

template <typename src_ctype, typename dst_type = src_ctype>
struct ReluOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct ReluOp<_ctype> : ReluOpBase<_ctype> { \
using ReluOpBase::ReluOpBase; \
using ReluOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
auto vitem0 = GiMaximum##_func_suffix(src.val[0], vzero); \
auto vitem1 = GiMaximum##_func_suffix(src.val[1], vzero); \
return {{vitem0, vitem1}}; \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src) const { \
auto vzero = GiBroadcast##_func_suffix(0); \
return GiMaximum##_func_suffix(src, vzero); \
} \
};

OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct ReluOpBase<dt_qint8, dt_qint8> : UnaryOpBase<dt_qint8, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint8& src, dt_qint8* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint8& src) const {
float fsrc = src.as_int8() * this->scale;
fsrc = std::max<float>(fsrc, 0.f);
return QConverter::convert<dt_qint8, float>(fsrc);
}
};

template <>
struct ReluOp<dt_qint8, dt_qint8> : ReluOpBase<dt_qint8, dt_qint8> {
using ReluOpBase::ReluOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using ReluOpBase::operator();

void operator()(const GI_INT8_V2_t& vsrc, dt_qint8* dst) const {
OPERATOR_UNARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vzero = GiBroadcastFloat32(0.f);
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, vzero);
vitem1 = GiMaximumFloat32(vitem1, vzero);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

template <>
struct ReluOpBase<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
void operator()(const dt_qint32& src, dt_qint8* dst) const {
*dst = operator()(src);
}

dt_qint8 operator()(const dt_qint32& src) const {
float fsrc = src.as_int32() * this->scale;
fsrc = std::max<float>(fsrc, 0.f);
return QConverter::convert<dt_qint8, float>(fsrc);
}
};

//! if old armv7, special define relu with fixup
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8>, FixupBase {
using ReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = 4;

ReluOp(DType src_dtype, DType dst_dtype)
: ReluOpBase(src_dtype, dst_dtype), FixupBase(scale) {}

ReluOp(float src_scale, float dst_scale)
: ReluOpBase(src_scale, dst_scale), FixupBase(scale) {}

void operator()(const int32x4x2_t& vsrc, dt_qint8* dst) const {
vst1_s8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}

int8x8_t operator()(const int32x4x2_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vsrc.val[0], vmultiplier);
int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero());
return vqmovn_s16(vcombine_s16(
vqmovn_s32(vrshlq_s32(vitem0, vshift)),
vqmovn_s32(vrshlq_s32(vitem1, vshift))));
}
int8x8_t operator()(const float32x4_t& vsrc) const {
int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier);
vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero());
vitem0 = vrshlq_s32(vitem0, vshift);
int16x4_t vitem = vqmovn_s32(vitem0);
return vqmovn_s16(vcombine_s16(vitem, vitem));
}
void operator()(const int32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(vcvtq_f32_s32(src), this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
}
void operator()(const float32x4_t& src, dt_qint8* dst) const {
auto vitem0 = vmulq_f32(src, this->vscale);
vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero());
auto result = QConverter::convert<int8x8_t, float32x4_t>(vitem0);
vst1_lane_s32(reinterpret_cast<int32_t*>(dst), (int32x2_t)result, 0);
}
};

#else
template <>
struct ReluOp<dt_qint32, dt_qint8> : ReluOpBase<dt_qint32, dt_qint8> {
using ReluOpBase::ReluOpBase;
using ReluOpBase::operator();
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(int32_t);

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
void operator()(const GI_INT32_t& src, dt_qint8* dst) const {
GiStoreLane0Int32(
reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(src)));
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
vitem1 = GiMaximumFloat32(vitem1, QConverterBase::vfzero());

return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
GI_INT8_t operator()(const GI_INT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(src, this->vscale);
vitem0 = GiMaximumFloat32(vitem0, QConverterBase::vfzero());
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
};

#endif

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 56
- 0
dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h View File

@@ -0,0 +1,56 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/sigmoid.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct SigmoidOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmpf = src;
tmpf = exp(-tmpf);
tmpf = 1.f / (1.f + tmpf);
return tmpf;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct SigmoidOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \
using SigmoidOpBase::SigmoidOpBase; \
using SigmoidOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
void operator()(const _simd_type& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
return {{operator()(src.val[0]), operator()(src.val[1])}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
return GiSigmoidPs##_func_suffix(src); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 97
- 0
dnn/src/fallback/elemwise_helper/kimpl/sub.h View File

@@ -0,0 +1,97 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/sub.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct SubOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 - src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct SubOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct SubOp<_ctype> : SubOpBase<_ctype> { \
using SubOpBase::SubOpBase; \
using SubOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto vitem0 = GiSubtract##_func_suffix(src0.val[0], src1.val[0]); \
auto vitem1 = GiSubtract##_func_suffix(src0.val[1], src1.val[1]); \
return {{vitem0, vitem1}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiSubtract##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
OP(dt_int32, GI_INT32_t, GI_INT32_V2_t, Int32, GI_SIMD_LEN_BYTE / sizeof(int32_t))
OP(dt_int8, GI_INT8_t, GI_INT8_V2_t, Int8, GI_SIMD_LEN_BYTE / sizeof(int8_t))
#undef OP

template <>
struct SubOpBase<dt_qint8, dt_qint8> : BinaryOpBase<dt_qint8, dt_qint8> {
using BinaryOpBase::BinaryOpBase;

void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const {
*dst = operator()(src0, src1);
}
dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const {
return QConverter::convert<dt_qint8, float>(
src0.as_int8() * scale0 - src1.as_int8() * scale1);
}
};

template <>
struct SubOp<dt_qint8, dt_qint8> : SubOpBase<dt_qint8, dt_qint8> {
using SubOpBase::SubOpBase;
constexpr static size_t SIMD_WIDTH = 16;
using SubOpBase::operator();

void operator()(
const GI_INT8_V2_t& vsrc0, const GI_INT8_V2_t& vsrc1, dt_qint8* dst) const {
OPERATOR_BINARY_QINT8_FALLBACK;
}
GI_INT8_t operator()(const GI_INT32_V2_t& vsrc0, const GI_INT32_V2_t& vsrc1) const {
auto vitem0 = GiSubtractFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[0]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[0]), this->vscale1));
auto vitem1 = GiSubtractFloat32(
GiMultiplyFloat32(GiCastToFloat32(vsrc0.val[1]), this->vscale0),
GiMultiplyFloat32(GiCastToFloat32(vsrc1.val[1]), this->vscale1));
return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 81
- 0
dnn/src/fallback/elemwise_helper/kimpl/tanh.h View File

@@ -0,0 +1,81 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/tanh.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> {
using UnaryOpBase<src_ctype, dst_ctype>::UnaryOpBase;
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dst_ctype operator()(const src_ctype& src) const {
float tmp = src;
return tanh(tmp);
}
};

template <typename src_ctype, typename dst_type = src_ctype>
struct TanhOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct TanhOp<_ctype> : TanhOpBase<_ctype> { \
using TanhOpBase::TanhOpBase; \
using TanhOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()(const _simd_type2& src, _ctype* dst) const { \
auto vitem = operator()(src); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()(const _simd_type2& src) const { \
auto one_val = GiBroadcast##_func_suffix(1.f); \
auto two_val = GiBroadcast##_func_suffix(2.f); \
auto val1 = src.val[0]; \
auto val2 = src.val[1]; \
val1 = GiMultiply##_func_suffix(two_val, val1); \
val2 = GiMultiply##_func_suffix(two_val, val2); \
val1 = GiExpPs##_func_suffix(val1); \
val2 = GiExpPs##_func_suffix(val2); \
val1 = GiAdd##_func_suffix(one_val, val1); \
val2 = GiAdd##_func_suffix(one_val, val2); \
auto rval1 = GiRecpe##_func_suffix(val1); \
auto rval2 = GiRecpe##_func_suffix(val2); \
rval1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val1, rval1), rval1); \
rval2 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val2, rval2), rval2); \
val1 = GiMultiply##_func_suffix(two_val, rval1); \
val2 = GiMultiply##_func_suffix(two_val, rval2); \
val1 = GiSubtract##_func_suffix(one_val, val1); \
val2 = GiSubtract##_func_suffix(one_val, val2); \
return {{val1, val2}}; \
} \
_simd_type operator()(const _simd_type& src) const { \
auto one_val = GiBroadcast##_func_suffix(1.f); \
auto two_val = GiBroadcast##_func_suffix(2.f); \
auto val1 = src; \
val1 = GiMultiply##_func_suffix(two_val, val1); \
val1 = GiExpPs##_func_suffix(val1); \
val1 = GiAdd##_func_suffix(one_val, val1); \
auto rval1 = GiRecpe##_func_suffix(val1); \
rval1 = GiMultiply##_func_suffix( \
GiRecpeS##_func_suffix(val1, rval1), rval1); \
val1 = GiMultiply##_func_suffix(two_val, rval1); \
val1 = GiSubtract##_func_suffix(one_val, val1); \
return val1; \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 68
- 0
dnn/src/fallback/elemwise_helper/kimpl/true_div.h View File

@@ -0,0 +1,68 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/true_div.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

//! use a couple Newton-Raphson steps to refine the estimate.
//! A / B => 1. rB = vrecpeq_f32(B) 2. rB= vmulq_f32(vrecpsq_f32(B, rB), rB)
//! 3. A * rB
template <typename src_ctype, typename dst_ctype = src_ctype>
struct TrueDivOpBase : BinaryOpBase<src_ctype, dst_ctype> {
using BinaryOpBase<src_ctype, dst_ctype>::BinaryOpBase;
void operator()(
const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const {
*dst = operator()(src0, src1);
}
dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const {
return src0 / src1;
}
};

template <typename src_ctype, typename dst_ctype = src_ctype>
struct TrueDivOp;

#define OP(_ctype, _simd_type, _simd_type2, _func_suffix, _simd_width) \
template <> \
struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \
using TrueDivOpBase::TrueDivOpBase; \
using TrueDivOpBase::operator(); \
constexpr static size_t SIMD_WIDTH = _simd_width; \
void operator()( \
const _simd_type2& src0, const _simd_type2& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem.val[0]); \
GiStore##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \
_simd_type2 operator()( \
const _simd_type2& src0, const _simd_type2& src1) const { \
auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \
val1 = GiDivide##_func_suffix(val1, val3); \
val2 = GiDivide##_func_suffix(val2, val4); \
return {{val1, val2}}; \
} \
void operator()( \
const _simd_type& src0, const _simd_type& src1, \
dst_ctype* dst) const { \
auto vitem = operator()(src0, src1); \
GiStore##_func_suffix(dst, vitem); \
} \
_simd_type operator()(const _simd_type& src0, const _simd_type& src1) const { \
return GiDivide##_func_suffix(src0, src1); \
} \
};
OP(dt_float32, GI_FLOAT32_t, GI_FLOAT32_V2_t, Float32, GI_SIMD_LEN_BYTE / sizeof(float))
#undef OP

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 53
- 0
dnn/src/fallback/elemwise_helper/kimpl/typecvt.h View File

@@ -0,0 +1,53 @@
/**
* \file dnn/src/fallback/elemwise_helper/kimpl/typecvt.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/op_base.h"

namespace megdnn {
namespace fallback {

template <typename src_ctype, typename dst_ctype = src_ctype>
struct TypeCvtOp;

template <>
struct TypeCvtOp<dt_qint32, dt_qint8> : UnaryOpBase<dt_qint32, dt_qint8> {
using UnaryOpBase::UnaryOpBase;
constexpr static size_t SIMD_WIDTH = GI_SIMD_LEN_BYTE / sizeof(float);

void operator()(const GI_INT32_V2_t& vsrc, dt_qint8* dst) const {
GiStoreLowInt8(reinterpret_cast<int8_t*>(dst), operator()(vsrc));
}
void operator()(const GI_INT32_t& vsrc, dt_qint8* dst) const {
GiStoreLane0Int32(
reinterpret_cast<int32_t*>(dst), (GI_INT32_t)(operator()(vsrc)));
}
void operator()(const src_ctype& src, dst_ctype* dst) const {
*dst = operator()(src);
}
dt_qint8 operator()(const dt_qint32& src) const {
float fsrc = src.as_int32() * this->scale;
return QConverter::convert<dt_qint8, float>(fsrc);
}

GI_INT8_t operator()(const GI_INT32_V2_t& vsrc) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[0]), this->vscale);
auto vitem1 = GiMultiplyFloat32(GiCastToFloat32(vsrc.val[1]), this->vscale);

return QConverter::convert<GI_INT8_t, GI_FLOAT32_V2_t>({{vitem0, vitem1}});
}
GI_INT8_t operator()(const GI_INT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(GiCastToFloat32(src), this->vscale);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
GI_INT8_t operator()(const GI_FLOAT32_t& src) const {
auto vitem0 = GiMultiplyFloat32(src, this->vscale);
return QConverter::convert<GI_INT8_t, GI_FLOAT32_t>(vitem0);
}
};

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/fallback/elemwise_helper/op_binary.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/fallback/elemwise_helper/op_binary.h
*/

#pragma once

#include "src/fallback/elemwise_helper/kimpl/add.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_h_swish.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_relu.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_sigmoid.h"
#include "src/fallback/elemwise_helper/kimpl/fuse_add_tanh.h"
#include "src/fallback/elemwise_helper/kimpl/max.h"
#include "src/fallback/elemwise_helper/kimpl/min.h"
#include "src/fallback/elemwise_helper/kimpl/mul.h"
#include "src/fallback/elemwise_helper/kimpl/pow.h"
#include "src/fallback/elemwise_helper/kimpl/sub.h"
#include "src/fallback/elemwise_helper/kimpl/true_div.h"

//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: BinaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using BinaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::BinaryQuantizationOp; \
};

cb(TrueDivOp);
cb(FuseAddSigmoidOp);
cb(FuseAddTanhOp);
cb(FuseAddHSwishOp);

#undef cb
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 39
- 0
dnn/src/fallback/elemwise_helper/op_common.h View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/fallback/elemwise_helper/op_common.h
*/
#pragma once

namespace megdnn {
/*!
* \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i]
*/
enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101,
BCAST111C_VEC_BCAST111C,
BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC,
VEC_BCAST111C_VEC,
VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE
};

} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 24
- 0
dnn/src/fallback/elemwise_helper/op_ternary.h View File

@@ -0,0 +1,24 @@
/**
* \file dnn/src/fallback/elemwise_helper/op_ternary.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/fuse_mul_add3.h"

//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: TernaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using TernaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::TernaryQuantizationOp; \
};

cb(FuseMulAdd3Op);
#undef cb
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 36
- 0
dnn/src/fallback/elemwise_helper/op_unary.h View File

@@ -0,0 +1,36 @@
/**
* \file dnn/src/fallback/elemwise_helper/op_unary.h
*/
#pragma once

#include "src/fallback/elemwise_helper/kimpl/abs.h"
#include "src/fallback/elemwise_helper/kimpl/exp.h"
#include "src/fallback/elemwise_helper/kimpl/fast_tanh.h"
#include "src/fallback/elemwise_helper/kimpl/hswish.h"
#include "src/fallback/elemwise_helper/kimpl/none.h"
#include "src/fallback/elemwise_helper/kimpl/relu.h"
#include "src/fallback/elemwise_helper/kimpl/sigmoid.h"
#include "src/fallback/elemwise_helper/kimpl/tanh.h"
#include "src/fallback/elemwise_helper/kimpl/typecvt.h"

//////////////////// quantization //////////////////////////////
namespace megdnn {
namespace fallback {
#define cb(op) \
template <> \
struct op<dt_qint8, dt_qint8> \
: UnaryQuantizationOp<dt_qint8, dt_qint8, op<float, float>> { \
using UnaryQuantizationOp< \
dt_qint8, dt_qint8, op<float, float>>::UnaryQuantizationOp; \
};

cb(SigmoidOp);
cb(ExpOp);
cb(TanhOp);
cb(FastTanhOp);
cb(HSwishOp);
#undef cb
} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1432
- 0
dnn/src/fallback/elemwise_op.h
File diff suppressed because it is too large
View File


+ 1
- 1
dnn/src/fallback/general_intrinsic/gi_common.h View File

@@ -19,7 +19,7 @@
#include <windows.h>
#else
#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
#include "src/arm_common/simd_macro/marm_neon.h"
#endif
#if defined(__x86_64__) || defined(__i386__)
#include <cpuid.h>


+ 10
- 10
dnn/src/fallback/quantized_converter.h View File

@@ -21,13 +21,13 @@ namespace megdnn {
namespace fallback {

struct QConverterBase {
inline static GI_INT32 vzero() { return GiBroadcastInt32(0); }
inline static GI_INT32_t vzero() { return GiBroadcastInt32(0); }

inline static GI_FLOAT32 vfzero() { return GiBroadcastFloat32(0.f); }
inline static GI_FLOAT32_t vfzero() { return GiBroadcastFloat32(0.f); }

inline static GI_FLOAT32 vfhalf() { return GiBroadcastFloat32(0.5f); }
inline static GI_FLOAT32_t vfhalf() { return GiBroadcastFloat32(0.5f); }

inline static GI_FLOAT32 vfneg_half() { return GiBroadcastFloat32(-0.5f); }
inline static GI_FLOAT32_t vfneg_half() { return GiBroadcastFloat32(-0.5f); }
};

struct QConverter {
@@ -56,23 +56,23 @@ inline dt_qint32 QConverter::convert(const float& src) {
}

template <>
inline GI_FLOAT32_V2 QConverter::convert(const GI_INT16& vsrc) {
GI_INT32 vhi = GiMoveHighLongInt16(vsrc);
GI_INT32 vlo = GiMoveLowLongInt16(vsrc);
inline GI_FLOAT32_V2_t QConverter::convert(const GI_INT16_t& vsrc) {
GI_INT32_t vhi = GiMoveHighLongInt16(vsrc);
GI_INT32_t vlo = GiMoveLowLongInt16(vsrc);
return {{GiCastToFloat32(vlo), GiCastToFloat32(vhi)}};
}

template <>
inline GI_INT8 QConverter::convert(const GI_FLOAT32_V2& vsrc) {
inline GI_INT8_t QConverter::convert(const GI_FLOAT32_V2_t& vsrc) {
return GiCvtFromFloat32V2ToInt8(vsrc);
}
template <>
inline GI_INT8 QConverter::convert(const GI_FLOAT32& src) {
inline GI_INT8_t QConverter::convert(const GI_FLOAT32_t& src) {
return GiCvtFromFloat32ToInt8(src);
}

template <>
inline GI_INT32 QConverter::round(const GI_FLOAT32& vsrc) {
inline GI_INT32_t QConverter::round(const GI_FLOAT32_t& vsrc) {
return GiRoundAsInt32(vsrc);
}
} // namespace fallback


+ 303
- 0
dnn/test/fallback/elemwise.cpp View File

@@ -38,6 +38,309 @@ TEST_F(FALLBACK, ELEMWISE_RECORD) {
checker.set_rng(2, &rng);
checker.execs({{10, 10, 32}, {10, 10, 32}, {}});
}


TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
checker.set_param(Mode::FUSE_MUL_ADD3);

auto run = [&] {
//! nchw44
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});

//! nchw44
checker.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});

//! nchw88
checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});

//! nchw88
checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});

checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
checker.execs({{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}, {}});
checker.execs({{1, 7}, {1, 7}, {1, 7}, {}});
checker.execs({{1, 2, 1}, {1, 2, 2}, {1, 2, 1}, {}});
checker.execs({{1, 2, 2}, {1, 2, 2}, {1, 1, 1}, {}});
checker.execs({{3, 4, 1}, {3, 4, 1}, {3, 4, 1}, {}});
checker.execs({{3, 4, 5}, {1}, {1}, {}});
checker.execs({{1}, {3, 4, 5}, {1}, {}});
};

// case int
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int8());
run();

checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
checker.set_dtype(2, dtype::Int16());
run();

checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
checker.set_dtype(2, dtype::Int32());
run();

// case float
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
run();
}

TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());

auto run = [&]() {
// VEC_BCAST101x not PowOp
checker.set_param(Mode::ADD).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::ADD).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::ADD).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::ADD).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
checker.set_param(Mode::RMULH).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::RMULH).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::RMULH).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.set_param(Mode::RMULH).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::RMULH).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
// BCAST101x_VEC not PowOp
checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.set_param(Mode::ADD).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.set_param(Mode::ADD).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.set_param(Mode::ADD).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::ADD).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
};
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
run();
checker.set_dtype(0, dtype::Int16());
checker.set_dtype(1, dtype::Int16());
run();
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
run();
}

TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW44_FP32) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());

UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());

checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});

auto run = [&](Mode mode) {
// VEC_BCAST101x
checker.set_param(mode).execs({{1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(mode).execs({{2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.set_param(mode).execs({{3, 8, 5, 3, 4}, {1, 8, 1, 1, 4}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(mode).execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {}});
// BCAST101x_VEC not powOp
checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {}});
checker.set_param(mode).execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {}});
checker.set_param(mode).execs({{1, 8, 1, 1, 4}, {3, 8, 5, 3, 4}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.set_param(mode).execs({{1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
};
run(Mode::ADD);
run(Mode::FUSE_ADD_H_SWISH);
run(Mode::FUSE_ADD_RELU);
run(Mode::MAX);
run(Mode::MIN);
run(Mode::MUL);
run(Mode::SUB);
run(Mode::TRUE_DIV);
run(Mode::POW);
}

TEST_F(FALLBACK, ELEMWISE_FORWARD_NCHW88_FP) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());

checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});

auto run = [&](Mode mode) {
// VEC_BCAST101x
checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
// BCAST101x_VEC not powOp
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
};
auto run_all = [&]() {
run(Mode::ADD);
run(Mode::FUSE_ADD_H_SWISH);
run(Mode::FUSE_ADD_RELU);
run(Mode::MAX);
run(Mode::MIN);
run(Mode::MUL);
run(Mode::SUB);
run(Mode::TRUE_DIV);
run(Mode::POW);
};

{
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
run_all();
}
}

TEST_F(FALLBACK, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());

UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());

//! 2 dim
auto run = [&](Mode mode) {
// VEC_BCASTX0X
checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
// BCASTX0X_VEC
checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
};
run(Mode::ADD);
run(Mode::MUL);
run(Mode::SUB);
}

TEST_F(FALLBACK, ELEMWISE_FORWARD_TERNARY_RECORD) {
using Mode = ElemwiseForward::Param::Mode;
TaskRecordChecker<ElemwiseForward> checker(0);
checker.set_param(Mode::FUSE_MUL_ADD3);

auto run = [&] {
//! nchw44
checker.execs({{1, 3, 1, 1, 4}, {1, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});
checker.execs({{1, 3, 1, 1, 4}, {2, 3, 2, 2, 4}, {1, 3, 1, 1, 4}, {}});

//! nchw88
checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});

checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
};

// case int
checker.set_dtype(0, dtype::Int32());
checker.set_dtype(1, dtype::Int32());
checker.set_dtype(2, dtype::Int32());
run();

// case float
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
checker.set_dtype(2, dtype::Float32());
run();
}

#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
auto naive_handle = create_cpu_handle(2);


+ 145
- 0
dnn/test/x86/elemwise_bmark.cpp View File

@@ -164,3 +164,148 @@ TEST_F(X86, BENCHMARK_ELEM_EVERY_DTYPE) {
// B.set_dtype(2, dtype::Int8());
// BENCHMARK_CASE_INT(1556011)
}

#if MEGDNN_WITH_BENCHMARK
namespace {
void run_elemwise_benchmark(
const TensorShapeArray& shapes, param::Elemwise::Mode mode,
const char* mode_str, DType type, Handle* handle_bench) {
auto handle_fallback = create_cpu_handle(1);
Benchmarker<Elemwise> benchmarker_bench(handle_bench);
Benchmarker<Elemwise> benchmarker_fallback(handle_fallback.get());

float throughput = 0;
SmallVector<TensorLayout> layouts;
std::string src_strs;
for (size_t i = 0; i < shapes.size(); i++) {
layouts.emplace_back(shapes[i], type);
throughput += layouts.back().span().dist_byte();
src_strs += layouts.back().to_string();
if (i != shapes.size() - 1) {
src_strs += ",";
}
}
constexpr size_t RUN = 50;
benchmarker_fallback.set_times(RUN).set_display(false);
benchmarker_bench.set_times(RUN).set_display(false);

benchmarker_fallback.set_param(mode);
benchmarker_bench.set_param(mode);

TensorLayout dst_layout;
auto opr = handle_bench->create_operator<Elemwise>();
opr->param() = mode;
opr->deduce_layout(layouts, dst_layout);

float computations =
dst_layout.total_nr_elems() * (std::max<size_t>(shapes.size(), 2) - 1);
throughput += dst_layout.span().dist_byte();
computations *= (1e3 / (1024.0 * 1024));
throughput *= (1e3 / (1024.0 * 1024));

layouts.emplace_back(dst_layout);
auto fallback_time = benchmarker_fallback.execl(layouts) / RUN;
auto bench_time = benchmarker_bench.execl(layouts) / RUN;

float fallback_flops = computations / fallback_time;
float bench_flops = computations / bench_time;
float fallback_thr = throughput / fallback_time;
float bench_thr = throughput / bench_time;

printf("%s = %s (type: %s, mode: %s) cpu=%fMFLOPS %fMB/s, bench=%fMFLOPS "
"%fMB/s "
"computations: %fx, throughput: %fx\n",
src_strs.c_str(), dst_layout.to_string().c_str(), type.name(), mode_str,
fallback_flops, fallback_thr, bench_flops, bench_thr,
bench_flops / fallback_flops, bench_thr / fallback_thr);
}
} // namespace

#define INT_RUN(shape, mode) \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int32{}, handle());

#define FLOAT_RUN(shape, mode) \
run_elemwise_benchmark(shape, mode, #mode, dtype::Float32{}, handle()); \
run_elemwise_benchmark(shape, mode, #mode, dtype::Float16{}, handle());

#define BENCHMARK_CASES(shape) \
INT_BENCHMARK_CASES(shape) \
FLOAT_BENCHMARK_CASES(shape)

TEST_F(X86, BENCHMARK_UNARY) {
#define INT_BENCHMARK_CASES(shape) \
INT_RUN(shape, Mode::RELU); \
INT_RUN(shape, Mode::ABS);

#define FLOAT_BENCHMARK_CASES(shape) \
FLOAT_RUN(shape, Mode::RELU); \
FLOAT_RUN(shape, Mode::ABS); \
FLOAT_RUN(shape, Mode::SIGMOID); \
FLOAT_RUN(shape, Mode::EXP); \
FLOAT_RUN(shape, Mode::TANH); \
FLOAT_RUN(shape, Mode::FAST_TANH);

using Mode = param::Elemwise::Mode;
BENCHMARK_CASES({{10000}});
BENCHMARK_CASES({{50000}});

#undef INT_BENCHMARK_CASES
#undef FLOAT_BENCHMARK_CASES
}

TEST_F(X86, BENCHMARK_BINARY) {
#define INT_BENCHMARK_CASES(shape) \
INT_RUN(shape, Mode::MIN); \
INT_RUN(shape, Mode::MAX); \
INT_RUN(shape, Mode::ADD); \
INT_RUN(shape, Mode::SUB); \
INT_RUN(shape, Mode::MUL); \
INT_RUN(shape, Mode::RMULH); \
INT_RUN(shape, Mode::FUSE_ADD_RELU);

#define FLOAT_BENCHMARK_CASES(shape) \
FLOAT_RUN(shape, Mode::MIN); \
FLOAT_RUN(shape, Mode::MAX); \
FLOAT_RUN(shape, Mode::ADD); \
FLOAT_RUN(shape, Mode::SUB); \
FLOAT_RUN(shape, Mode::MUL); \
FLOAT_RUN(shape, Mode::POW); \
FLOAT_RUN(shape, Mode::TRUE_DIV); \
FLOAT_RUN(shape, Mode::FUSE_ADD_RELU);

using Mode = param::Elemwise::Mode;
TensorShapeArray shapes = {{1, 112, 28, 28}, {1, 112, 28, 28}};
BENCHMARK_CASES(shapes);
shapes = {{1, 16, 1, 1}, {1, 16, 112, 112}};
BENCHMARK_CASES(shapes);
shapes = {{1, 448, 7, 7}, {1, 448, 7, 7}};
BENCHMARK_CASES(shapes);

#undef INT_BENCHMARK_CASES
#undef FLOAT_BENCHMARK_CASES
}

TEST_F(X86, BENCHMARK_TERNARY_FMA3) {
#define INT_BENCHMARK_CASES(shape) INT_RUN(shape, Mode::FUSE_MUL_ADD3);

#define FLOAT_BENCHMARK_CASES(shape) FLOAT_RUN(shape, Mode::FUSE_MUL_ADD3);

using Mode = param::Elemwise::Mode;
TensorShapeArray shapes = {{30, 40, 70}, {30, 40, 70}, {30, 40, 70}};
BENCHMARK_CASES(shapes);
shapes = {{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}};
BENCHMARK_CASES(shapes);
shapes = {{3, 4, 5, 7}, {3, 4, 5, 7}, {1, 1, 1, 1}};
BENCHMARK_CASES(shapes);

#undef INT_BENCHMARK_CASES
#undef FLOAT_BENCHMARK_CASES
}

#undef BENCHMARK_CASES
#undef INT_RUN
#undef FLOAT_RUN

#endif

Loading…
Cancel
Save