/** * \file dnn/src/arm_common/utils.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/common/utils.h" #include #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; namespace { template void transpose_naive(const dtype *src, dtype *dst, int lda, int ldb, int n, int m) { rep(i, n) rep(j, m) { dst[i*ldb + j] = src[j*lda + i]; } } void transpose_4x4_neon(const float *src, float *dst, int lda, int ldb) { float32x4x2_t a0, a1; a0.val[0] = vld1q_f32(src + 0*lda); a0.val[1] = vld1q_f32(src + 1*lda); a1.val[0] = vld1q_f32(src + 2*lda); a1.val[1] = vld1q_f32(src + 3*lda); float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); vst1q_f32(dst + 0*ldb, c0.val[0]); vst1q_f32(dst + 1*ldb, c0.val[1]); vst1q_f32(dst + 2*ldb, c1.val[0]); vst1q_f32(dst + 3*ldb, c1.val[1]); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC void transpose_8x8_neon(const dt_float16 *src, dt_float16 *dst, int lda, int ldb) { const __fp16* src_ptr = reinterpret_cast(src); __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst); float16x8x4_t a0, a1; a0.val[0] = vld1q_f16(src_ptr + 0*lda); // A0A1A2A3A4A5A6A7 a0.val[1] = vld1q_f16(src_ptr + 1*lda); // B0B1B2B3B4B5B6B7 a0.val[2] = vld1q_f16(src_ptr + 2*lda); // C0C1C2C3C4C5C6C7 a0.val[3] = vld1q_f16(src_ptr + 3*lda); // D0D1D2D3D4D5D6D7 a1.val[0] = vld1q_f16(src_ptr + 4*lda); // E0E1E2E3E4E5E6E7 a1.val[1] = vld1q_f16(src_ptr + 5*lda); // F0F1F2F3F4F5F6F7 a1.val[2] = vld1q_f16(src_ptr + 6*lda); // G0G1G2G3G4G5G6G7 a1.val[3] = vld1q_f16(src_ptr + 7*lda); // H0H1H2H3H4H5H6H7 float16x8x2_t b0 = vzipq_f16(a0.val[0], a1.val[0]); // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7 float16x8x2_t b1 = vzipq_f16(a0.val[2], a1.val[2]); // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7 float16x8x2_t c0 = vzipq_f16(a0.val[1], a1.val[1]); // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7 float16x8x2_t c1 = vzipq_f16(a0.val[3], a1.val[3]); // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7 float16x8x2_t d0 = vzipq_f16(b0.val[0], b1.val[0]); // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3 float16x8x2_t d1 = vzipq_f16(c0.val[0], c1.val[0]); // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3 float16x8x2_t e0 = vzipq_f16(d0.val[0], d1.val[0]); // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1 float16x8x2_t e1 = vzipq_f16(d0.val[1], d1.val[1]); // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3 float16x8x2_t f0 = vzipq_f16(b0.val[1], b1.val[1]); // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7 float16x8x2_t f1 = vzipq_f16(c0.val[1], c1.val[1]); // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7 float16x8x2_t g0 = vzipq_f16(f0.val[0], f1.val[0]); // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5 float16x8x2_t g1 = vzipq_f16(f0.val[1], f1.val[1]); // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7 vst1q_f16(dst_ptr + 0*ldb, e0.val[0]); vst1q_f16(dst_ptr + 1*ldb, e0.val[1]); vst1q_f16(dst_ptr + 2*ldb, e1.val[0]); vst1q_f16(dst_ptr + 3*ldb, e1.val[1]); vst1q_f16(dst_ptr + 4*ldb, g0.val[0]); vst1q_f16(dst_ptr + 5*ldb, g0.val[1]); vst1q_f16(dst_ptr + 6*ldb, g1.val[0]); vst1q_f16(dst_ptr + 7*ldb, g1.val[1]); } #endif } // anonymous namespace namespace megdnn { template <> void transpose(const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds, ptrdiff_t ldd) { if (lds == -1) { lds = n; } if (ldd == -1) { ldd = m; } for (size_t is = 0; is < n; is += 16) { for (size_t js = 0; js < m; js += 16) { auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is; for (; i + 4 <= ie; i += 4) { auto j = js; for (; j + 4 <= je; j += 4) { transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd); } if (j < je) { transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j); } } if (i < ie) { transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i, je - js); } } } } template void transpose_knc2nsck_helper(const dtype *src, dtype *dst, size_t k, size_t n, size_t c, size_t n_stride) { if (n_stride == k * c) { // dst is contiguous transpose(src, dst, k, n * c); } else { for (size_t i = 0; i < n; ++ i) { transpose(src + i * c, dst + i * n_stride, k, c, n * c); } } } template <> void transpose_knc2nsck(const float *src, float *dst, size_t k, size_t n, size_t c, size_t n_stride) { transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <> void transpose(const dt_float16* src, dt_float16* dst, size_t m, size_t n, ptrdiff_t lds, ptrdiff_t ldd) { if (lds == -1) { lds = n; } if (ldd == -1) { ldd = m; } for (size_t is = 0; is < n; is += 16) { for (size_t js = 0; js < m; js += 16) { auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is; for (; i + 8 <= ie; i += 8) { auto j = js; for (; j + 8 <= je; j += 8) { transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd); } if (j < je) { transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, ldd, 8, je - j); } } if (i < ie) { transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i, je - js); } } } } template <> void transpose_knc2nsck(const dt_float16* src, dt_float16* dst, size_t k, size_t n, size_t c, size_t n_stride) { transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); } #endif } // namespace megdnn // vim: syntax=cpp.doxygen