Browse Source

refactor(dnn/arm): split arm direct kernel to cut compile time

GitOrigin-RevId: b06fba83eb
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
4d56371e0b
48 changed files with 6492 additions and 5552 deletions
  1. +173
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
  2. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
  3. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
  4. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
  5. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
  6. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
  7. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
  8. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
  9. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
  10. +81
    -152
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
  11. +89
    -251
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
  12. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
  13. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
  14. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
  15. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
  16. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
  17. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
  18. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
  19. +14
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
  20. +443
    -0
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  21. +10
    -32
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  22. +34
    -0
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h
  23. +4
    -2
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
  24. +12
    -395
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
  25. +0
    -40
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
  26. +0
    -40
      dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h
  27. +4
    -228
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
  28. +10
    -16
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
  29. +5
    -9
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  30. +0
    -435
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
  31. +245
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
  32. +320
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
  33. +322
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
  34. +448
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
  35. +437
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
  36. +743
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
  37. +778
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
  38. +47
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h
  39. +561
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
  40. +1412
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
  41. +26
    -57
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
  42. +7
    -1337
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
  43. +10
    -9
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  44. +3
    -1854
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  45. +5
    -4
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  46. +14
    -662
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
  47. +2
    -2
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  48. +23
    -27
      dnn/test/arm_common/matrix_mul.cpp

+ 173
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp View File

@@ -0,0 +1,173 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
template <>
void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin,
const int, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

namespace {

static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}

static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start,
const int src_idx, const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace

template <>
void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin,
const int ph, const int pw, const int pad_right,
const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
int src_idx = 0;
if (even_start) {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
} else {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(3);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(3);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(5);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(5);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h"
INSTANTIATION_CONV_S1(7);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h"
INSTANTIATION_CONV_S2(7);

dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h View File

@@ -1,6 +1,6 @@
/** /**
* \file * \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -12,7 +12,7 @@
*/ */


#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
@@ -24,21 +24,21 @@ using namespace megdnn;
using namespace arm_common; using namespace arm_common;
namespace { namespace {


template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8], lane);


UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1); UNROLL_CALL_RAW(8, cb, 1);
@@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4], lane);


UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1); UNROLL_CALL_RAW(4, cb, 1);
@@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane);


UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1); UNROLL_CALL_RAW(8, cb, 1);
@@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane);


UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1); UNROLL_CALL_RAW(4, cb, 1);
@@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
} }
}; };


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl(
c, src, weight);
}; };
template <int oc> template <int oc>
struct OCHelper { struct OCHelper {
@@ -162,13 +162,11 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
0); 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
} }
@@ -209,18 +207,15 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
0); 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
} }
@@ -260,32 +255,27 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
0); 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);


src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);


src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);


src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
} }
@@ -326,44 +316,37 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
0); 0);
load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);


src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step);
load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step);
load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);


src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step);
load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);


src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step);
load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<4, 0, c_dim, ow_block>(c, src, weight);


src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step);
load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<5, 0, c_dim, ow_block>(c, src, weight);


src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step);
load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<6, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
} }
@@ -375,36 +358,14 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {


} // namespace } // namespace


void conv_bias::pack_src_fp32_nchw44_stride1(
float* sptr_base, const float* sptr_origin, const int, const int pw,
const int pad_right, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int ic,
const int ic_stride) {
constexpr int ic_step = 4;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memset(sptr_base, 0, sizeof(float) * pw * ic_step);
sptr_base += pw * ic_step;
memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step);
sptr_base += iw * ic_step;
sptr += iw * ic_step;
memset(sptr_base, 0, sizeof(float) * pad_right * ic_step);
sptr_base += pad_right * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride1_fp32_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic,
const int ih, const int iw,
const int oh, const int oh_block,
const int ow, const Op& op, const int,
const int) {
constexpr int fh = filter_size; constexpr int fh = filter_size;
constexpr int fw = filter_size; constexpr int fw = filter_size;
constexpr int ic_step = 4; constexpr int ic_step = 4;
@@ -518,55 +479,23 @@ static void conv_direct_stride1_fp32_nchw44(
} }
} }


#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride1_fp32_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \
bias, Op>(const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);

#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(stride1)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
#define INSTANTIATION(filter_size, bias_mode, Op) \
template void \
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 1>( \
const float* src, const float* filter, const float* bias, float*, \
float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int, const int);

#define FOR_OP(filter_size, bias) \
INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \
INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S1(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp → dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h View File

@@ -1,6 +1,6 @@
/** /**
* \file * \file
* dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -12,7 +12,7 @@
*/ */


#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
@@ -24,21 +24,21 @@ using namespace megdnn;
using namespace arm_common; using namespace arm_common;
namespace { namespace {


template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
typename T, typename T2, typename T3, typename T4>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 8], lane);


UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1); UNROLL_CALL_RAW(8, cb, 1);
@@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]); \
c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \
src[(step + src_idx) % 4], lane);


UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1); UNROLL_CALL_RAW(4, cb, 1);
@@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 8], lane);


UNROLL_CALL_RAW(8, cb, 0); UNROLL_CALL_RAW(8, cb, 0);
UNROLL_CALL_RAW(8, cb, 1); UNROLL_CALL_RAW(8, cb, 1);
@@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> {
#undef cb #undef cb
} }
}; };
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step, lane) \
c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4]);
#define cb(step, lane) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \
src[(step + src_idx) % 4], lane);


UNROLL_CALL_RAW(4, cb, 0); UNROLL_CALL_RAW(4, cb, 0);
UNROLL_CALL_RAW(4, cb, 1); UNROLL_CALL_RAW(4, cb, 1);
@@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> {
} }
}; };


template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
typename T, typename T2, typename T3>
template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3,
int>::impl(c, src, weight);
ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl(
c, src, weight);
}; };
template <int oc> template <int oc>
struct OCHelper { struct OCHelper {
@@ -163,13 +163,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);


load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0); 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
@@ -177,13 +177,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);


load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0); 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
@@ -224,18 +224,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);


src[0] = vld1q_f32(src_ptr + ow_block * simd_len); src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0); 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
@@ -243,17 +243,17 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0); 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
@@ -261,18 +261,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> {
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); src[0] = vld1q_f32(src_ptr + ow_block * simd_len);


load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd,
0); 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
weight_ptr += ld_weight_fh; weight_ptr += ld_weight_fh;
@@ -316,30 +316,25 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> {
0); 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
// odd element // odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr_odd, 0); src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);


src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
@@ -390,40 +385,33 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
0); 0);
load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr,
ld_weight_oc); ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr + ow_block * simd_len); src[0] = vld1q_f32(src_ptr + ow_block * simd_len);
load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len);
load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);
src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len);
load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<3, 0, c_dim, ow_block>(c, src, weight);
// odd element // odd element
load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr_odd, 0); src, src_ptr_odd, 0);
load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<0, 0, c_dim, ow_block>(c, src, weight);
src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len);
load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<1, 0, c_dim, ow_block>(c, src, weight);
src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len);
load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc); weight, weight_ptr, ld_weight_oc);
cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src,
weight);
cal_helper<2, 0, c_dim, ow_block>(c, src, weight);


src_ptr += ld_src_iw; src_ptr += ld_src_iw;
src_ptr_odd += ld_src_iw; src_ptr_odd += ld_src_iw;
@@ -436,133 +424,15 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> {
}; };


} // namespace } // namespace
namespace {

inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = iw_idx / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]);
}

inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr,
const int odd_start, const int src_idx,
const int iw_idx) {
constexpr int ic_step = 4;
const int src_offset = src_idx * ic_step;
const int even_offset = (iw_idx + 1) / 2 * ic_step;
const int odd_offset = (odd_start + iw_idx / 2) * ic_step;
float32x4_t temp[8];
temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step);
temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step);
temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step);
temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step);
temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step);
temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step);
temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step);
temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step);
vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]);
vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]);
vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]);
vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]);
vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]);
vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]);
vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]);
vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]);
}
} // namespace

void conv_bias::pack_src_fp32_nchw44_stride2(
float* sptr_base, const float* sptr_origin, const int ph, const int pw,
const int pad_right, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int ic,
const int ic_stride) {
constexpr int ic_step = 4;
int odd_start = megdnn::div_ceil(iw2, 2);
float32x4_t zero_v = vdupq_n_f32(0.f);
MEGDNN_MARK_USED_VAR(ph);
bool even_start = pw % 2 == 0;
rep_step(ic_idx, ic, ic_step) {
const float* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step);
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
int iw_idx = 0;
rep(idx, pw) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
int src_idx = 0;
if (even_start) {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
} else {
for (; src_idx + 7 < iw; src_idx += 8) {
odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx,
iw_idx);
iw_idx += 8;
}
}
for (; src_idx < iw; ++src_idx) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
vld1q_f32(sptr + src_idx * ic_step));
}
++iw_idx;
}
rep(idx, pad_right) {
if (iw_idx % 2 == 0) {
vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v);
} else {
vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step,
zero_v);
}
++iw_idx;
}
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step);
sptr_base += iw2 * pad_bottom * ic_step;
}
}


template <BiasMode bias_mode, typename Op, int filter_size>
static void conv_direct_stride2_fp32_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic,
const int ih, const int iw,
const int oh, const int oh_block,
const int ow, const Op& op, const int,
const int) {
constexpr int fh = filter_size; constexpr int fh = filter_size;
constexpr int fw = filter_size; constexpr int fw = filter_size;
constexpr int ic_step = 4; constexpr int ic_step = 4;
@@ -697,55 +567,23 @@ static void conv_direct_stride2_fp32_nchw44(
} }
} }


#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw44( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t* temp, float32_t* dst, \
const int oc, const int ic, const int ih, const int iw, \
const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw) { \
conv_direct_stride2_fp32_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \
ow, op, ph, pw); \
}
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \
bias, Op>(const float32_t*, const float32_t*, const float32_t*, \
float32_t*, float32_t*, const int, const int, const int, \
const int, const int, const int, const int, const Op&, \
const int, const int);

#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \
INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, i, BiasMode::BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(stride2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION
#define INSTANTIATION(filter_size, bias_mode, Op) \
template void \
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 2>( \
const float* src, const float* filter, const float* bias, float*, \
float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int, const int);

#define FOR_OP(filter_size, bias) \
INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \
INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \
INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \
INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>)

#define INSTANTIATION_CONV_S2(filter_size) \
FOR_OP(filter_size, BiasMode::NO_BIAS) \
FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(filter_size, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(2, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(3, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(5, 2);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 1);

+ 14
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp View File

@@ -0,0 +1,14 @@
/**
* \file
* dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h"
INSTANCE_CONV(7, 2);

+ 443
- 0
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -0,0 +1,443 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

using namespace megdnn;
using namespace arm_common;

namespace {
/**
*\brief ShiftCalHelper is core calculate code
*\tparam src_idx is offset for src regs
*\tparam weight_idx is offset for weight regs
*\tparam T is type of output regs
*\tparam T2 is type of src regs
*\tparam T3 is type of weight regs
*/
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, int stride, typename T, typename T2,
typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4); \
c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, int stride, typename T, typename T2,
typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4], \
(step * stride + src_idx) % 4);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3>::impl(c, src,
weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};

template <>
struct OCHelper<4> {
public:
static const int val = 1;
};

template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
**/
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, stride>(c, src, weight); \
cal_helper<5, 5, c_dim, stride>(c, src, weight); \
cal_helper<6, 6, c_dim, stride>(c, src, weight);

UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, stride>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);
cal_helper<2, 2, c_dim, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, stride>(c, src, weight);
cal_helper<1, 1, c_dim, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

} // namespace

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = 1;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step, stride, ow_step>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
big_oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

#define INSTANTIATION(stride, filter_size, bias_mode, Op) \
template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \
bias_mode, Op, filter_size, stride>( \
const float32_t* src, const float32_t* filter, \
const float32_t* bias, float32_t*, float32_t* dst, const int oc, \
const int ic, const int ih, const int iw, const int oh, \
const int oh_block, const int ow, const Op& op, const int, \
const int);

#define FOR_OP(stride, filter, bias) \
INSTANTIATION(stride, filter, bias, NoneOp<dt_float32>) \
INSTANTIATION(stride, filter, bias, ReluOp<dt_float32>) \
INSTANTIATION(stride, filter, bias, HSwishOp<dt_float32>)

#define INSTANCE_CONV(filter, stride) \
FOR_OP(stride, filter, BiasMode::NO_BIAS) \
FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
FOR_OP(stride, filter, BiasMode::BIAS)

// vim: syntax=cpp.doxygen

+ 10
- 32
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -13,8 +13,8 @@
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/fp32/algos.h" #include "src/arm_common/conv_bias/fp32/algos.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"


#include "midout.h" #include "midout.h"
@@ -112,17 +112,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2);
float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) + float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) +
ncb_index.thread_id * src_size); ncb_index.thread_id * src_size);
if (stride == 1) {
conv_bias::pack_src_fp32_nchw44_stride1(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
} else {
conv_bias::pack_src_fp32_nchw44_stride2(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);
}

conv_bias::pack_src_fp32_nchw44<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw);


const float* fptr = const float* fptr =
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
@@ -135,25 +129,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset; kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset;


Op op; Op op;
if (stride == 1) {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)

DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
} else {
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \
\
bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \
ih_real, iw2, oh, oh_block_real, ow, op, ph, pw)

DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}
conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op, ph, pw);
} }


} // namespace } // namespace


+ 34
- 0
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_fp32_nchw44(const float* src, const float* filter,
const float* bias, float*, float* dst,
const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block,
const int ow, const Op& op, const int, const int);
template <int stride>
void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int,
const int pw, const int pad_right, const int ih,
const int iw, const int iw2, const int pad_top,
const int pad_bottom, const int ic,
const int ic_stride);

} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 4
- 2
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -120,7 +120,8 @@ static void pack_weight(const WorkspaceBundle& bundle,
kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) + auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) +
group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw;
pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic);
fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight,
oc_block, fh, fw, ic);
} }


template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride> template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride>
@@ -180,7 +181,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx; kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx;
Op op; Op op;


conv_direct_fp32_nchw_nchw44<bias_mode, Op, filter_size, stride>(
fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44<bias_mode, Op,
filter_size, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2,
oh, oh_block_real, ow, op, ph, pw); oh, oh_block_real, ow, op, ph, pw);
} }


+ 12
- 395
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h View File

@@ -20,295 +20,12 @@
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace {
/**
*\brief ShiftCalHelper is core calculate code
*\tparam src_idx is offset for src regs
*\tparam weight_idx is offset for weight regs
*\tparam T is type of output regs
*\tparam T2 is type of src regs
*\tparam T3 is type of weight regs
*/
template <int src_idx, int weight_idx, int c_dim, typename Func, int stride,
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};

template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]); \
c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[1][step], weight[1][weight_idx], \
src[(step * stride + src_idx) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \
c[0][step], weight[0][weight_idx], \
src[(step * stride + src_idx) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride,
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl(
c, src, weight);
};
template <int oc>
struct OCHelper {
public:
static const int val = -1;
};

template <>
struct OCHelper<4> {
public:
static const int val = 1;
};

template <>
struct OCHelper<8> {
public:
static const int val = 2;
};
/**
* oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel
**/
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32 {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op);
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 7;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

UNROLL_CALL_RAW(7, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 5;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];

#define KERNEL_CB(step) \
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \
src, src_ptr + step * iw, 0); \
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \
weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \
cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
UNROLL_CALL_RAW(5, KERNEL_CB)
#undef KERNEL_CB

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 3;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
namespace fp32_direct_nchw_nchw44 {


// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

// row 2
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride, int ow_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride,
ow_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size =
(ow_block * stride + filter_size - stride + simd_len - 1) /
simd_len;

constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight);

src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
const int oc, const int kh, const int kw,
const int ic) {
static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr,
float32_t* dst_ptr,
const int oc, const int kh,
const int kw, const int ic) {
constexpr int oc_step = 4; constexpr int oc_step = 4;
const int filter_oc_stride = kh * kw * ic; const int filter_oc_stride = kh * kw * ic;
const int filter_ic_stride = kh * kw * oc_step; const int filter_ic_stride = kh * kw * oc_step;
@@ -327,115 +44,15 @@ void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr,
} }
} }
} }

template <BiasMode bias_mode, typename Op, int filter_size, int stride> template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_fp32_nchw_nchw44(
const float32_t* src, const float32_t* filter, const float32_t* bias,
float32_t*, float32_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op, const int, const int) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 1;
constexpr int big_oc_step = 8;
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = 1;
void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter,
const float32_t* bias, float32_t*,
float32_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op, const int, const int);
} // namespace fp32_direct_nchw_nchw44


const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
big_oc_step, stride, ow_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \
oc_step, stride, ow_step>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}
for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
big_oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size,
oc_step, stride,
ow_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
} // namespace
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 40
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h View File

@@ -1,40 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);

KERN(stride1, 2, nchw44)
KERN(stride1, 3, nchw44)
KERN(stride1, 5, nchw44)
KERN(stride1, 7, nchw44)
#undef KERN

void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 0
- 40
dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h View File

@@ -1,40 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"
namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_fp32_##layout( \
const float* src, const float* filter, const float* bias, \
float* temp, float* dst, const int oc, const int ic, const int ih, \
const int iw, const int oh, const int oh_block, const int ow, \
const Op& op, const int ph, const int pw);

KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN

void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin,
const int ph, const int pw,
const int pad_right, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 4
- 228
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#ifdef __ARM_FEATURE_DOTPROD #ifdef __ARM_FEATURE_DOTPROD
@@ -17,7 +18,7 @@
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h"
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace direct_dotprod_nchw44 { namespace direct_dotprod_nchw44 {
@@ -139,234 +140,9 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step,
} }
} }


template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define CONSTRUCT_FUNC(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op) { \
conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \
filter_size>( \
dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \
}

CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

#define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \
template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \
stride>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(1)
FOR_FILTER(2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44 } // namespace direct_dotprod_nchw44
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif #endif


//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 10
- 16
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -42,20 +43,13 @@ using BiasMode = ConvBiasForward::BiasMode;
* @return none * @return none
*/ */


#define KERN(filter_size) \
template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \
dst_type* dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op)

KERN(2);
KERN(3);
KERN(5);
KERN(7);

#undef KERN
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op);
/** /**
* @brief : copy data from src to dst for direct conv with no side effect * @brief : copy data from src to dst for direct conv with no side effect
* @param : [output ptr] dst * @param : [output ptr] dst
@@ -84,4 +78,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,


#endif #endif


//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 5
- 9
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -148,14 +148,10 @@ static void conv_kern(const WorkspaceBundle& bundle,
float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale; float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst); op = Op(scale_bias, scale_dst);
} }

#define KERN1_NCHW44_CONV(filter) \
direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \
dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \
ih_real_size, iw2, weights, bias, \
oh_real_size, OC, IC, op);
DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44<
dst_type, stride, bias_mode, Op, filter_size>(
dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias,
oh_real_size, OC, IC, op);
} }


} // namespace } // namespace
@@ -342,4 +338,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(


#endif #endif


//vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 0
- 435
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h View File

@@ -1,435 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#ifdef __ARM_FEATURE_DOTPROD

#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

constexpr int SIMD_LEN = 16;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

template <int row, BiasMode bias_mode>
MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8],
const int32_t* bias_ptr, int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
} else {
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
}
}

#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));

#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));

#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));

#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));

#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));

#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));

template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc);
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {

static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc);
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb12);
break;
case 7:
cb11(6);
case 6:
UNROLL_CALL_RAW(3, cb12);
break;
case 5:
cb11(4);
case 4:
UNROLL_CALL_RAW(2, cb12);
break;
case 3:
cb11(2);
case 2:
UNROLL_CALL_RAW(1, cb12);
break;
case 1:
cb11(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb22);
break;
case 7:
cb21(6);
case 6:
UNROLL_CALL_RAW(3, cb22);
break;
case 5:
cb21(4);
case 4:
UNROLL_CALL_RAW(2, cb22);
break;
case 3:
cb21(2);
case 2:
UNROLL_CALL_RAW(1, cb22);
break;
case 1:
cb21(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb32);
break;
case 7:
cb31(6);
case 6:
UNROLL_CALL_RAW(3, cb32);
break;
case 5:
cb31(4);
case 4:
UNROLL_CALL_RAW(2, cb32);
break;
case 3:
cb31(2);
case 2:
UNROLL_CALL_RAW(1, cb32);
break;
case 1:
cb31(0);
default:
break;
}
}
};

#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32

template <int row, int ow_remain, typename Op, typename T>
MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
const Op& op, T* dst_ptr,
const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
}

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \
res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4]);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename FUNC, typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2,
T3>::impl(res, src, weight);
};

/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int ow_remain, int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44 {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op);
};

template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];

load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;

load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \
weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \
weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \
weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

#endif

// vim: syntax=cpp.doxygen

+ 245
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h View File

@@ -0,0 +1,245 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#if __ARM_FEATURE_DOTPROD
#include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/common/unroll_macro.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

constexpr int SIMD_LEN = 16;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;
constexpr int filter_next_col =
IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

template <int row, BiasMode bias_mode>
MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8],
const int32_t* bias_ptr, int oc_step) {
static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number.");
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {
#define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
} else {
#define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0);
switch (row) {
case 3:
UNROLL_CALL_RAW(8, BIAS_INIT, 2);
case 2:
UNROLL_CALL_RAW(8, BIAS_INIT, 1);
default:
UNROLL_CALL_RAW(8, BIAS_INIT, 0);
}
#undef BIAS_INIT
}
}

#define cb11(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8));

#define cb21(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8));

#define cb31(col) \
op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \
op(res[1][col], \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \
op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \
ld_dst_oc + col / 2 * 8));

#define cb12(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8));

#define cb22(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8));

#define cb32(step) \
op({{res[0][2 * step], res[0][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \
op({{res[1][2 * step], res[1][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \
op({{res[2][2 * step], res[2][2 * step + 1]}}, \
reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8));

template <int row, int ow_remain, typename Op, typename T>
struct StoreOCxOWx {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc);
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<1, ow_remain, Op, T> {
static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr,
const int ld_dst_oc) {
MEGDNN_MARK_USED_VAR(ld_dst_oc);
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb12);
break;
case 7:
cb11(6);
case 6:
UNROLL_CALL_RAW(3, cb12);
break;
case 5:
cb11(4);
case 4:
UNROLL_CALL_RAW(2, cb12);
break;
case 3:
cb11(2);
case 2:
UNROLL_CALL_RAW(1, cb12);
break;
case 1:
cb11(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<2, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb22);
break;
case 7:
cb21(6);
case 6:
UNROLL_CALL_RAW(3, cb22);
break;
case 5:
cb21(4);
case 4:
UNROLL_CALL_RAW(2, cb22);
break;
case 3:
cb21(2);
case 2:
UNROLL_CALL_RAW(1, cb22);
break;
case 1:
cb21(0);
default:
break;
}
}
};

template <int ow_remain, typename Op, typename T>
struct StoreOCxOWx<3, ow_remain, Op, T> {
static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op,
T* dst_ptr, const int ld_dst_oc) {
switch (ow_remain) {
case 8:
UNROLL_CALL_RAW(4, cb32);
break;
case 7:
cb31(6);
case 6:
UNROLL_CALL_RAW(3, cb32);
break;
case 5:
cb31(4);
case 4:
UNROLL_CALL_RAW(2, cb32);
break;
case 3:
cb31(2);
case 2:
UNROLL_CALL_RAW(1, cb32);
break;
case 1:
cb31(0);
default:
break;
}
}
};

#undef cb11
#undef cb21
#undef cb31
#undef cb12
#undef cb22
#undef cb32

template <int row, int ow_remain, typename Op, typename T>
MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
const Op& op, T* dst_ptr,
const int ld_dst_oc) {
StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc);
}

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \
res[res_row][step] = \
vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \
src[src_row][(src_start_idx + step) / 4], \
(src_start_idx + step) % 4);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2,
T3>::impl(res, src, weight);
};

/**
* oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x)
* gemm like kernel
* */
template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int ow_remain, int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44 {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op);
};

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 320
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp View File

@@ -0,0 +1,320 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[1][4];
int8x16_t weight[3];

load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, 0, step, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, 0, step, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, 0, step, 2>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(1)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 322
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp View File

@@ -0,0 +1,322 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD

#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {
template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> {
static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int filter_next_row =
FW * OC_PACK_SIZE *
IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]

const int filter_next_4oc =
FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC]
const int src_next_ic = ih * iw;
const int src_next_row = iw * IC_PACK_SIZE;

constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1;
constexpr int LOOP = oc_interval / 4;

int32x4_t res[3][ow_interval];
init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE);

for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) {
const int8_t* i_src = src + ic_idx * src_next_ic;
const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE;
for (int fh_idx = 0; fh_idx < FH; ++fh_idx) {
int8x16_t src[2][3];
int8x16_t weight[3];
const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE;

load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset);

//! do not use switch order 3,2,1 because it will slow the speed.
#define CALC_PART(step) \
switch (LOOP) { \
case 1: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
break; \
case 2: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
break; \
case 3: \
weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \
filter_next_col * step); \
cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \
weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \
filter_next_col * step); \
cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \
weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \
filter_next_col * step); \
cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \
break; \
default: \
break; \
}

switch (filter_size) {
case 2:
UNROLL_CALL_RAW(2, CALC_PART);
break;
case 3:
UNROLL_CALL_RAW(3, CALC_PART);
break;
case 5:
UNROLL_CALL_RAW(5, CALC_PART);
break;
case 7:
UNROLL_CALL_RAW(7, CALC_PART);
break;
default:
break;
}
#undef CALC_PART

i_filter += filter_next_row;
i_src += src_next_row;
}
}
store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst,
dst_step);
}
};

template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size>
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias,
const int oh_size, const int oc, const int ic,
const Op& op) {
constexpr int FH = filter_size;
constexpr int FW = filter_size;
constexpr int IC_PACK_SIZE = 4;
constexpr int OC_PACK_SIZE = 4;

#if MEGDNN_AARCH64
constexpr int OC_BIG_INTERVAL = 12;
constexpr int OC_MID_INTERVAL = 8;
constexpr int OC_SMA_INTERVAL = 4;
#else
constexpr int OC_BIG_INTERVAL = 4;
constexpr int OC_MID_INTERVAL = 4;
constexpr int OC_SMA_INTERVAL = 4;
#endif

constexpr int OW_INTERVAL = 8;
constexpr int SH = stride;

const int dst_numbers_per_channel = oh * ow;
const int ow_remain = ow % OW_INTERVAL;
const int ow_end_idx = ow - ow_remain;
const int oc_remain =
oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8
const int oc_end_idx = oc - oc_remain;
const int dst_numbers_4channel_packed =
dst_numbers_per_channel * OC_PACK_SIZE;

using remain_fun = std::function<void(
dst_type * dst, const int dst_step, const int8_t* src, const int ih,
const int iw, const int8_t* filter, const int32_t* bias,
const int ic, const Op& op)>;

remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_mid_oc_remain = nullptr;
remain_fun kern_sma_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_BIG_INTERVAL, \
OW_INTERVAL>::impl; \
kern_mid_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_MID_INTERVAL, \
OW_INTERVAL>::impl; \
kern_sma_oc_remain = \
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \
filter_size, OC_SMA_INTERVAL, \
OW_INTERVAL>::impl; \
break;
UNROLL_CALL_RAW(8, cb);
#undef cb
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

//! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC]
//! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL,
//! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates
//! [OW_INTERVAL, 1, OC_INTERVAL] each time
for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) {
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_BIG_INTERVAL, OW_INTERVAL>::
impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
kern_big_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}

#ifdef MEGDNN_AARCH64
//! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32
if (oc_remain) {
int oc_idx = oc_end_idx;
const int filter_offset_in_element = oc_idx * ic * FH * FW;
for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) {
for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_MID_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
} else {
KernNeonSdotNCHW44<
dst_type, stride, bias_mode, Op, OW_INTERVAL,
filter_size, OC_SMA_INTERVAL,
OW_INTERVAL>::impl(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih,
iw,
filter +
filter_offset_in_element,
bias + bias_offset_in_element,
ic, op);
}
}
if (ow_remain) {
const int src_offset_in_element =
(oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE;
const int dst_offset_in_element =
oc_idx * dst_numbers_per_channel +
(oh_idx * ow + ow_end_idx) * OC_PACK_SIZE;
const int bias_offset_in_element = oc_idx;
if (oc_remain == 8) {
kern_mid_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
} else {
kern_sma_oc_remain(dst + dst_offset_in_element,
dst_numbers_4channel_packed,
src + src_offset_in_element, ih, iw,
filter + filter_offset_in_element,
bias + bias_offset_in_element, ic, op);
}
}
}
}
#endif
}

#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \
template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \
Op, filter_size>( \
dst_type * dst, const int oh, const int ow, const int8_t* src, \
const int ih, const int iw, const int8_t* weight, \
const int32_t* bias, const int oh_size, const int oc, \
const int ic, const Op& op);

#define FOR_OP(stride, i, bias_mode) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int32, stride, i, bias_mode, \
NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(dt_int8, stride, i, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

} // namespace direct_dotprod_nchw44
} // namespace arm_common
} // namespace megdnn

#endif
// vim: syntax=cpp.doxygen

+ 448
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp View File

@@ -0,0 +1,448 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 3;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];

#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <>
void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
const int8_t* sptr_origin, const int,
const int pw, const int, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 16;
constexpr int pack_iw_len = 4;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 4);
src[2] = vld1q_s8(temp_ptr + iw_idx + 8);
src[3] = vld1q_s8(temp_ptr + iw_idx + 12);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
*(sptr_base + iw_idx * pack_iw_len + 0) =
*(temp_ptr + iw_idx + 0);
*(sptr_base + iw_idx * pack_iw_len + 1) =
*(temp_ptr + iw_idx + 1);
*(sptr_base + iw_idx * pack_iw_len + 2) =
*(temp_ptr + iw_idx + 2);
*(sptr_base + iw_idx * pack_iw_len + 3) =
*(temp_ptr + iw_idx + 3);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template void \
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const int oc, const int ic, \
const int ih, const int iw, const int oh, const int oh_block, \
const int ow, const Op& op);

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(1);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen

+ 437
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp View File

@@ -0,0 +1,437 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 2 * iw, stride);
load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 5 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
2> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 2;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 7 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <>
void pack_src_int8_nchw_nchw44_dot<2>(
int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw,
const int, const int ih, const int iw, const int iw2, const int pad_top,
const int pad_bottom, const int ic, const int ic_stride, int8_t*) {
constexpr int ic_step = 1;
rep_step(ic_idx, ic, ic_step) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom));
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memcpy(sptr_base + pw * ic_step, sptr,
sizeof(int8_t) * iw * ic_step);
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <BiasMode bias_mode, typename Op, int filter_size, int stride>
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template void \
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const int oc, const int ic, \
const int ih, const int iw, const int oh, const int oh_block, \
const int ow, const Op& op);

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(2);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

#endif
// vim: syntax=cpp.doxygen

+ 743
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp View File

@@ -0,0 +1,743 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;

int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];

init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]);

c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]);

c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]);

c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[1][8];
int8x16_t weight[1][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]);

c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride1Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
//! TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));

c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]);

src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));

c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, step, \
filter_size, 2, DstType>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, step, \
filter_size, 1, DstType>; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb
for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size,
2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_oc, op);
}
}
}
if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size,
1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_oc, op);
}
}
}
}
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
void conv_direct_stride1_int8_nchw44_kern(const int8_t* src,
const int8_t* filter,
const int32_t* bias, int32_t* temp,
DstType* dst, const size_t oc,
const size_t ic, const size_t ih,
const size_t iw, const size_t oh,
const size_t ow, const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const int ld_dst_oc = oh * ow * oc_step;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
const Op& op, int ld_dst_oc)>;

remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_small_oc_remain = \
KerNeonDirectStride1Int8<bias_mode, Op, step, filter_size, 1, \
DstType>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * iw + ow_idx) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride1Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * iw + ow_end) * ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, op, ld_dst_oc);
}
}
}
}
} // namespace

namespace int8_direct_nchw44 {
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride1_int8_nchw44_kern<bias_mode, Op, filter_size,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride1_2x2_int8_nchw44<bias_mode, Op, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \
DstType, stride>;

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(1);

} // namespace int8_direct_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 778
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp View File

@@ -0,0 +1,778 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct.h"
#include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

namespace megdnn {
namespace arm_common {
namespace {
template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
struct KerNeonDirectStride2Int8 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc);
};

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc4 = oc_step * fh * fw * ic;

int32x4_t c[2][8];
int8x16_t weight[2][2];
int8x16_t src[8 + 1];
int16x8_t temp_c[4];

init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0][0] = vld1q_s8(read_weight_ptr);
weight[0][1] = vld1q_s8(read_weight_ptr + 16);
weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4);
weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16);

c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]);
c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]);
c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]);
c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]);

c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]);
c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]);
c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]);
c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]);

src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]);
c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]);
c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]);
c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]);

src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]);
c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]);
c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]);
c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int c_dim, typename DstType>
static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr,
const int8_t* weight_ptr,
const int32_t* bias_ptr,
DstType* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc,
const Op& op) {
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[2];
int8x16_t src[8 + 1];
int16x8_t temp_c[2];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 9 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 11 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]);

src[3] = vld1q_s8(src_ic_0_3 + 12 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 15 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
/**
dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc>
example: (format like weight<oc, ic>)
packed weight
low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3>
---------------------------------------------------------------------
high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0>
dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0>
**/
// TODO: can try oh = 2 impl, oc = 8 impl
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 3;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[3];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);

c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);

src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 5;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[5];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8((src_ic_0_3 + 16));
src[2] = vld1q_s8((src_ic_0_3 + 2 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 3 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 4 * 16));
src[5] = vld1q_s8((src_ic_0_3 + 5 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 6 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 7 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 8 * 16));
src[9] = vld1q_s8((src_ic_0_3 + 9 * 16));

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);

src[1] = vld1q_s8((src_ic_0_3 + 11 * 16));
src[2] = vld1q_s8((src_ic_0_3 + 12 * 16));
src[3] = vld1q_s8((src_ic_0_3 + 13 * 16));
src[4] = vld1q_s8((src_ic_0_3 + 14 * 16));
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);

src[5] = vld1q_s8((src_ic_0_3 + 15 * 16));
src[6] = vld1q_s8((src_ic_0_3 + 16 * 16));
src[7] = vld1q_s8((src_ic_0_3 + 17 * 16));
src[8] = vld1q_s8((src_ic_0_3 + 18 * 16));
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int c_dim,
typename DstType>
struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih,
int iw, const Op& op, int ld_dst_oc) {
constexpr int filter_size = 7;
constexpr int fh = filter_size;
constexpr int fw = filter_size;
constexpr int oc_step = 4;
constexpr int ic_step = 4;
constexpr int loop_ic_step = 4;
constexpr int ld_weight_ic4 = 16;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;

int32x4_t c[c_dim][8];
int8x16_t weight[7];
int8x16_t src[8 + 2];
int16x8_t temp_c[4];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
for (int fh_idx = 0; fh_idx < fh; ++fh_idx) {
const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;

src[0] = vld1q_s8(src_ic_0_3);
src[1] = vld1q_s8(src_ic_0_3 + 1 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 2 * 16);
src[3] = vld1q_s8(src_ic_0_3 + 3 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 4 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 5 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 6 * 16);
src[7] = vld1q_s8(src_ic_0_3 + 7 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 8 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 9 * 16);

// oc == 0
const int8_t* read_weight_ptr =
weight_ptr + fh_idx * fw * ld_weight_ic4;

weight[0] = vld1q_s8(read_weight_ptr);
weight[1] = vld1q_s8(read_weight_ptr + 16);
weight[2] = vld1q_s8(read_weight_ptr + 2 * 16);
weight[3] = vld1q_s8(read_weight_ptr + 3 * 16);
weight[4] = vld1q_s8(read_weight_ptr + 4 * 16);
weight[5] = vld1q_s8(read_weight_ptr + 5 * 16);
weight[6] = vld1q_s8(read_weight_ptr + 6 * 16);

c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]);
c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]);
c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]);
c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]);
c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]);

src[0] = vld1q_s8(src_ic_0_3 + 10 * 16);
src[1] = vld1q_s8(src_ic_0_3 + 11 * 16);
src[2] = vld1q_s8(src_ic_0_3 + 12 * 16);
c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]);
c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]);
c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]);
c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]);
c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]);

src[3] = vld1q_s8(src_ic_0_3 + 13 * 16);
src[4] = vld1q_s8(src_ic_0_3 + 14 * 16);
src[5] = vld1q_s8(src_ic_0_3 + 15 * 16);
src[6] = vld1q_s8(src_ic_0_3 + 16 * 16);
c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]);
c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]);
c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]);
c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]);
c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]);

src[7] = vld1q_s8(src_ic_0_3 + 17 * 16);
src[8] = vld1q_s8(src_ic_0_3 + 18 * 16);
src[9] = vld1q_s8(src_ic_0_3 + 19 * 16);
src[0] = vld1q_s8(src_ic_0_3 + 20 * 16);
c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]);
c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]);
c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]);
c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]);
c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]);
}
weight_ptr += fh * fw * ld_weight_ic4;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
void conv_direct_stride2_2x2_int8_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t filter_size = 2;
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t big_oc_step = 8;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;

const size_t out_img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;

switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, step, \
filter_size, 2, DstType>; \
kern_small_oc_remain = \
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, step, \
filter_size, 1, DstType>; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, ow_step,
filter_size, 2, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
const size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_idx) * oc_step;
ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, ow_step,
filter_size, 1, DstType>(
src + src_offset, filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset = oc_idx * out_img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
void conv_direct_stride2_int8_nchw44_kern(
const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*,
DstType* dst, const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow, const Op& op) {
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 4;
constexpr size_t oc_step = 4;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const int ld_dst_oc = oh * ow * oc_step;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw,
const Op& op, int ld_dst_oc)>;

remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_small_oc_remain = \
KerNeonDirectStride2Int8<bias_mode, Op, step, filter_size, 1, \
DstType>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}
#undef cb

for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDirectStride2Int8<bias_mode, Op, ow_step, filter_size, 1,
DstType>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, op, ld_dst_oc);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w) * ic_step *
pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, op, ld_dst_oc);
}
}
}
}
} // namespace

namespace int8_direct_nchw44 {
template <BiasMode bias_mode, typename Op, int filter_size, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride2_int8_nchw44_kern<bias_mode, Op, filter_size,
DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

template <BiasMode bias_mode, typename Op, typename DstType>
struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 2> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, DstType* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
conv_direct_stride2_2x2_int8_nchw44<bias_mode, Op, DstType>(
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op);
}
};

#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \
template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \
DstType, stride>;

#define GET_OP_PARAM(stride, filter, bias_mode) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>)

#define GET_BIAS_MODE_PARAM(stride, filter) \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \
GET_BIAS_MODE_PARAM(stride, 7)

DISPATCH_CONV_KERN(2);

} // namespace int8_direct_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 47
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h View File

@@ -0,0 +1,47 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace {

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block, int stride>
struct KerNeonXXs2NchwNchw44 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op);
};

template <int oc>
struct OCHelper {
public:
static const int val = 0;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};

} // namespace
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 561
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp View File

@@ -0,0 +1,561 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
namespace megdnn {
namespace arm_common {
namespace {
/**
* @brief core code for calculation patten
*
* @tparam src_idx is offset of src reg
* @tparam weight_idx is offset of weight reg
* @tparam c_dim is output channel
* @tparam Func mla operation funcion
* @tparam stride
* @tparam T outpur regs type
* @tparam T2 src regs type
* @tparam T3 weight regs type
* @tparam T4 temp regs type
*/

template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp);
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3, typename T4>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl(
c, src, weight, temp);
}
template <int src_idx, int weight_idx, int c_dim, int stride, typename T,
typename T2, typename T3>
MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl(
c, src, weight);
};
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[1][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[1][weight_idx],
c[1][0], temp[1]);
c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[2]);
c[1][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[1][weight_idx],
c[1][1], temp[3]);
c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[0]);
c[1][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[1][weight_idx],
c[1][2], temp[1]);
c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[2]);
c[1][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[1][weight_idx],
c[1][3], temp[3]);

c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[1][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[1][weight_idx],
c[1][4], temp[1]);
c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[2]);
c[1][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[1][weight_idx],
c[1][5], temp[3]);
c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[0]);
c[1][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[1][weight_idx],
c[1][6], temp[1]);
c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[2]);
c[1][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[1][weight_idx],
c[1][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};
template <int src_idx, int weight_idx, typename T, typename T2, typename T3,
typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx],
c[0][0], temp[0]);
c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx],
c[0][1], temp[1]);
c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx],
c[0][2], temp[2]);
c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx],
c[0][3], temp[3]);
c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx],
c[0][4], temp[0]);
c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx],
c[0][5], temp[1]);
c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx],
c[0][6], temp[2]);
c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx],
c[0][7], temp[3]);
}
static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&);
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 2;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 3;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 1 * filter_width * oc_step,
ld_weight_oc);

load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr + 2 * filter_width * oc_step,
ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 5;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);
UNROLL_CALL_RAW(5, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 7;
constexpr int filter_width = 8;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 2;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
#define cb(step) \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
dot4_weight, weight_ptr + step * filter_width * oc_step, \
ld_weight_oc); \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, nchw_src_ptr + step * iw * pack_iw_len, 0); \
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \
load_helper<4, 0, simd_len, 0, Vld1q_s8>( \
src, \
nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \
0); \
cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c);

UNROLL_CALL_RAW(7, cb);
#undef cb
weight_ptr += oc_step * filter_height * filter_width;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
} // namespace

namespace int8_direct_nchw_nchw44 {
/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)}
* pack interleave two adjacent row in filter to one row
* */
template <>
void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr,
const int ic, const int fh,
const int fw, const int oc) {
constexpr int oc_step = 4;
const int fw2 = round_up(fw, 4);
const int fw_remain = fw2 - fw;
const int dst_ic_stride = fh * fw2;
const int oc_step_stride = fh * fw2 * ic * oc_step;
static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7,
8, 12, 9, 13, 10, 14, 11, 15};
uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]);
rep_step(oc_idx, oc, oc_step) {
int32_t* dst_temp_ptr =
reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2);
const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>(
src_ptr + oc_idx * ic * fh * fw);
// transpose ic and pad
rep(fh_idx, fh) {
rep(fw_idx, fw) {
rep(ic_idx, ic) {
*(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr;
src_temp_ptr++;
}
dst_temp_ptr++;
}
rep(ic_idx, ic) {
memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0,
sizeof(int8_t) * oc_step * fw_remain);
}
dst_temp_ptr += fw_remain;
}
// transpose fw oc
int8_t* trans_dst_temp_ptr =
reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2);

rep_step(idx, oc_step_stride, 16) {
int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx);
vst1q_s8(trans_dst_temp_ptr + idx,
vqtbl1q_s8(temp, tbl_transpose_4x4));
}
}
};

/**
* pack (ic, h, w) to (ic, h, w * 16)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
template <>
void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin,
int8_t* sptr_base, const int ic,
const int pad_top, const int pad_bottom,
const int, const int, const int ih,
const int iw, const int iw2, const int pw,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1,
2, 3, 2, 3, 2, 3, 2, 3};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 4;
constexpr int pack_iw_len = 16;
const int ic_stride = ih * iw;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 1);
src[2] = vld1q_s8(temp_ptr + iw_idx + 2);
src[3] = vld1q_s8(temp_ptr + iw_idx + 3);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
int8x16_t src = vld1q_s8(temp_ptr + iw_idx);
int8x16_t dst = vqtbl1q_s8(src, tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}

template <BiasMode bias_mode, typename Op, size_t filter_size>
struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
static void impl(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, int8_t* dst,
const size_t oc, const size_t ic, const size_t ih,
const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int stride = 1;
constexpr size_t fh = filter_size;
constexpr size_t fw = (filter_size + 3) / 4 * 4;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 1;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 8;
constexpr size_t stride_h = stride;
constexpr size_t stride_w = stride;
constexpr int pack_iw_len = 16;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun = std::function<void(
const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;

KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
big_oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}

if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size,
oc_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset = (oh_idx * stride_h * iw +
ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset = oc_idx * img_stride +
(oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset,
filter + weight_offset, bias + oc_idx,
dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
}
};

#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \
template struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, \
stride>;

#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define INSTANCE_BIAS_MODE_PARAM(stride, filter) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \
INSTANCE_BIAS_MODE_PARAM(stride, 7)

INSTANCE_CONV_KERN(1);

} // namespace int8_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 1412
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
File diff suppressed because it is too large
View File


+ 26
- 57
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp View File

@@ -114,7 +114,7 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
rep(ih_idx, IH) { rep(ih_idx, IH) {
std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t));
sptr_base += nr_pad_w; sptr_base += nr_pad_w;
nchw44_pack_src(sptr, sptr_base, IW);
int8_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW);
sptr_base += IW * pack_ic * expend_element; sptr_base += IW * pack_ic * expend_element;
sptr += IW * pack_ic; sptr += IW * pack_ic;
std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t));
@@ -125,8 +125,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
} }
} }


template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain,
typename DstType, int stride>
template <size_t filter, BiasMode bias_mode, typename Op, typename DstType,
int stride>
static void do_conv_kern(const WorkspaceBundle& bundle, static void do_conv_kern(const WorkspaceBundle& bundle,
const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const ConvBiasImpl::NCBKernIndex& ncb_index,
@@ -182,8 +182,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;
nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW);
conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>(
int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight,
oc_block / 4 * IC / 4 * FH * FW);
int8_direct_nchw44::conv_direct_int8_nchw44<bias_mode, Op, filter, DstType,
stride>(
sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst),
oc_block, IC, IH2, IW2, OH, OW, op); oc_block, IC, IH2, IW2, OH, OW, op);
} }
@@ -233,40 +235,38 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
size_t N = param.n; size_t N = param.n;
size_t IC = fm.icpg; size_t IC = fm.icpg;
size_t OC = fm.ocpg; size_t OC = fm.ocpg;
size_t OW = param.osz[1];
size_t group = fm.group; size_t group = fm.group;
size_t fh = fm.spatial[0]; size_t fh = fm.spatial[0];
size_t fw = fm.spatial[1]; size_t fw = fm.spatial[1];
WorkspaceBundle wbundle = get_bundle(param); WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr; conv_fun do_conv_fun = nullptr;
int ow_remain = OW % 8;
bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing // NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime // shape runtime
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \
#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \
midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \
stride>; \
do_conv_fun = do_conv_kern<filter, bias_mode, op, dst_type, stride>; \
} \ } \
MIDOUT_END(); MIDOUT_END();


#define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \
#define GET_OP_PARAM(stride, filter, bias_mode) \
if (need_post_process) { \ if (need_post_process) { \
switch (param.nonlineMode) { \ switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \ case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \ break; \
case param::ConvBias::NonlineMode::RELU: \ case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \ break; \
case param::ConvBias::NonlineMode::H_SWISH: \ case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \
remain_w, \
\
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \ break; \
default: \ default: \
@@ -277,7 +277,7 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
switch (param.nonlineMode) { \ switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \ case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \
remain_w, NoneOp<dt_int32>) \
NoneOp<dt_int32>) \
break; \ break; \
default: \ default: \
megdnn_assert( \ megdnn_assert( \
@@ -287,48 +287,17 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns(
} \ } \
} }


#define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \
switch (ow_remain) { \
case 0: \
GET_OP_PARAM(stride, filter, bias_mode, 0); \
break; \
case 1: \
GET_OP_PARAM(stride, filter, bias_mode, 1); \
break; \
case 2: \
GET_OP_PARAM(stride, filter, bias_mode, 2); \
break; \
case 3: \
GET_OP_PARAM(stride, filter, bias_mode, 3); \
break; \
case 4: \
GET_OP_PARAM(stride, filter, bias_mode, 4); \
break; \
case 5: \
GET_OP_PARAM(stride, filter, bias_mode, 5); \
break; \
case 6: \
GET_OP_PARAM(stride, filter, bias_mode, 6); \
break; \
case 7: \
GET_OP_PARAM(stride, filter, bias_mode, 7); \
break; \
default: \
megdnn_assert(0); \
}

#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_REMAIN_W_PARAM(stride, filter, \
BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
} }


#define DISPATCH_CONV_KERN(stride) \ #define DISPATCH_CONV_KERN(stride) \


+ 7
- 1337
dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h
File diff suppressed because it is too large
View File


+ 10
- 9
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -117,11 +117,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle,
const size_t tmp_size = get_temp_bytes(iw, pw); const size_t tmp_size = get_temp_bytes(iw, pw);
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size; ncb_index.thread_id * tmp_size;
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, tmp_ptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<1>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, tmp_ptr);
} else { } else {
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, nullptr);
int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<2>(
sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr);
} }
} }
static void pack_weight(const WorkspaceBundle& bundle, static void pack_weight(const WorkspaceBundle& bundle,
@@ -142,11 +142,11 @@ static void pack_weight(const WorkspaceBundle& bundle,
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2;


if (stride_h == 1) { if (stride_h == 1) {
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<1>(
fptr, packed_weight, ic, fh, fw, oc_block);
} else { } else {
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw,
oc_block);
int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<2>(
fptr, packed_weight, ic, fh, fw, oc_block);
} }
} }
template <size_t filter, BiasMode bias_mode, typename Op, int stride> template <size_t filter, BiasMode bias_mode, typename Op, int stride>
@@ -208,7 +208,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 + group_id * oc * ic * fh * fw2 +
oc_idx * ic * fh * fw2; oc_idx * ic * fh * fw2;
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>(
int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44<bias_mode, Op, filter,
stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh,
ow, op); ow, op);
} }


+ 3
- 1854
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
File diff suppressed because it is too large
View File


+ 5
- 4
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -93,8 +93,8 @@ void do_weight_trans(const WorkspaceBundle& bundle,
const int fw2 = round_up(fw, 4); const int fw2 = round_up(fw, 4);
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)); auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1));
auto origin_weight = kern_param.filter<dt_int8>(); auto origin_weight = kern_param.filter<dt_int8>();
pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh,
fw, fw2);
dot_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44_dot(
packed_weight, origin_weight, oc, ic, fh, fw, fw2);
} }


template <size_t filter, BiasMode bias_mode, typename Op, int stride> template <size_t filter, BiasMode bias_mode, typename Op, int stride>
@@ -147,7 +147,7 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size; ncb_index.thread_id * tmp_size;
} }
pack_src_int8_nchw_nchw44_dot<stride>(
dot_direct_nchw_nchw44::pack_src_int8_nchw_nchw44_dot<stride>(
sptr, origin_sptr, ph, pw, remain_right_pad, sptr, origin_sptr, ph, pw, remain_right_pad,
ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad,
src_bottom_pad, ic, ih * iw, tmp_ptr); src_bottom_pad, ic, ih * iw, tmp_ptr);
@@ -164,7 +164,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
Op op(scale_bias, scale_dst); Op op(scale_bias, scale_dst);
conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>(
dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot<bias_mode, Op,
filter, stride>(
sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh,
oh_block_real, ow, op); oh_block_real, ow, op);
} }


+ 14
- 662
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h View File

@@ -20,83 +20,15 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


using namespace megdnn;
using namespace arm_common;
namespace {
namespace megdnn {
namespace arm_common {
namespace dot_direct_nchw_nchw44 {
template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block,
int stride, typename T, typename T2, typename T3, typename T4> int stride, typename T, typename T2, typename T3, typename T4>
struct ShiftCalHelper { struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight); static void impl(T& c, T2& src, T3& weight);
}; };


template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2], weight[1][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]); \
c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[1][step * 2 + 1], weight[1][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, int stride, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, stride, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2], weight[0][weight_idx], \
src[0][(src_idx + step) / 4]); \
c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \
c[0][step * 2 + 1], weight[0][weight_idx], \
src[1][(src_idx + step) / 4]);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \
c[1][step] = Func::template impl<(src_idx + step) % 4>( \
c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \
c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]);

UNROLL_CALL_RAW(8, cb);
#undef cb
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block,
int stride, typename T, typename T2, typename T3> int stride, typename T, typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) { inline void cal_helper(T& c, T2& src, T3& weight) {
@@ -133,490 +65,12 @@ struct KerNeonDotXXs2Nchw44Int8 {
int iw, int ld_dst_oc, const Op& op); int iw, int ld_dst_oc, const Op& op);
}; };


template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 1;
constexpr int src_reg = 1;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 0 * iw, stride);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 1 * iw, stride);
load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(
src, src_ptr + 2 * iw, stride);
load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 5 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

/**
* oc = 8, ow = 8
* dot 4 element, pad last filter and do twice dot every row filter, filter like
* below
* --------------------------
* |x, x, x, x,| x, x, x, 0 |
* --------------------------
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block, int stride>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
stride> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 2;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 1;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[2][src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \
stride); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);
UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += 7 * 32;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
////////////////////stride 1///////////////////
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 2;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 3;
constexpr int filter_width = 4;
constexpr int weight_reg = 3;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 1
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 1 * iw * pack_iw_len, 0);
cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);
// row 2
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 2 * iw * pack_iw_len, 0);
cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 5;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(5, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 7;
constexpr int filter_width = 8;
constexpr int src_reg = 3;
constexpr int weight_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
#define cb(step) \
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \
src, src_ptr + step * iw * pack_iw_len, 0); \
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \
weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \
weight); \
cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight);

UNROLL_CALL_RAW(7, cb);
#undef cb
src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <int stride> template <int stride>
void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin,
const int, const int pw, const int, const int, const int pw, const int,
const int ih, const int iw, const int iw2, const int ih, const int iw, const int iw2,
const int pad_top, const int pad_bottom, const int pad_top, const int pad_bottom,
const int ic, const int ic_stride, int8_t*) {
constexpr int ic_step = 1;
rep_step(ic_idx, ic, ic_step) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom));
sptr_base += iw2 * pad_top * ic_step;
rep(ih_idx, ih) {
memcpy(sptr_base + pw * ic_step, sptr,
sizeof(int8_t) * iw * ic_step);
sptr_base += iw2 * ic_step;
sptr += iw * ic_step;
}
sptr_base += iw2 * pad_bottom * ic_step;
}
}

template <>
void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
const int8_t* sptr_origin, const int,
const int pw, const int, const int ih,
const int iw, const int iw2,
const int pad_top, const int pad_bottom,
const int ic, const int ic_stride,
int8_t* temp_ptr) {
static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]);

constexpr int iw_step = 16;
constexpr int pack_iw_len = 4;
const int iw_with_pad = iw + 2 * pw;
const int iw_with_pad_end = iw_with_pad / iw_step * iw_step;
rep(ic_idx, ic) {
const int8_t* sptr = sptr_origin + ic_idx * ic_stride;
memset(sptr_base, 0,
sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) *
pack_iw_len);
sptr_base += iw2 * pad_top * pack_iw_len;
rep(ih_idx, ih) {
memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t));
memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw);
for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) {
int8x16_t src[4];
int8x16_t dst[4];
src[0] = vld1q_s8(temp_ptr + iw_idx);
src[1] = vld1q_s8(temp_ptr + iw_idx + 4);
src[2] = vld1q_s8(temp_ptr + iw_idx + 8);
src[3] = vld1q_s8(temp_ptr + iw_idx + 12);
dst[0] = vqtbl1q_s8(src[0], tbl_idx);
dst[1] = vqtbl1q_s8(src[1], tbl_idx);
dst[2] = vqtbl1q_s8(src[2], tbl_idx);
dst[3] = vqtbl1q_s8(src[3], tbl_idx);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]);
vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]);
}
for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) {
*(sptr_base + iw_idx * pack_iw_len + 0) =
*(temp_ptr + iw_idx + 0);
*(sptr_base + iw_idx * pack_iw_len + 1) =
*(temp_ptr + iw_idx + 1);
*(sptr_base + iw_idx * pack_iw_len + 2) =
*(temp_ptr + iw_idx + 2);
*(sptr_base + iw_idx * pack_iw_len + 3) =
*(temp_ptr + iw_idx + 3);
}
sptr_base += iw2 * pack_iw_len;
sptr += iw;
}
sptr_base += iw2 * pad_bottom * pack_iw_len;
}
}
const int ic, const int ic_stride, int8_t*);


static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr,
const int8_t* src_ptr, const int8_t* src_ptr,
@@ -663,117 +117,15 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr,
} }


template <BiasMode bias_mode, typename Op, int filter_size, int stride> template <BiasMode bias_mode, typename Op, int filter_size, int stride>
static void conv_direct_int8_nchw_nchw44_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih,
const int iw, const int oh, const int oh_block, const int ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr int fh = filter_size;
constexpr int fw = (filter_size + 3) / 4 * 4;
#if MEGDNN_AARCH64
constexpr int big_oc_step = 8;
#else
constexpr int big_oc_step = 4;
#endif
constexpr int oc_step = 4;
constexpr int ih_step = 1;
constexpr int oh_step = 1;
constexpr int ow_step = 8;
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int pack_iw_len = stride == 2 ? 1 : 4;

const int img_stride = oh * ow;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
const int oc_end = oc / big_oc_step * big_oc_step;
const int oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
big_oc_step, ow_step, stride>::impl; \
kern_small_oc_remain = \
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \
oc_step, ow_step, stride>::impl; \
break;

UNROLL_CALL_RAW(8, cb);
default:
megdnn_assert(0, "no remain %d for kern", ow_remain);
}

for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
big_oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
int oc_idx = oc_end;
const int weight_offset = oc_idx * ic * fh * fw;
for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) {
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size,
oc_step, ow_step,
stride>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const int src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
pack_iw_len;
const int dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}

} // namespace
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic,
const int ih, const int iw, const int oh,
const int oh_block, const int ow,
const Op& op);

} // namespace dot_direct_nchw_nchw44
} // namespace arm_common
} // namespace megdnn
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 2
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -2344,7 +2344,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif #endif
std::vector<conv_bias::TestArg> gemv_args; std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args) for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg); gemv_args.emplace_back(arg);
} }
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
@@ -2361,7 +2361,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
#endif #endif
std::vector<conv_bias::TestArg> gemv_args; std::vector<conv_bias::TestArg> gemv_args;
for (auto&& arg : args) for (auto&& arg : args)
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) {
gemv_args.emplace_back(arg); gemv_args.emplace_back(arg);
} }
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");


+ 23
- 27
dnn/test/arm_common/matrix_mul.cpp View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "test/arm_common/fixture.h" #include "test/arm_common/fixture.h"


@@ -30,8 +31,7 @@ TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) {


TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) { TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) {
matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127), matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127),
dtype::Quantized8Asymm(1.3f, (uint8_t)129),
{},
dtype::Quantized8Asymm(1.3f, (uint8_t)129), {},
handle()); handle());
} }


@@ -232,8 +232,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
Checker<MatrixMul> checker(handle()); Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param; using Param = MatrixMul::Param;


checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));


std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127);
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); checker.set_rng(0, rng.get()).set_rng(1, rng.get());
@@ -251,7 +250,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {
.set_dtype(2, dtype::QuantizedS32(6.25f)) .set_dtype(2, dtype::QuantizedS32(6.25f))
.execs({A, B, {}}); .execs({A, B, {}});
}; };
// M = 1 // M = 1
for (size_t N : {1, 10, 16, 33, 64}) for (size_t N : {1, 10, 16, 33, 64})
for (size_t K : {7, 512, 1024}) for (size_t K : {7, 512, 1024})
@@ -263,8 +262,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
Checker<MatrixMul> checker(handle()); Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param; using Param = MatrixMul::Param;


checker.set_before_exec_callback(
AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));
checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM"));


checker.set_epsilon(1e-2); checker.set_epsilon(1e-2);
auto run = [&](size_t M, size_t K, size_t N) { auto run = [&](size_t M, size_t K, size_t N) {
@@ -276,7 +274,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) {
B = TensorShape{N, K}; B = TensorShape{N, K};
checker.set_param(param).execs({A, B, {}}); checker.set_param(param).execs({A, B, {}});
}; };
// M = 1 // M = 1
for (size_t M : {1}) for (size_t M : {1})
for (size_t K : {1000, 4096, 25088}) for (size_t K : {1000, 4096, 25088})
@@ -298,15 +296,15 @@ TEST_F(ARM_COMMON, FP32_GEMV_MK4) {
param.transposeA = false; param.transposeA = false;
param.transposeB = false; param.transposeB = false;
TensorShape A, B; TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
checker.set_param(param).execs({A, B, {}}); checker.set_param(param).execs({A, B, {}});
}; };
// N = 1 // N = 1
for (size_t M : {4, 16, 128, 1024}) for (size_t M : {4, 16, 128, 1024})
for (size_t K : {4, 8, 12, 128, 256, 4096}) for (size_t K : {4, 8, 12, 128, 256, 4096})
run(M, K);
run(M, K);
} }


#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
@@ -343,7 +341,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) {


for (size_t M : {4, 64, 1024, 4096}) for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 256, 1024, 4096}) for (size_t K : {128, 256, 1024, 4096})
run(M, K, 1);
run(M, K, 1);
} }


TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
@@ -372,7 +370,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) {
.exec({{2, 1024}, {1024, 512}, {}}); .exec({{2, 1024}, {1024, 512}, {}});
benchmarker.set_display(true); benchmarker.set_display(true);
} }
// run gemv // run gemv
run(12, 48, 1); run(12, 48, 1);
run(48, 12, 1); run(48, 12, 1);
@@ -396,14 +394,14 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
Benchmarker<MatrixMul> benchmarker(handle()); Benchmarker<MatrixMul> benchmarker(handle());
benchmarker.set_times(exec_times); benchmarker.set_times(exec_times);
benchmarker.set_dtype(0, dtype::Float32()) benchmarker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_param(param);
.set_dtype(1, dtype::Float32())
.set_param(param);


auto run = [&](size_t M, size_t K) { auto run = [&](size_t M, size_t K) {
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N);
printf("SGEMV_MK4: (%zu, %zu)\n", M, K);
TensorShape A, B; TensorShape A, B;
A = TensorShape{M/4, K/4, 4, 4};
B = TensorShape{K/4, 1, 4};
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};
auto time = benchmarker.exec({A, B, {}}) / exec_times; auto time = benchmarker.exec({A, B, {}}) / exec_times;
auto computations = 2.f * M * K * 1e-6; auto computations = 2.f * M * K * 1e-6;
auto perf = computations / time; auto perf = computations / time;
@@ -422,7 +420,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
// run gemv mk4 // run gemv mk4
for (size_t M : {4, 64, 1024, 4096}) for (size_t M : {4, 64, 1024, 4096})
for (size_t K : {128, 1024, 4096}) for (size_t K : {128, 1024, 4096})
run(M, K);
run(M, K);
} }


TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) {
@@ -490,7 +488,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
//////////////////////// gemv ////////////////////////// //////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) { for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) { for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
} }
} }


@@ -502,10 +500,8 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) {
} }
} }
} }

} }



TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
constexpr size_t RUNS = 50; constexpr size_t RUNS = 50;
param::MatrixMul param; param::MatrixMul param;
@@ -514,7 +510,8 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
.set_dtype(0, dtype::Int8{}) .set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{}) .set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{}) .set_dtype(2, dtype::Int32{})
.set_param(param).set_display(false);
.set_param(param)
.set_display(false);
Benchmarker<MatrixMul> benchmarker_float(handle()); Benchmarker<MatrixMul> benchmarker_float(handle());
benchmarker_float.set_display(false).set_times(RUNS); benchmarker_float.set_display(false).set_times(RUNS);


@@ -533,7 +530,7 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) {
//////////////////////// gemv ////////////////////////// //////////////////////// gemv //////////////////////////
for (size_t M : {8, 64, 112, 256}) { for (size_t M : {8, 64, 112, 256}) {
for (size_t K : {8, 64, 112, 256}) { for (size_t K : {8, 64, 112, 256}) {
run (M, 1, K);
run(M, 1, K);
} }
} }


@@ -618,5 +615,4 @@ TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) {


#endif #endif



// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

Loading…
Cancel
Save