Browse Source

feat(dnn/arm_common): add N1HW like elemwise broadcast mode

GitOrigin-RevId: 2895135801
tags/v1.7.0
Megvii Engine Team 4 years ago
parent
commit
c48d58daa8
8 changed files with 277 additions and 0 deletions
  1. +81
    -0
      dnn/src/arm_common/elemwise/binary/algo.cpp
  2. +1
    -0
      dnn/src/arm_common/elemwise/binary/algo.h
  3. +12
    -0
      dnn/src/arm_common/elemwise/opr_impl.cpp
  4. +1
    -0
      dnn/src/arm_common/elemwise/opr_impl.h
  5. +136
    -0
      dnn/src/arm_common/elemwise_op.h
  6. +14
    -0
      dnn/src/common/elemwise/opr_impl_helper.cpp
  7. +8
    -0
      dnn/src/common/elemwise/opr_impl_helper.h
  8. +24
    -0
      dnn/test/arm_common/elemwise.cpp

+ 81
- 0
dnn/src/arm_common/elemwise/binary/algo.cpp View File

@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcastX0X::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCASTX0X != kern_param.broad_cast_type) &&
(BcastType::BCASTX0X_VEC != kern_param.broad_cast_type)))
return false;

auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0];

DISPATCH_TYPE("AlgoBinaryVecBcastX0X::is_available"_hash);

return false;
}

bool ElemwiseImpl::AlgoBinaryVecBcast111C::is_available(
const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) ||
@@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons
return;
}

void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo;

// Case: BcastType::VEC + BCAST_X0X
if (BcastType::VEC_BCASTX0X == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src1.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_vec_b"_hash);

#undef DISPATCH_BINARY
}

// BCAST_X0X + BcastType::VEC
if (BcastType::BCASTX0X_VEC == kern_param.broad_cast_type &&
is_broadcasted_3dim_like(src0.layout, binfo)) {
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return

DISPATCH_TYPE("AlgoBinaryVecBcastX0X::exec_b_vec"_hash);

#undef DISPATCH_BINARY
}
return;
}

void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1];


+ 1
- 0
dnn/src/arm_common/elemwise/binary/algo.h View File

@@ -33,6 +33,7 @@ namespace arm_common {
DECL_CB(VecVec);
DECL_CB(VecScalar);
DECL_CB(VecBcast101);
DECL_CB(VecBcastX0X);
DECL_CB(VecBcast111C);
DECL_CB(VecBcast101xX);
#undef DECL_CB


+ 12
- 0
dnn/src/arm_common/elemwise/opr_impl.cpp View File

@@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcastX0X algo_binary_vec_bcastX0X;
AlgoBinaryVecBcast111C algo_binary_vec_bcast110;
AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
@@ -46,6 +47,7 @@ public:
all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_vec_bcastX0X);
all_algos.emplace_back(&algo_binary_vec_bcast110);
all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
@@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return kern_param;
}

if (is_vector(src0.layout) && is_broadcasted_3dim_like(src1.layout, binfo)) {
kern_param.broad_cast_type = BcastType::VEC_BCASTX0X;
return kern_param;
}

if (is_vector(src1.layout) && is_broadcasted_3dim_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCASTX0X_VEC;
return kern_param;
}

if (is_legal_layout_for_nhwc(src1.layout) &&
is_NHWC_broadcasted_channel_like(src0.layout, binfo)) {
kern_param.broad_cast_type = BcastType::BCAST111C_VEC;


+ 1
- 0
dnn/src/arm_common/elemwise/opr_impl.h View File

@@ -38,6 +38,7 @@ private:
class AlgoBinaryVecVec;
class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcastX0X;
class AlgoBinaryVecBcast111C;
class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec;


+ 136
- 0
dnn/src/arm_common/elemwise_op.h View File

@@ -107,11 +107,13 @@ enum BcastType {
VEC,
VEC_VEC,
VEC_BCAST101,
VEC_BCASTX0X,
VEC_BCAST111C,
VEC_BCAST101xX,
VEC_SCALAR,
SCALAR_VEC,
BCAST101_VEC,
BCASTX0X_VEC,
BCAST111C_VEC,
BCAST101xX_VEC,
VEC_VEC_VEC,
@@ -230,6 +232,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> {
}
};

template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCASTX0X> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src1_ptr = src1_ptr_base;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, dst);
src0++;
src1_ptr++;
dst++;
}
}
}
}
};

template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST111C> {
using Op = PowOp<ctype, ctype>;
@@ -332,6 +362,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101_VEC> {
}
};

template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, BCASTX0X_VEC> {
using Op = PowOp<ctype, ctype>;
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
for (size_t b = 0; b < batch; b++) {
auto src0_ptr_base = src0 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src0_ptr = src0_ptr_base;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, dst);
src0_ptr++;
src1++;
dst++;
}
}
}
}
};

template <typename Op>
struct OpCallerBinary<Op, VEC_VEC> {
static void run(
@@ -398,6 +456,45 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
}
};

template <typename Op>
struct OpCallerBinary<Op, VEC_BCASTX0X> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
const typename Op::src_ctype* src1_ptr_base = src1 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
size_t i = 0;
auto src1_ptr = src1_ptr_base;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0);
auto src0_neon1 = vis(src0 + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1_ptr);
auto src1_neon1 = vis(src1_ptr + Op::SIMD_WIDTH);
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
src0 += Op::SIMD_WIDTH * 2;
src1_ptr += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0, *src1_ptr, dst);
src0++;
src1_ptr++;
dst++;
}
}
}
}
};

template <typename Op>
struct OpCallerBinary<Op, VEC_BCAST111C> {
static void run(
@@ -844,6 +941,45 @@ struct OpCallerBinary<Op, BCAST101_VEC> {
}
};

template <typename Op>
struct OpCallerBinary<Op, BCASTX0X_VEC> {
static void run(
const typename Op::src_ctype* src0, const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype,
DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) {
Op op(src0_dtype, src1_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis;
for (size_t b = 0; b < batch; b++) {
auto src0_ptr_base = src0 + b * channel_stride;
for (size_t c = 0; c < channel; c++) {
auto src0_ptr = src0_ptr_base;
size_t i = 0;
for (; i + Op::SIMD_WIDTH * 2 <= channel_stride;
i += Op::SIMD_WIDTH * 2) {
auto src0_neon0 = vis(src0_ptr);
auto src0_neon1 = vis(src0_ptr + Op::SIMD_WIDTH);
auto src1_neon0 = vis(src1);
auto src1_neon1 = vis(src1 + Op::SIMD_WIDTH);
op({{src0_neon0, src0_neon1}}, {{src1_neon0, src1_neon1}}, dst);
src0_ptr += Op::SIMD_WIDTH * 2;
src1 += Op::SIMD_WIDTH * 2;
dst += Op::SIMD_WIDTH * 2;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for (; i < channel_stride; i++) {
op(*src0_ptr, *src1, dst);
src0_ptr++;
src1++;
dst++;
}
}
}
}
};

template <typename Op, BcastType bcast_type>
struct OpCallerTernary;



+ 14
- 0
dnn/src/common/elemwise/opr_impl_helper.cpp View File

@@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
return false;
}

bool ElemwiseLayoutHelper::is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {
if (layout.ndim == 3 && (layout.stride[0] - layout.shape[2]) == 0 &&
layout.stride[1] == 0 && layout.stride[2] == 1) {
info.x = layout.shape[0];
info.y = layout.shape[1];
info.z = layout.shape[2];
return true;
}
}
return false;
}

bool ElemwiseLayoutHelper::is_NHWC_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info) {
if (layout.format.type() == TensorFormat::Type::DEFAULT) {


+ 8
- 0
dnn/src/common/elemwise/opr_impl_helper.h View File

@@ -80,6 +80,14 @@ public:
static bool is_broadcasted_channel_like(
const TensorLayout& layout, BroadcastChannelInfo& info);

/*!
* \brief check whether layout matches BroadcastChannelInfo like N1HW
*
* Note layout should be [N, 1, H*W] like
*/
static bool is_broadcasted_3dim_like(
const TensorLayout& layout, BroadcastChannelInfo& info);

/*!
* \brief check whether layout matches BroadcastChannelInfo under NHWC
* layout


+ 24
- 0
dnn/test/arm_common/elemwise.cpp View File

@@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
run_3d_incontig(Mode::FUSE_MUL_ADD3);
}

TEST_F(ARM_COMMON, ELEMWISE_FORWARD_N1HW_FP32_BCAST) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());

UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());

//! 2 dim
auto run = [&](Mode mode) {
// VEC_BCASTX0X
checker.set_param(mode).execs({{2, 8, 4, 4}, {2, 1, 4, 4}, {}});
checker.set_param(mode).execs({{4, 21, 78}, {4, 1, 78}, {}});
// BCASTX0X_VEC
checker.set_param(mode).execs({{2, 1, 4, 4}, {2, 8, 4, 4}, {}});
checker.set_param(mode).execs({{4, 1, 78}, {4, 21, 78}, {}});
};
run(Mode::ADD);
run(Mode::MUL);
run(Mode::SUB);
}

#if MEGDNN_WITH_BENCHMARK
namespace {
void run_elemwise_benchmark(


Loading…
Cancel
Save