/** * \file dnn/src/arm_common/resize/upsample2_nchwxx.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/arm_common/resize/upsample2_nchwxx.h" #include "src/arm_common/resize/helper.h" #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; using namespace arm_common; using namespace resize; namespace { template static inline typename simd_helper::simd_type compute_linear_element( const typename simd_helper::simd_type src[4], const typename simd_helper::simd_type alpha[2][2]) { typename simd_helper::simd_type c = simd_helper::dup(0); c = simd_helper::fma(c, src[0], alpha[0 ^ fh][0 ^ fw]); c = simd_helper::fma(c, src[1], alpha[0 ^ fh][1 ^ fw]); c = simd_helper::fma(c, src[2], alpha[1 ^ fh][0 ^ fw]); c = simd_helper::fma(c, src[3], alpha[1 ^ fh][1 ^ fw]); return c; } template static inline void compute_linear_2x2_element( const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { constexpr size_t PC = simd_helper::simd_width; const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src}; if (has_right) { src_ptr[1] += PC; src_ptr[3] += PC; } if (has_bottom) { src_ptr[2] += IW * PC; src_ptr[3] += IW * PC; } typename simd_helper::simd_type rsrc[4]; rsrc[0] = simd_helper::load(src_ptr[0]); rsrc[1] = simd_helper::load(src_ptr[1]); rsrc[2] = simd_helper::load(src_ptr[2]); rsrc[3] = simd_helper::load(src_ptr[3]); typename simd_helper::simd_type rdst[4]; rdst[0] = compute_linear_element(rsrc, alpha); rdst[1] = compute_linear_element(rsrc, alpha); rdst[2] = compute_linear_element(rsrc, alpha); rdst[3] = compute_linear_element(rsrc, alpha); simd_helper::store(dst, rdst[0]); if (has_right) { simd_helper::store(dst + PC, rdst[1]); } if (has_bottom) { simd_helper::store(dst + OW * PC, rdst[2]); } if (has_right && has_bottom) { simd_helper::store(dst + (OW + 1) * PC, rdst[3]); } } template void linear_upsample2_nchwxx( const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; typename simd_helper::simd_type alpha[2][2]; alpha[0][0] = simd_helper::dup(0.75 * 0.75); alpha[0][1] = simd_helper::dup(0.75 * 0.25); alpha[1][0] = simd_helper::dup(0.25 * 0.75); alpha[1][1] = simd_helper::dup(0.25 * 0.25); for (size_t i = 0; i < N; ++i) { compute_linear_2x2_element( src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } } compute_linear_2x2_element( src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); dst_ptr += OW * PC; for (size_t ih = 0; ih + 1 < IH; ++ih) { compute_linear_2x2_element( src_ptr, dst_ptr, IW, OW, alpha); for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } compute_linear_2x2_element( src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); src_ptr += IW * PC; dst_ptr += 2 * OW * PC; } compute_linear_2x2_element( src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } } compute_linear_2x2_element( src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); src_ptr += IW * PC; dst_ptr += OW * PC; } } template void nearest_upsample2_nchwxx( const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; for (size_t i = 0; i < N; ++i) { for (size_t ih = 0; ih < IH; ++ih) { for (size_t iw = 0; iw < IW; ++iw) { typename simd_helper::simd_type r0 = simd_helper::load(src_ptr + iw * PC); simd_helper::store(dst_ptr + (iw * 2) * PC, r0); simd_helper::store(dst_ptr + (iw * 2 + 1) * PC, r0); simd_helper::store(dst_ptr + (OW + iw * 2) * PC, r0); simd_helper::store(dst_ptr + (OW + iw * 2 + 1) * PC, r0); } src_ptr += IW * PC; dst_ptr += 2 * OW * PC; } } } } // namespace void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( const ResizeImpl::KernParam& kern_param) { linear_upsample2_nchwxx( kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( const ResizeImpl::KernParam& kern_param) { nearest_upsample2_nchwxx( kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, kern_param.ih, kern_param.iw); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr.get_ptr()); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); linear_upsample2_nchwxx( sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr.get_ptr()); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); nearest_upsample2_nchwxx( sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); } #endif