GitOrigin-RevId: 2fe469bb4e
tags/v1.7.2.m1
| @@ -497,7 +497,16 @@ pdef('ElemwiseMultiType').add_enum( | |||
| Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'), | |||
| Doc('QH_SWISH = 51', 'quantized h_swish'), | |||
| Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'), | |||
| Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad') | |||
| Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad'), | |||
| Doc('FUSE_MUL_ADD3_INT16xF32xF32xF32 = 54', | |||
| 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and ' | |||
| '``c`` float32, and the result is float32.'), | |||
| Doc('MUL_INT16xF32xF32 = 55', | |||
| 'compute ``a * b `` requiring that ``a`` be int16 and ``b`` float32, ' | |||
| 'and the result is float32.'), | |||
| Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56', | |||
| 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and ' | |||
| '``c`` float32, and the result is float32.') | |||
| ) | |||
| pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||
| @@ -0,0 +1,707 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/elemwise_multi_type/kernels.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "kernels.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| #if defined(__ARM_FEATURE_FMA) | |||
| #define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||
| #else | |||
| #define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||
| #endif | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const int16_t* src0, const float* src1, const float* src2, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t s = 0; s < channel_stride; ++s) { | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + i + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + i + 12); | |||
| auto vec2_0 = vld1q_f32(sptr2 + i); | |||
| auto vec2_1 = vld1q_f32(sptr2 + i + 4); | |||
| auto vec2_2 = vld1q_f32(sptr2 + i + 8); | |||
| auto vec2_3 = vld1q_f32(sptr2 + i + 12); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec2_0 = vld1q_f32(sptr2 + i); | |||
| auto vec2_1 = vld1q_f32(sptr2 + i + 4); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i + 3 < channel_size; i += 4, sptr0 += 4, dst_ptr += 4) { | |||
| auto vec0_0 = vld1_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec2_0 = vld1q_f32(sptr2 + i); | |||
| auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| } | |||
| for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const uint8_t* src0, const float* src1, const float* src2, float* dst) { | |||
| const uint8_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t s = 0; s < channel_stride; ++s) { | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_0123_u8 = vld1q_u8(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + i + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + i + 12); | |||
| auto vec2_0 = vld1q_f32(sptr2 + i); | |||
| auto vec2_1 = vld1q_f32(sptr2 + i + 4); | |||
| auto vec2_2 = vld1q_f32(sptr2 + i + 8); | |||
| auto vec2_3 = vld1q_f32(sptr2 + i + 12); | |||
| auto vec0_01 = | |||
| vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8))); | |||
| auto vec0_23 = | |||
| vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8))); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01_u8 = vld1_u8(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec2_0 = vld1q_f32(sptr2 + i); | |||
| auto vec2_1 = vld1q_f32(sptr2 + i + 4); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[i] + sptr2[i]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const int16_t* src0, const float* src1, const float* src2, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t chan = 0; chan < channel_size; ++chan) { | |||
| auto vec1 = vdupq_n_f32(sptr1[chan]); | |||
| auto vec2 = vdupq_n_f32(sptr2[chan]); | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i + 3 < channel_stride; i += 4, sptr0 += 4, dst_ptr += 4) { | |||
| auto vec0_0 = vld1_s16(sptr0); | |||
| auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| } | |||
| for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const uint8_t* src0, const float* src1, const float* src2, float* dst) { | |||
| const uint8_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t chan = 0; chan < channel_size; ++chan) { | |||
| auto vec1 = vdupq_n_f32(sptr1[chan]); | |||
| auto vec2 = vdupq_n_f32(sptr2[chan]); | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_0123_u8 = vld1q_u8(sptr0); | |||
| auto vec0_01 = | |||
| vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123_u8))); | |||
| auto vec0_23 = | |||
| vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123_u8))); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01_u8 = vld1_u8(sptr0); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[chan] + sptr2[chan]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( | |||
| size_t size, const int16_t* src0, const float* src1, const float* src2, | |||
| float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; | |||
| i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + 12); | |||
| auto vec2_0 = vld1q_f32(sptr2); | |||
| auto vec2_1 = vld1q_f32(sptr2 + 4); | |||
| auto vec2_2 = vld1q_f32(sptr2 + 8); | |||
| auto vec2_3 = vld1q_f32(sptr2 + 12); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec2_0 = vld1q_f32(sptr2); | |||
| auto vec2_1 = vld1q_f32(sptr2 + 4); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) { | |||
| auto vec0_0 = vld1_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec2_0 = vld1q_f32(sptr2); | |||
| auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0_f32, vec1_0); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( | |||
| size_t size, const uint8_t* src0, const float* src1, const float* src2, | |||
| float* dst) { | |||
| const uint8_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; | |||
| i += 16, sptr0 += 16, sptr1 += 16, sptr2 += 16, dst_ptr += 16) { | |||
| auto vec0_0123 = vld1q_u8(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + 12); | |||
| auto vec2_0 = vld1q_f32(sptr2); | |||
| auto vec2_1 = vld1q_f32(sptr2 + 4); | |||
| auto vec2_2 = vld1q_f32(sptr2 + 8); | |||
| auto vec2_3 = vld1q_f32(sptr2 + 12); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123))); | |||
| auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123))); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2_2, vec0_2, vec1_2); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2_3, vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { | |||
| auto vec0_01_u8 = vld1_u8(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec2_0 = vld1q_f32(sptr2); | |||
| auto vec2_1 = vld1q_f32(sptr2 + 4); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2_0, vec0_0, vec1_0); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2_1, vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++sptr1, ++sptr2, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( | |||
| size_t size, const int16_t* src0, const float* src1, const float* src2, | |||
| float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| auto vec1 = vdupq_n_f32(sptr1[0]); | |||
| auto vec2 = vdupq_n_f32(sptr2[0]); | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, sptr2 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i + 3 < size; i += 4, sptr0 += 4, sptr1 += 4, sptr2 += 4, dst_ptr += 4) { | |||
| auto vec0_0 = vld1_s16(sptr0); | |||
| auto vec0_0_f32 = vcvtq_f32_s32(vmovl_s16(vec0_0)); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0_f32, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); | |||
| } | |||
| } | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( | |||
| size_t size, const uint8_t* src0, const float* src1, const float* src2, | |||
| float* dst) { | |||
| const uint8_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| const float* __restrict sptr2 = src2; | |||
| auto vec1 = vdupq_n_f32(sptr1[0]); | |||
| auto vec2 = vdupq_n_f32(sptr2[0]); | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_0123 = vld1q_u8(sptr0); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(vec0_0123))); | |||
| auto vec0_23 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(vec0_0123))); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| auto dst_vec_2 = Vfmaq_f32(vec2, vec0_2, vec1); | |||
| auto dst_vec_3 = Vfmaq_f32(vec2, vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01_u8 = vld1_u8(sptr0); | |||
| auto vec0_01 = vreinterpretq_s16_u16(vmovl_u8(vec0_01_u8)); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = Vfmaq_f32(vec2, vec0_0, vec1); | |||
| auto dst_vec_1 = Vfmaq_f32(vec2, vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1) + (*sptr2); | |||
| } | |||
| } | |||
| void neon_mul_int16xf32xf32_vec_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const int16_t* src0, const float* src1, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t s = 0; s < channel_stride; ++s) { | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + i + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + i + 12); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); | |||
| auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2); | |||
| auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_size; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1 + i); | |||
| auto vec1_1 = vld1q_f32(sptr1 + i + 4); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < channel_size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[i]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_mul_int16xf32xf32_vec_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const int16_t* src0, const float* src1, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| float* __restrict dst_ptr = dst; | |||
| for (size_t batch = 0; batch < batch_size; ++batch) { | |||
| for (size_t chan = 0; chan < channel_size; ++chan) { | |||
| auto vec1 = vdupq_n_f32(sptr1[chan]); | |||
| size_t i = 0; | |||
| for (; i + 15 < channel_stride; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1); | |||
| auto dst_vec_2 = vmulq_f32(vec0_2, vec1); | |||
| auto dst_vec_3 = vmulq_f32(vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < channel_stride; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < channel_stride; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * sptr1[chan]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void neon_mul_int16xf32xf32_vec_vec( | |||
| size_t size, const int16_t* src0, const float* src1, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; i += 16, sptr0 += 16, sptr1 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec1_2 = vld1q_f32(sptr1 + 8); | |||
| auto vec1_3 = vld1q_f32(sptr1 + 12); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); | |||
| auto dst_vec_2 = vmulq_f32(vec0_2, vec1_2); | |||
| auto dst_vec_3 = vmulq_f32(vec0_3, vec1_3); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, sptr1 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec1_0 = vld1q_f32(sptr1); | |||
| auto vec1_1 = vld1q_f32(sptr1 + 4); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1_0); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1_1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++sptr1, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1); | |||
| } | |||
| } | |||
| void neon_mul_int16xf32xf32_vec_scaler( | |||
| size_t size, const int16_t* src0, const float* src1, float* dst) { | |||
| const int16_t* __restrict sptr0 = src0; | |||
| const float* __restrict sptr1 = src1; | |||
| auto vec1 = vdupq_n_f32(sptr1[0]); | |||
| float* __restrict dst_ptr = dst; | |||
| size_t i = 0; | |||
| for (; i + 15 < size; i += 16, sptr0 += 16, dst_ptr += 16) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_23 = vld1q_s16(sptr0 + 8); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto vec0_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_23))); | |||
| auto vec0_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_23))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1); | |||
| auto dst_vec_2 = vmulq_f32(vec0_2, vec1); | |||
| auto dst_vec_3 = vmulq_f32(vec0_3, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| vst1q_f32(dst_ptr + 8, dst_vec_2); | |||
| vst1q_f32(dst_ptr + 12, dst_vec_3); | |||
| } | |||
| for (; i + 7 < size; i += 8, sptr0 += 8, dst_ptr += 8) { | |||
| auto vec0_01 = vld1q_s16(sptr0); | |||
| auto vec0_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vec0_01))); | |||
| auto vec0_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vec0_01))); | |||
| auto dst_vec_0 = vmulq_f32(vec0_0, vec1); | |||
| auto dst_vec_1 = vmulq_f32(vec0_1, vec1); | |||
| vst1q_f32(dst_ptr, dst_vec_0); | |||
| vst1q_f32(dst_ptr + 4, dst_vec_1); | |||
| } | |||
| for (; i < size; ++i, ++sptr0, ++dst_ptr) { | |||
| *dst_ptr = (float)(*sptr0) * (*sptr1); | |||
| } | |||
| } | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/elemwise_multi_type/kernels.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "stddef.h" | |||
| #include "stdint.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const int16_t* src0, const float* src1, const float* src2, float* dst); | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const uint8_t* src0, const float* src1, const float* src2, float* dst); | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const int16_t* src0, const float* src1, const float* src2, float* dst); | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const uint8_t* src0, const float* src1, const float* src2, float* dst); | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( | |||
| size_t size, const int16_t* src0, const float* src1, const float* src2, | |||
| float* dst); | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( | |||
| size_t size, const uint8_t* src0, const float* src1, const float* src2, | |||
| float* dst); | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_b1x_b1x( | |||
| size_t size, size_t vec, const int16_t* src0, const float* src1, | |||
| const float* src2, float* dst); | |||
| void neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( | |||
| size_t size, const int16_t* src0, const float* src1, const float* src2, | |||
| float* dst); | |||
| void neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( | |||
| size_t size, const uint8_t* src0, const float* src1, const float* src2, | |||
| float* dst); | |||
| void neon_mul_int16xf32xf32_vec_bcast111c( | |||
| size_t batch_size, size_t channel_stride, size_t channel_size, | |||
| const int16_t* src0, const float* src1, float* dst); | |||
| void neon_mul_int16xf32xf32_vec_bcast101( | |||
| size_t batch_size, size_t channel_size, size_t channel_stride, | |||
| const int16_t* src0, const float* src1, float* dst); | |||
| void neon_mul_int16xf32xf32_vec_vec( | |||
| size_t size, const int16_t* src0, const float* src1, float* dst); | |||
| void neon_mul_int16xf32xf32_vec_scaler( | |||
| size_t size, const int16_t* src0, const float* src1, float* dst); | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "kernels.h" | |||
| #include "src/common/elemwise_multi_type/kern_defs.cuh" | |||
| #include "src/naive/handle.h" | |||
| @@ -851,6 +852,154 @@ void ElemwiseMultiTypeImpl::on_quantized_mode( | |||
| #undef DISPATCH_QUANTIZED_MODE | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto &src0 = param[0], &src1 = param[1], &src2 = param[2]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo) && | |||
| src1.layout.eq_layout(src2.layout)) { | |||
| // VEC_BCAST111C_BCAST111C | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| binfo.x, binfo.y, binfo.z, | |||
| static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) && | |||
| src1.layout.eq_layout(src2.layout)) { | |||
| // VEC_BCAST101_BCAST101 | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_int16xf32xf32xf32_vec_bcast101_bcast101( | |||
| binfo.x, binfo.y, binfo.z, | |||
| static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_vector(src1.layout) && | |||
| is_vector(src2.layout)) { | |||
| // VEC_VEC_VEC | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_int16xf32xf32xf32_vec_vec_vec( | |||
| size, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && | |||
| is_broadcasted_scalar(src2.layout)) { | |||
| // VEC_SCALAR_SCALAR | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_int16xf32xf32xf32_vec_scaler_scaler( | |||
| size, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto &src0 = param[0], &src1 = param[1], &src2 = param[2]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo) && | |||
| src1.layout.eq_layout(src2.layout)) { | |||
| // VEC_BCAST111C_BCAST111C | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast111c_bcast111c( | |||
| binfo.x, binfo.y, binfo.z, | |||
| static_cast<dt_uint8*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo) && | |||
| src1.layout.eq_layout(src2.layout)) { | |||
| // VEC_BCAST101_BCAST101 | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_uint8xf32xf32xf32_vec_bcast101_bcast101( | |||
| binfo.x, binfo.y, binfo.z, | |||
| static_cast<dt_uint8*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_vector(src1.layout) && | |||
| is_vector(src2.layout)) { | |||
| // VEC_VEC_VEC | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_fuse_mul_add3_uint8xf32xf32xf32_vec_vec_vec( | |||
| size, static_cast<dt_uint8*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) && | |||
| is_broadcasted_scalar(src2.layout)) { | |||
| // VEC_SCALAR_SCALAR | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||
| neon_fuse_mul_add3_uint8xf32xf32xf32_vec_scaler_scaler( | |||
| size, static_cast<dt_uint8*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), | |||
| static_cast<dt_float32*>(src2.raw_ptr()), | |||
| dst.ptr<dt_float32>())); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto &src0 = param[0], &src1 = param[1]; | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(src0.layout) && | |||
| is_NHWC_broadcasted_channel_like(src1.layout, binfo)) { | |||
| // VEC_BCAST111C | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast111c( | |||
| binfo.x, binfo.y, binfo.z, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if ( | |||
| is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { | |||
| // VEC_BCAST101 | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_bcast101( | |||
| binfo.x, binfo.y, binfo.z, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if (is_vector(src0.layout) && is_vector(src1.layout)) { | |||
| // VEC_VEC | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_vec( | |||
| size, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } else if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) { | |||
| auto size = param.size; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_mul_int16xf32xf32_vec_scaler( | |||
| size, static_cast<dt_int16*>(src0.raw_ptr()), | |||
| static_cast<dt_float32*>(src1.raw_ptr()), dst.ptr<dt_float32>())); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst); | |||
| } | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| @@ -48,6 +48,15 @@ protected: | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst, | |||
| Elemwise::Mode mode) override; | |||
| void on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| void on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||
| void on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| public: | |||
| using fallback::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; | |||
| }; | |||
| @@ -155,6 +155,29 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||
| dst.name = name; | |||
| dst.need_specify_out_dtype = true; | |||
| }; | |||
| auto init_fma3_int16xf32xf32xf32 = [&](ModeTrait& dst, const char* name) { | |||
| dst.arity = 3; | |||
| dst.check_inp[0] = make_check_dtype_func(dtype::Int16()); | |||
| dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); | |||
| dst.check_inp[2] = make_check_dtype_func(dtype::Float32()); | |||
| dst.check_out = make_out_dtype_func(dtype::Float32()); | |||
| dst.name = name; | |||
| }; | |||
| auto init_mul_int16xf32xf32 = [&](ModeTrait& dst, const char* name) { | |||
| dst.arity = 2; | |||
| dst.check_inp[0] = make_check_dtype_func(dtype::Int16()); | |||
| dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); | |||
| dst.check_out = make_out_dtype_func(dtype::Float32()); | |||
| dst.name = name; | |||
| }; | |||
| auto init_fma3_uint8xf32xf32xf32 = [&](ModeTrait& dst, const char* name) { | |||
| dst.arity = 3; | |||
| dst.check_inp[0] = make_check_dtype_func(dtype::Uint8()); | |||
| dst.check_inp[1] = make_check_dtype_func(dtype::Float32()); | |||
| dst.check_inp[2] = make_check_dtype_func(dtype::Float32()); | |||
| dst.check_out = make_out_dtype_func(dtype::Float32()); | |||
| dst.name = name; | |||
| }; | |||
| #define SET(f, m) \ | |||
| MIDOUT_BEGIN(megdnn_common_elemwise_multi_type, midout_iv(Mode::m)) { \ | |||
| @@ -169,6 +192,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||
| SET(init_fuse_add_rmulh_rshr_int32x32x32x8, | |||
| FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8); | |||
| SET(init_rshrs_iXxi8xi16, ROUND_SHR_SATURATE_IXxI8xI16); | |||
| SET(init_fma3_int16xf32xf32xf32, FUSE_MUL_ADD3_INT16xF32xF32xF32); | |||
| SET(init_mul_int16xf32xf32, MUL_INT16xF32xF32); | |||
| SET(init_fma3_uint8xf32xf32xf32, FUSE_MUL_ADD3_UINT8xF32xF32xF32); | |||
| //! quantized opr, with specified dtype. | |||
| //! dispatch elemwise mode internally | |||
| @@ -43,6 +43,17 @@ void ElemwiseMultiTypeImplHelper::exec( | |||
| case Mode::ROUND_SHR_SATURATE_IXxI8xI16: | |||
| on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), dst); | |||
| break; | |||
| case Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: | |||
| on_fuse_mul_add3_int16xf32xf32xf32( | |||
| make_elemwise_op_param<3>(src, dst), dst); | |||
| break; | |||
| case Mode::MUL_INT16xF32xF32: | |||
| on_mul_int16xf32xf32(make_elemwise_op_param<2>(src, dst), dst); | |||
| break; | |||
| case Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: | |||
| on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| make_elemwise_op_param<3>(src, dst), dst); | |||
| break; | |||
| ON_QUANTIZED_MODE(RELU, 1); | |||
| ON_QUANTIZED_MODE(ABS, 1); | |||
| ON_QUANTIZED_MODE(ACOS, 1); | |||
| @@ -50,6 +50,27 @@ protected: | |||
| virtual void on_round_shr_saturate_iXxi8xi16( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; | |||
| virtual void on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| MEGDNN_MARK_USED_VAR(param); | |||
| MEGDNN_MARK_USED_VAR(dst); | |||
| megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32xf32."); | |||
| } | |||
| virtual void on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| MEGDNN_MARK_USED_VAR(param); | |||
| MEGDNN_MARK_USED_VAR(dst); | |||
| megdnn_throw("unsupported ElemwiseMultiType fma3 int16xf32xf32."); | |||
| } | |||
| virtual void on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| MEGDNN_MARK_USED_VAR(param); | |||
| MEGDNN_MARK_USED_VAR(dst); | |||
| megdnn_throw("unsupported ElemwiseMultiType fma3 uint8xf32xf32xf32."); | |||
| } | |||
| virtual void on_quantized_mode( | |||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||
| Elemwise::Mode mode) { | |||
| @@ -56,6 +56,216 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | |||
| naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| BroadcastChannelInfo binfo0, binfo1; | |||
| if (is_vector(param[0].layout) && | |||
| is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) && | |||
| is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | |||
| auto x = binfo0.x, y = binfo0.y, z = binfo0.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [=]() { | |||
| const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < x; ++i) { | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto off = i * (y * z) + j * z; | |||
| size_t k = 0; | |||
| for (; k + 4 <= z; k += 4) { | |||
| d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0]; | |||
| d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1]; | |||
| d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2]; | |||
| d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3]; | |||
| } | |||
| for (; k < z; ++k) { | |||
| d[off + k] = a[off + k] * b[k] + c[k]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } else if ( | |||
| is_vector(param[0].layout) && | |||
| is_broadcasted_channel_like(param[1].layout, binfo0) && | |||
| is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | |||
| auto x = binfo0.x, y = binfo0.y, z = binfo0.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [=]() { | |||
| const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto bv = b[j], cv = c[j]; | |||
| for (size_t i = 0; i < x; ++i) { | |||
| auto off = i * (y * z) + j * z, offt = off + z; | |||
| for (; off + 4 <= offt; off += 4) { | |||
| d[off + 0] = a[off + 0] * bv + cv; | |||
| d[off + 1] = a[off + 1] * bv + cv; | |||
| d[off + 2] = a[off + 2] * bv + cv; | |||
| d[off + 3] = a[off + 3] * bv + cv; | |||
| } | |||
| for (; off < offt; ++off) { | |||
| d[off] = a[off] * bv + cv; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| BroadcastChannelInfo binfo; | |||
| if (is_vector(param[0].layout) && | |||
| is_NHWC_broadcasted_channel_like(param[1].layout, binfo)) { | |||
| auto x = binfo.x, y = binfo.y, z = binfo.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto work = [=]() { | |||
| const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < x; ++i) { | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto off = i * (y * z) + j * z; | |||
| size_t k = 0; | |||
| for (; k + 4 <= z; k += 4) { | |||
| d[off + k + 0] = a[off + k + 0] * b[k + 0]; | |||
| d[off + k + 1] = a[off + k + 1] * b[k + 1]; | |||
| d[off + k + 2] = a[off + k + 2] * b[k + 2]; | |||
| d[off + k + 3] = a[off + k + 3] * b[k + 3]; | |||
| } | |||
| for (; k < z; ++k) { | |||
| d[off + k] = a[off + k] * b[k]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } else if ( | |||
| is_vector(param[0].layout) && | |||
| is_broadcasted_channel_like(param[1].layout, binfo)) { | |||
| auto x = binfo.x, y = binfo.y, z = binfo.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto work = [=]() { | |||
| const dt_int16* __restrict__ a = static_cast<dt_int16*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto bv = b[j]; | |||
| for (size_t i = 0; i < x; ++i) { | |||
| auto off = i * (y * z) + j * z, offt = off + z; | |||
| for (; off + 4 <= offt; off += 4) { | |||
| d[off + 0] = a[off + 0] * bv; | |||
| d[off + 1] = a[off + 1] * bv; | |||
| d[off + 2] = a[off + 2] * bv; | |||
| d[off + 3] = a[off + 3] * bv; | |||
| } | |||
| for (; off < offt; ++off) { | |||
| d[off] = a[off] * bv; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_mul_int16xf32xf32(param, dst); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| BroadcastChannelInfo binfo0, binfo1; | |||
| if (is_vector(param[0].layout) && | |||
| is_NHWC_broadcasted_channel_like(param[1].layout, binfo0) && | |||
| is_NHWC_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | |||
| auto x = binfo0.x, y = binfo0.y, z = binfo0.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [=]() { | |||
| const dt_uint8* __restrict__ a = static_cast<dt_uint8*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < x; ++i) { | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto off = i * (y * z) + j * z; | |||
| size_t k = 0; | |||
| for (; k + 4 <= z; k += 4) { | |||
| d[off + k + 0] = a[off + k + 0] * b[k + 0] + c[k + 0]; | |||
| d[off + k + 1] = a[off + k + 1] * b[k + 1] + c[k + 1]; | |||
| d[off + k + 2] = a[off + k + 2] * b[k + 2] + c[k + 2]; | |||
| d[off + k + 3] = a[off + k + 3] * b[k + 3] + c[k + 3]; | |||
| } | |||
| for (; k < z; ++k) { | |||
| d[off + k] = a[off + k] * b[k] + c[k]; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } else if ( | |||
| is_vector(param[0].layout) && | |||
| is_broadcasted_channel_like(param[1].layout, binfo0) && | |||
| is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | |||
| auto x = binfo0.x, y = binfo0.y, z = binfo0.z; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [=]() { | |||
| const dt_uint8* __restrict__ a = static_cast<dt_uint8*>(src0.raw_ptr()); | |||
| const dt_float32* __restrict__ b = static_cast<dt_float32*>(src1.raw_ptr()); | |||
| const dt_float32* __restrict__ c = static_cast<dt_float32*>(src2.raw_ptr()); | |||
| dt_float32* __restrict__ d = dst.ptr<dt_float32>(); | |||
| for (size_t j = 0; j < y; ++j) { | |||
| auto bv = b[j], cv = c[j]; | |||
| for (size_t i = 0; i < x; ++i) { | |||
| auto off = i * (y * z) + j * z, offt = off + z; | |||
| for (; off + 4 <= offt; off += 4) { | |||
| d[off + 0] = a[off + 0] * bv + cv; | |||
| d[off + 1] = a[off + 1] * bv + cv; | |||
| d[off + 2] = a[off + 2] * bv + cv; | |||
| d[off + 3] = a[off + 3] * bv + cv; | |||
| } | |||
| for (; off < offt; ++off) { | |||
| d[off] = a[off] * bv + cv; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| return; | |||
| } | |||
| naive::ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32(param, dst); | |||
| } | |||
| template <typename ctype> | |||
| void ElemwiseMultiTypeImpl::dispatch_fma3_iXxf32xf32xi8_bcast_1x( | |||
| const ElemwiseOpParamN<3>& param, const Broadcast1xInfo& binfo, | |||
| @@ -43,6 +43,12 @@ protected: | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||
| void on_round_shr_saturate_iXxi8xi16( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||
| void on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| void on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||
| void on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| public: | |||
| using naive::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; | |||
| @@ -39,6 +39,66 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto src2 = param[2]; | |||
| auto work = [src0, src1, src2, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_uint8>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto i2 = tensor_iter_valonly<dt_float32>(src2).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1) + (*i2); | |||
| ++i0; | |||
| ++i1; | |||
| ++i2; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||
| auto size = param.size; | |||
| auto src0 = param[0]; | |||
| auto src1 = param[1]; | |||
| auto work = [src0, src1, size, dst]() { | |||
| auto i0 = tensor_iter_valonly<dt_int16>(src0).begin(); | |||
| auto i1 = tensor_iter_valonly<dt_float32>(src1).begin(); | |||
| auto dst_ptr = dst.ptr<dt_float32>(); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dst_ptr[i] = (*i0) * (*i1); | |||
| ++i0; | |||
| ++i1; | |||
| } | |||
| }; | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(work()); | |||
| } | |||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||
| switch (param[0].layout.dtype.enumv()) { | |||
| @@ -60,6 +60,12 @@ protected: | |||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||
| void on_round_shr_saturate_iXxi8xi16( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||
| void on_fuse_mul_add3_int16xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| void on_mul_int16xf32xf32( | |||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||
| void on_fuse_mul_add3_uint8xf32xf32xf32( | |||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||
| void on_quantized_mode( | |||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | |||
| @@ -456,4 +456,107 @@ TEST_F(ARM_COMMON, ELEMWISE_QUANTIZED_MODE_TERNARY_RECORD) { | |||
| } | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32) { | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FMA3_INT16xF32xF32xF32_RECORD) { | |||
| TaskRecordChecker<ElemwiseMultiType> checker(0); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 18}, {1, 1, 1, 18}, {1, 1, 1, 18}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_MUL_INT16xF32xF32) { | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_ELEMWISE_MUL_INT16xF32xF32_RECORD) { | |||
| TaskRecordChecker<ElemwiseMultiType> checker(0); | |||
| checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32) { | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Uint8()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| TEST_F(ARM_COMMON, ELEMWISE_FMA3_UINT8xF32xF32xF32_RECORD) { | |||
| TaskRecordChecker<ElemwiseMultiType> checker(0); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Uint8()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -79,6 +79,73 @@ DEF_TEST(fuse_mul_add3_int16x32x32x32) { | |||
| .execs({{102, 67, 71}, {1, 67, 1}, {1, 67, 1}, {}}); | |||
| } | |||
| DEF_TEST(fuse_mul_add3_int16xf32xf32xf32) { | |||
| // This is not implemented on CUDA. | |||
| if (handle->type() == Handle::HandleType::CUDA) { | |||
| return; | |||
| } | |||
| Checker<ElemwiseMultiType> checker(handle); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}}) | |||
| .execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}}) | |||
| .execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| DEF_TEST(fuse_mul_add3_uint8xf32xf32xf32) { | |||
| // This is not implemented on CUDA. | |||
| if (handle->type() == Handle::HandleType::CUDA) { | |||
| return; | |||
| } | |||
| Checker<ElemwiseMultiType> checker(handle); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Uint8()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 6}, {1, 1, 6}, {1, 1, 6}, {}}) | |||
| .execs({{1, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{1, 700, 600}, {1, 700, 600}, {1, 700, 600}, {}}) | |||
| .execs({{102, 71, 67}, {1, 1, 67}, {1, 1, 67}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| DEF_TEST(fuse_mul_add3_int16xf32xf32) { | |||
| // This is not implemented on CUDA. | |||
| if (handle->type() == Handle::HandleType::CUDA) { | |||
| return; | |||
| } | |||
| Checker<ElemwiseMultiType> checker(handle); | |||
| checker.set_param({ElemwiseMultiType::Mode::MUL_INT16xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| UniformIntRNG rng{-100, 100}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.execs({{5, 7, 6}, {1, 1, 6}, {}}) | |||
| .execs({{1, 700, 600}, {1, 1, 600}, {}}) | |||
| .execs({{1, 700, 600}, {1, 700, 600}, {}}) | |||
| .execs({{102, 71, 67}, {1, 1, 67}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| DEF_TEST(fuse_mul_add3_iXxf32xf32xi8) { | |||
| Checker<ElemwiseMultiType> checker(handle); | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8}); | |||
| @@ -20,10 +20,13 @@ namespace test { | |||
| namespace elemwise_multi_type { | |||
| #define FIRST_ELEMWISE_MULTI_TYPE_CASE fuse_mul_add3_int16x32x32x32 | |||
| #define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \ | |||
| cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \ | |||
| cb(fuse_add_rmulh_round_shr_saturate_int16) \ | |||
| cb(fuse_add_rmulh_round_shr_saturate_int32) | |||
| #define FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) \ | |||
| cb(fuse_mul_add3_iXxf32xf32xi8) cb(round_shr_saturate_iXxi8xi8) \ | |||
| cb(fuse_add_rmulh_round_shr_saturate_int16) \ | |||
| cb(fuse_add_rmulh_round_shr_saturate_int32) \ | |||
| cb(fuse_mul_add3_int16xf32xf32xf32) \ | |||
| cb(fuse_mul_add3_uint8xf32xf32xf32) \ | |||
| cb(fuse_mul_add3_int16xf32xf32) | |||
| #define FOREACH_ELEMWISE_MULTI_TYPE_CASE(cb) \ | |||
| cb(FIRST_ELEMWISE_MULTI_TYPE_CASE) FOREACH_ELEMWISE_MULTI_TYPE_NONFIRST_CASE(cb) | |||
| @@ -40,6 +40,24 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16x32x32x32) { | |||
| checker.execs({{A, B, C}, {1, B, 1}, {1, B, 1}, {}}); | |||
| } | |||
| TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_RECORD_FMA3_INT16xF32xF32xF32) { | |||
| TaskRecordChecker<ElemwiseMultiType> checker{1}; | |||
| checker.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32}); | |||
| checker.set_dtype(0, dtype::Int16()); | |||
| checker.set_dtype(1, dtype::Float32()); | |||
| checker.set_dtype(2, dtype::Float32()); | |||
| UniformIntRNG rng{-10, 10}; | |||
| checker.set_rng(0, &rng); | |||
| checker.set_rng(1, &rng); | |||
| checker.set_rng(2, &rng); | |||
| checker.execs({{5, 7, 16}, {1, 1, 16}, {1, 1, 16}, {}}) | |||
| .execs({{2, 700, 600}, {1, 1, 600}, {1, 1, 600}, {}}) | |||
| .execs({{2, 700, 600}, {2, 700, 600}, {2, 700, 600}, {}}) | |||
| .execs({{16, 16, 128}, {16, 16, 128}, {16, 16, 128}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 128, 1, 1}, {1, 128, 1, 1}, {}}) | |||
| .execs({{16, 128, 16, 16}, {1, 1, 1, 1}, {1, 1, 1, 1}, {}}); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) { | |||
| Benchmarker<ElemwiseMultiType> bench{handle()}; | |||