| @@ -56,7 +56,7 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi | |||||
| void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | ||||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | size_t plane_size, size_t plane_stride, size_t relu_type) { | ||||
| #ifdef ENABLE_ARM | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| size_t oc4mod = output_channel % C4NUM; | size_t oc4mod = output_channel % C4NUM; | ||||
| size_t oc4div = output_channel - oc4mod; | size_t oc4div = output_channel - oc4mod; | ||||
| size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); | size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); | ||||
| @@ -50,10 +50,6 @@ void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_ | |||||
| size_t in_kh_step, size_t in_kw_step); | size_t in_kh_step, size_t in_kw_step); | ||||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | ||||
| size_t plane_size, size_t stride, size_t relu_type); | size_t plane_size, size_t stride, size_t relu_type); | ||||
| #endif | |||||
| #ifdef ENABLE_ARM | |||||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, | void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, | ||||
| size_t output_channel, size_t input_step); | size_t output_channel, size_t input_step); | ||||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | #endif | ||||
| #ifndef ENABLE_ARM | |||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, | void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, | ||||
| int output_channel, int input_step) { | int output_channel, int input_step) { | ||||
| for (int i = 0; i < num_pixels; i++) { | for (int i = 0; i < num_pixels; i++) { | ||||
| @@ -161,7 +161,7 @@ void DeConvWgInputPack(const float *src_ptr, float *dst_ptr, int channel, int st | |||||
| return; | return; | ||||
| } | } | ||||
| #ifndef ENABLE_ARM | |||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||||
| void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { | void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { | ||||
| int dx, sz, dz; | int dx, sz, dz; | ||||
| const int src_depth_step = 4 * DECONV_WINOGRAD_DEFAULT_TILE; | const int src_depth_step = 4 * DECONV_WINOGRAD_DEFAULT_TILE; | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | |||||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, | |||||
| size_t output_channel, size_t input_step) { | |||||
| size_t out_c16 = DOWN_DIV(output_channel, C16NUM) * C16NUM; | |||||
| size_t out_c8 = DOWN_DIV(output_channel, C8NUM) * C8NUM; | |||||
| size_t out_c4 = DOWN_DIV(output_channel, C4NUM) * C4NUM; | |||||
| for (int i = 0; i < num_pixels; i++) { | |||||
| const float *weight_tmp = weight_ptr; | |||||
| const float *input_tmp = input_ptr; | |||||
| size_t out_c = 0; | |||||
| for (; out_c < out_c16; out_c += C16NUM) { | |||||
| __m128 dst1 = _mm_loadu_ps(output_ptr); | |||||
| __m128 dst2 = _mm_loadu_ps(output_ptr + 4); | |||||
| __m128 dst3 = _mm_loadu_ps(output_ptr + 8); | |||||
| __m128 dst4 = _mm_loadu_ps(output_ptr + 12); | |||||
| __m128 w1 = _mm_loadu_ps(weight_tmp); | |||||
| __m128 w2 = _mm_loadu_ps(weight_tmp + 4); | |||||
| __m128 w3 = _mm_loadu_ps(weight_tmp + 8); | |||||
| __m128 w4 = _mm_loadu_ps(weight_tmp + 12); | |||||
| __m128 in1 = _mm_loadu_ps(input_tmp); | |||||
| __m128 in2 = _mm_loadu_ps(input_tmp + 4); | |||||
| __m128 in3 = _mm_loadu_ps(input_tmp + 8); | |||||
| __m128 in4 = _mm_loadu_ps(input_tmp + 12); | |||||
| dst1 = MS_MLAQ_F32(dst1, w1, in1); | |||||
| dst2 = MS_MLAQ_F32(dst2, w2, in2); | |||||
| dst3 = MS_MLAQ_F32(dst3, w3, in3); | |||||
| dst4 = MS_MLAQ_F32(dst4, w4, in4); | |||||
| _mm_storeu_ps(output_ptr, dst1); | |||||
| _mm_storeu_ps(output_ptr + 4, dst2); | |||||
| _mm_storeu_ps(output_ptr + 8, dst3); | |||||
| _mm_storeu_ps(output_ptr + 12, dst4); | |||||
| output_ptr += 16; | |||||
| input_tmp += 16; | |||||
| weight_tmp += 16; | |||||
| } | |||||
| for (; out_c < out_c8; out_c += C8NUM) { | |||||
| __m128 dst1 = _mm_loadu_ps(output_ptr); | |||||
| __m128 dst2 = _mm_loadu_ps(output_ptr + 4); | |||||
| __m128 w1 = _mm_loadu_ps(weight_tmp); | |||||
| __m128 w2 = _mm_loadu_ps(weight_tmp + 4); | |||||
| __m128 in1 = _mm_loadu_ps(input_tmp); | |||||
| __m128 in2 = _mm_loadu_ps(input_tmp + 4); | |||||
| dst1 = MS_MLAQ_F32(dst1, w1, in1); | |||||
| dst2 = MS_MLAQ_F32(dst2, w2, in2); | |||||
| _mm_storeu_ps(output_ptr, dst1); | |||||
| _mm_storeu_ps(output_ptr + 4, dst2); | |||||
| output_ptr += 8; | |||||
| input_tmp += 8; | |||||
| weight_tmp += 8; | |||||
| } | |||||
| for (; out_c < out_c4; out_c += C4NUM) { | |||||
| __m128 dst1 = _mm_loadu_ps(output_ptr); | |||||
| __m128 w1 = _mm_loadu_ps(weight_tmp); | |||||
| __m128 in1 = _mm_loadu_ps(input_tmp); | |||||
| dst1 = MS_MLAQ_F32(dst1, w1, in1); | |||||
| _mm_storeu_ps(output_ptr, dst1); | |||||
| output_ptr += 4; | |||||
| input_tmp += 4; | |||||
| weight_tmp += 4; | |||||
| } | |||||
| for (; out_c < output_channel; out_c++) { | |||||
| *output_ptr++ += weight_ptr[out_c] * input_ptr[out_c]; | |||||
| } | |||||
| input_ptr += input_step; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,126 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | |||||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | |||||
| __m128 relu6 = _mm_set_ps1(6.0); | |||||
| __m128 zero = _mm_setzero_ps(); | |||||
| size_t stride = oc4div + oc4mod; | |||||
| plane_stride /= sizeof(float); | |||||
| for (size_t loop_c4 = 0; loop_c4 < oc4div; loop_c4 += C4NUM) { | |||||
| size_t plane_size_tmp = plane_size; | |||||
| float *dst_c4 = dst + loop_c4; | |||||
| __m128 bias1 = _mm_setzero_ps(); | |||||
| if (bias != NULL) { | |||||
| bias1 = _mm_loadu_ps(bias); | |||||
| bias += 4; | |||||
| } | |||||
| for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { | |||||
| __m128 src1 = _mm_loadu_ps(src); | |||||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||||
| __m128 src3 = _mm_loadu_ps(src + 8); | |||||
| __m128 src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| src1 = _mm_add_ps(src1, bias1); | |||||
| src2 = _mm_add_ps(src2, bias1); | |||||
| src3 = _mm_add_ps(src3, bias1); | |||||
| src4 = _mm_add_ps(src4, bias1); | |||||
| switch (relu_type) { | |||||
| case 3: | |||||
| src1 = _mm_min_ps(src1, relu6); | |||||
| src2 = _mm_min_ps(src2, relu6); | |||||
| src3 = _mm_min_ps(src3, relu6); | |||||
| src4 = _mm_min_ps(src4, relu6); | |||||
| case 1: | |||||
| src1 = _mm_max_ps(src1, zero); | |||||
| src2 = _mm_max_ps(src2, zero); | |||||
| src3 = _mm_max_ps(src3, zero); | |||||
| src4 = _mm_max_ps(src4, zero); | |||||
| break; | |||||
| } | |||||
| _mm_storeu_ps(dst_c4, src1); | |||||
| dst_c4 += stride; | |||||
| _mm_storeu_ps(dst_c4, src2); | |||||
| dst_c4 += stride; | |||||
| _mm_storeu_ps(dst_c4, src3); | |||||
| dst_c4 += stride; | |||||
| _mm_storeu_ps(dst_c4, src4); | |||||
| dst_c4 += stride; | |||||
| } | |||||
| for (; plane_size_tmp > 0; plane_size_tmp -= 1) { | |||||
| __m128 src1 = _mm_loadu_ps(src); | |||||
| src1 = _mm_add_ps(src1, bias1); | |||||
| switch (relu_type) { | |||||
| case 3: | |||||
| src1 = _mm_min_ps(src1, relu6); | |||||
| case 1: | |||||
| src1 = _mm_max_ps(src1, zero); | |||||
| break; | |||||
| } | |||||
| _mm_storeu_ps(dst_c4, src1); | |||||
| dst_c4 += stride; | |||||
| src += 4; | |||||
| } | |||||
| src += plane_stride; | |||||
| } | |||||
| if (oc4mod == 0) { | |||||
| return; | |||||
| } | |||||
| __m128 bias1 = _mm_setzero_ps(); | |||||
| if (bias != NULL) { | |||||
| bias1 = _mm_loadu_ps(bias); | |||||
| bias += 4; | |||||
| } | |||||
| float *dst_c1 = dst + oc4div; | |||||
| for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { | |||||
| __m128 src1 = _mm_loadu_ps(src); | |||||
| src += 4; | |||||
| src1 = _mm_add_ps(src1, bias1); | |||||
| switch (relu_type) { | |||||
| case 3: | |||||
| src1 = _mm_min_ps(src1, relu6); | |||||
| case 1: | |||||
| src1 = _mm_max_ps(src1, zero); | |||||
| break; | |||||
| } | |||||
| switch (oc4mod) { | |||||
| case 1: | |||||
| _mm_store_ss(dst_c1, src1); | |||||
| dst_c1 += stride; | |||||
| break; | |||||
| case 2: | |||||
| _mm_storel_pi((__m64 *)(dst_c1), src1); | |||||
| dst_c1 += stride; | |||||
| break; | |||||
| case 3: | |||||
| _mm_storel_pi((__m64 *)(dst_c1), src1); | |||||
| src1 = _mm_unpackhi_ps(src1, src1); | |||||
| _mm_store_ss(dst_c1 + 2, src1); | |||||
| dst_c1 += stride; | |||||
| break; | |||||
| case 4: | |||||
| _mm_storeu_ps(dst_c1, src1); | |||||
| dst_c1 += stride; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,175 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | |||||
| void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { | |||||
| const float *src_tmp = src; | |||||
| for (int i = 0; i < oc4; ++i) { | |||||
| float *dst_tmp = dst; | |||||
| src = src_tmp; | |||||
| size_t ic4_tmp = ic4 - 1; | |||||
| __m128 src1 = _mm_loadu_ps(src); | |||||
| __m128 src2 = _mm_loadu_ps(src + 4); | |||||
| __m128 src3 = _mm_loadu_ps(src + 8); | |||||
| __m128 src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| __m128 weight_data[4]; | |||||
| weight_data[0] = _mm_loadu_ps(weight); | |||||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||||
| weight_data[2] = _mm_loadu_ps(weight + 8); | |||||
| weight_data[3] = _mm_loadu_ps(weight + 12); | |||||
| weight += 16; | |||||
| __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||||
| __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||||
| __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||||
| __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||||
| for (int j = 1; j < 4; ++j) { | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||||
| } | |||||
| src1 = _mm_loadu_ps(src); | |||||
| src2 = _mm_loadu_ps(src + 4); | |||||
| src3 = _mm_loadu_ps(src + 8); | |||||
| src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0])); | |||||
| __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0])); | |||||
| __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0])); | |||||
| __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0])); | |||||
| for (int j = 1; j < 4; ++j) { | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||||
| } | |||||
| if (ic4_tmp != 0) { | |||||
| ic4_tmp -= 1; | |||||
| src1 = _mm_loadu_ps(src); | |||||
| src2 = _mm_loadu_ps(src + 4); | |||||
| src3 = _mm_loadu_ps(src + 8); | |||||
| src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| weight_data[0] = _mm_loadu_ps(weight); | |||||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||||
| weight += 8; | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||||
| for (; ic4_tmp != 0; ic4_tmp -= 1) { | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||||
| weight_data[2] = _mm_loadu_ps(weight); | |||||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||||
| weight += 8; | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||||
| src1 = _mm_loadu_ps(src); | |||||
| src2 = _mm_loadu_ps(src + 4); | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||||
| src3 = _mm_loadu_ps(src + 8); | |||||
| src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||||
| weight_data[0] = _mm_loadu_ps(weight); | |||||
| weight_data[1] = _mm_loadu_ps(weight + 4); | |||||
| weight += 8; | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||||
| src1 = _mm_loadu_ps(src); | |||||
| src2 = _mm_loadu_ps(src + 4); | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||||
| src3 = _mm_loadu_ps(src + 8); | |||||
| src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(src1[0]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(src2[0]))); | |||||
| } | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(src3[0]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(src4[0]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[1], _mm_set_ps1(src1[1]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[1], _mm_set_ps1(src2[1]))); | |||||
| weight_data[2] = _mm_loadu_ps(weight); | |||||
| weight_data[3] = _mm_loadu_ps(weight + 4); | |||||
| weight += 8; | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[1], _mm_set_ps1(src3[1]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[1], _mm_set_ps1(src4[1]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[2], _mm_set_ps1(src1[2]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[2], _mm_set_ps1(src2[2]))); | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[2], _mm_set_ps1(src3[2]))); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[2], _mm_set_ps1(src4[2]))); | |||||
| dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(src1[3]))); | |||||
| dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(src2[3]))); | |||||
| dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(src3[3]))); | |||||
| src1 = _mm_loadu_ps(src); | |||||
| src2 = _mm_loadu_ps(src + 4); | |||||
| dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(src4[3]))); | |||||
| src3 = _mm_loadu_ps(src + 8); | |||||
| src4 = _mm_loadu_ps(src + 12); | |||||
| src += 16; | |||||
| for (int j = 0; j < 4; ++j) { | |||||
| dst5 = _mm_add_ps(dst5, _mm_mul_ps(weight_data[j], _mm_set_ps1(src1[j]))); | |||||
| dst6 = _mm_add_ps(dst6, _mm_mul_ps(weight_data[j], _mm_set_ps1(src2[j]))); | |||||
| dst7 = _mm_add_ps(dst7, _mm_mul_ps(weight_data[j], _mm_set_ps1(src3[j]))); | |||||
| dst8 = _mm_add_ps(dst8, _mm_mul_ps(weight_data[j], _mm_set_ps1(src4[j]))); | |||||
| } | |||||
| } | |||||
| _mm_storeu_ps(dst, dst1); | |||||
| _mm_storeu_ps(dst + 4, dst2); | |||||
| _mm_storeu_ps(dst + 8, dst3); | |||||
| _mm_storeu_ps(dst + 12, dst4); | |||||
| _mm_storeu_ps(dst + 16, dst5); | |||||
| _mm_storeu_ps(dst + 20, dst6); | |||||
| _mm_storeu_ps(dst + 24, dst7); | |||||
| _mm_storeu_ps(dst + 28, dst8); | |||||
| dst = dst_tmp + cal_num; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| @@ -17,7 +17,7 @@ | |||||
| #include "nnacl/tensorlist_parameter.h" | #include "nnacl/tensorlist_parameter.h" | ||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "src/ops/tensorlistfromtensor.h" | |||||
| #include "src/ops/tensorlist_fromtensor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/tensorlistgetitem.h" | |||||
| #include "src/ops/tensorlist_getitem.h" | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/tensorlist_parameter.h" | #include "nnacl/tensorlist_parameter.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/tensorlistreserve.h" | |||||
| #include "src/ops/tensorlist_reserve.h" | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/tensorlist_parameter.h" | #include "nnacl/tensorlist_parameter.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/tensorlistsetitem.h" | |||||
| #include "src/ops/tensorlist_setitem.h" | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/tensorlist_parameter.h" | #include "nnacl/tensorlist_parameter.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/tensorliststack.h" | |||||
| #include "src/ops/tensorlist_stack.h" | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/tensorlist_parameter.h" | #include "nnacl/tensorlist_parameter.h" | ||||
| @@ -150,11 +150,11 @@ | |||||
| #include "src/ops/unsorted_segment_sum.h" | #include "src/ops/unsorted_segment_sum.h" | ||||
| #include "src/ops/reciprocal.h" | #include "src/ops/reciprocal.h" | ||||
| #include "src/ops/constant.h" | #include "src/ops/constant.h" | ||||
| #include "src/ops/tensorlistfromtensor.h" | |||||
| #include "src/ops/tensorlistgetitem.h" | |||||
| #include "src/ops/tensorlistsetitem.h" | |||||
| #include "src/ops/tensorlistreserve.h" | |||||
| #include "src/ops/tensorliststack.h" | |||||
| #include "src/ops/tensorlist_fromtensor.h" | |||||
| #include "src/ops/tensorlist_getitem.h" | |||||
| #include "src/ops/tensorlist_setitem.h" | |||||
| #include "src/ops/tensorlist_reserve.h" | |||||
| #include "src/ops/tensorlist_stack.h" | |||||
| #include "src/ops/merge.h" | #include "src/ops/merge.h" | ||||
| #include "src/ops/switch.h" | #include "src/ops/switch.h" | ||||
| #include "src/ops/partial.h" | #include "src/ops/partial.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/tensorlistfromtensor.h" | |||||
| #include "src/ops/tensorlist_fromtensor.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -133,7 +133,6 @@ int TensorListFromTensor::InferShape(std::vector<lite::Tensor *> inputs_, std::v | |||||
| auto ele_shape_ptr = reinterpret_cast<int *>(input1->data_c()); | auto ele_shape_ptr = reinterpret_cast<int *>(input1->data_c()); | ||||
| auto output = reinterpret_cast<TensorList *>(outputs_[0]); | auto output = reinterpret_cast<TensorList *>(outputs_[0]); | ||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| // output->set_tensors_data_type(input0->data_type()); | |||||
| std::vector<std::vector<int> > tensor_shape(dim0, std::vector<int>(input0_shape.begin() + 1, input0_shape.end())); | std::vector<std::vector<int> > tensor_shape(dim0, std::vector<int>(input0_shape.begin() + 1, input0_shape.end())); | ||||
| output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input1->ElementsNum())); | output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input1->ElementsNum())); | ||||
| output->set_shape(std::vector<int>(1, dim0)); | output->set_shape(std::vector<int>(1, dim0)); | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/tensorlistgetitem.h" | |||||
| #include "src/ops/tensorlist_getitem.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/tensorlistreserve.h" | |||||
| #include "src/ops/tensorlist_reserve.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/tensorlistsetitem.h" | |||||
| #include "src/ops/tensorlist_setitem.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/ops/tensorliststack.h" | |||||
| #include "src/ops/tensorlist_stack.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/TensorListFromTensor.h" | |||||
| #include "src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/TensorListGetItem.h" | |||||
| #include "src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/TensorListReserve.h" | |||||
| #include "src/runtime/kernel/arm/fp32/tensorlist_reserve_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -16,7 +16,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/TensorListSetItem.h" | |||||
| #include "src/runtime/kernel/arm/fp32/tensorlist_setitem_fp32.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/TensorListStack.h" | |||||
| #include "src/runtime/kernel/arm/fp32/tensorlist_stack_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||