From: @lzkcode Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongtags/v1.1.0
| @@ -685,6 +685,8 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | int output_width, int input_stride, bool relu, bool relu6, int kernel) { | ||||
| if (kernel == 9) { | if (kernel == 9) { | ||||
| ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); | ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); | ||||
| } else if (kernel == 25) { | |||||
| ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -69,6 +69,9 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c | |||||
| #ifdef ENABLE_AVX | #ifdef ENABLE_AVX | ||||
| void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, | void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, | ||||
| size_t output_width, size_t input_stride, size_t relu, size_t relu6); | size_t output_width, size_t input_stride, size_t relu, size_t relu6); | ||||
| void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, | |||||
| size_t output_width, size_t input_stride, size_t relu, size_t relu6); | |||||
| #endif | #endif | ||||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | ||||
| @@ -0,0 +1,116 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifdef ENABLE_AVX | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||||
| void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, | |||||
| size_t output_width, size_t input_stride, size_t relu, size_t relu6) { | |||||
| input_stride /= sizeof(float *); | |||||
| size_t c8 = UP_DIV(channels, C8NUM) * C8NUM; | |||||
| size_t c8_mod = channels % C8NUM; | |||||
| int kernel = 25; | |||||
| for (int i = 0; i < output_width; ++i) { | |||||
| float *in[kernel]; | |||||
| for (int k = 0; k < kernel; k++) { | |||||
| in[k] = input[k]; | |||||
| } | |||||
| input += input_stride; | |||||
| size_t c = c8; | |||||
| const float *w = weights; | |||||
| const float *bias1 = bias; | |||||
| for (; c >= C8NUM; c -= C8NUM) { | |||||
| __m256 out1 = _mm256_loadu_ps(bias1); | |||||
| bias1 += 8; | |||||
| for (int k = 0; k < kernel; k += 5) { | |||||
| __m256 in1 = _mm256_loadu_ps(in[k]); | |||||
| __m256 w1 = _mm256_loadu_ps(w); | |||||
| __m256 in2 = _mm256_loadu_ps(in[k + 1]); | |||||
| __m256 w2 = _mm256_loadu_ps(w + 8); | |||||
| out1 = _mm256_fmadd_ps(in1, w1, out1); | |||||
| __m256 in3 = _mm256_loadu_ps(in[k + 2]); | |||||
| __m256 w3 = _mm256_loadu_ps(w + 16); | |||||
| out1 = _mm256_fmadd_ps(in2, w2, out1); | |||||
| __m256 in4 = _mm256_loadu_ps(in[k + 3]); | |||||
| __m256 w4 = _mm256_loadu_ps(w + 24); | |||||
| out1 = _mm256_fmadd_ps(in3, w3, out1); | |||||
| __m256 in5 = _mm256_loadu_ps(in[k + 8]); | |||||
| __m256 w5 = _mm256_loadu_ps(w + 32); | |||||
| out1 = _mm256_fmadd_ps(in4, w4, out1); | |||||
| w += 40; | |||||
| in[k] += C8NUM; | |||||
| in[k + 1] += C8NUM; | |||||
| in[k + 2] += C8NUM; | |||||
| in[k + 3] += C8NUM; | |||||
| in[k + 4] += C8NUM; | |||||
| out1 = _mm256_fmadd_ps(in5, w5, out1); | |||||
| } | |||||
| if (relu6 != 0) { | |||||
| __m256 relu6_data = _mm256_set1_ps(6.0); | |||||
| out1 = _mm256_min_ps(out1, relu6_data); | |||||
| } | |||||
| if (relu != 0 || relu6 != 0) { | |||||
| __m256 zero = _mm256_setzero_ps(); | |||||
| out1 = _mm256_max_ps(out1, zero); | |||||
| } | |||||
| if (c == C8NUM) { | |||||
| __m128 tmp; | |||||
| switch (c8_mod) { | |||||
| case 1: | |||||
| _mm_store_ss(output, _mm256_castps256_ps128(out1)); | |||||
| break; | |||||
| case 2: | |||||
| _mm_storel_pi((__m64 *)output, _mm256_castps256_ps128(out1)); | |||||
| break; | |||||
| case 3: | |||||
| tmp = _mm256_castps256_ps128(out1); | |||||
| _mm_storel_pi((__m64 *)output, tmp); | |||||
| tmp = _mm_unpackhi_ps(tmp, tmp); | |||||
| _mm_store_ss(output + 2, tmp); | |||||
| break; | |||||
| case 4: | |||||
| _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); | |||||
| break; | |||||
| case 5: | |||||
| _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); | |||||
| _mm_store_ss(output + 4, _mm256_extractf128_ps(out1, 1)); | |||||
| break; | |||||
| case 6: | |||||
| _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); | |||||
| _mm_storel_pi((__m64 *)(output + 4), _mm256_extractf128_ps(out1, 1)); | |||||
| break; | |||||
| case 7: | |||||
| _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); | |||||
| tmp = _mm256_extractf128_ps(out1, 1); | |||||
| _mm_storel_pi((__m64 *)(output + 4), tmp); | |||||
| tmp = _mm_unpackhi_ps(tmp, tmp); | |||||
| _mm_store_ss(output + 6, tmp); | |||||
| break; | |||||
| default: | |||||
| _mm256_storeu_ps(output, out1); | |||||
| break; | |||||
| } | |||||
| output += c8_mod == 0 ? C8NUM : c8_mod; | |||||
| } else { | |||||
| _mm256_storeu_ps(output, out1); | |||||
| output += C8NUM; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| @@ -147,12 +147,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | ||||
| conv_param->output_h_ = outputs[kOutputIndex]->Height(); | conv_param->output_h_ = outputs[kOutputIndex]->Height(); | ||||
| conv_param->output_w_ = outputs[kOutputIndex]->Width(); | conv_param->output_w_ = outputs[kOutputIndex]->Width(); | ||||
| #ifdef ENABLE_AVX | |||||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) { | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| #elif defined(ENABLE_ARM64) | |||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| if (CheckConvDwUseIndirectBuffer(conv_param)) { | if (CheckConvDwUseIndirectBuffer(conv_param)) { | ||||
| kernel = | kernel = | ||||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| @@ -60,7 +60,7 @@ int FullconnectionCPUKernel::ReSize() { | |||||
| #endif | #endif | ||||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | ||||
| fc_param_->col_align_ = UP_ROUND(fc_param_->col_, col_tile); | fc_param_->col_align_ = UP_ROUND(fc_param_->col_, col_tile); | ||||
| fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM); | |||||
| fc_param_->row_6_ = UP_ROUND(fc_param_->row_, C6NUM); | |||||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | ||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_align_, col_tile)); | thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_align_, col_tile)); | ||||