Merge pull request !4531 from ling/deconvtags/v0.7.0-beta
| @@ -0,0 +1,210 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_NULL_PTR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_DeConv2D; | |||||
| namespace mindspore::kernel { | |||||
| DeConvolutionFp16CPUKernel::~DeConvolutionFp16CPUKernel() { | |||||
| FreeParam(); | |||||
| return; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::ReSize() { | |||||
| FreeParam(); | |||||
| InitParam(); | |||||
| return RET_OK; | |||||
| } | |||||
| void DeConvolutionFp16CPUKernel::FreeParam() { | |||||
| if (tmp_buffer_ != nullptr) { | |||||
| free(tmp_buffer_); | |||||
| tmp_buffer_ = nullptr; | |||||
| } | |||||
| if (pack_input_ != nullptr) { | |||||
| free(pack_input_); | |||||
| pack_input_ = nullptr; | |||||
| } | |||||
| if (pack_output_ != nullptr) { | |||||
| free(pack_output_); | |||||
| pack_output_ = nullptr; | |||||
| } | |||||
| return; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::InitWeightBias() { | |||||
| bias_data_ = malloc(UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float16_t)); | |||||
| if (bias_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "deconv malloc bias_data_ error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(bias_data_, 0, UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(float16_t)); | |||||
| if (in_tensors_.size() == 3) { | |||||
| Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[2]->Data()), reinterpret_cast<float16_t *>(bias_data_), | |||||
| conv_param_->output_channel_); | |||||
| } | |||||
| size_t weight_pack_size = conv_param_->input_channel_ * conv_param_->kernel_w_ * conv_param_->kernel_h_ * | |||||
| UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(float16_t); | |||||
| execute_weight_ = reinterpret_cast<float16_t *>(malloc(weight_pack_size)); | |||||
| if (execute_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "deconv malloc execute_weight_ error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(execute_weight_, 0, weight_pack_size); | |||||
| PackNHWCFp32ToC8HWN8Fp16(reinterpret_cast<float *>(in_tensors_[1]->Data()), execute_weight_, | |||||
| conv_param_->input_channel_, kernel_plane_, conv_param_->output_channel_); | |||||
| return RET_OK; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::InitParam() { | |||||
| input_plane_ = conv_param_->input_h_ * conv_param_->input_w_; | |||||
| kernel_plane_ = conv_param_->kernel_w_ * conv_param_->kernel_h_; | |||||
| output_plane_ = conv_param_->output_h_ * conv_param_->output_w_; | |||||
| matmul_param_->row_ = input_plane_; | |||||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||||
| matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; | |||||
| row16_ = UP_ROUND(matmul_param_->row_, 16); | |||||
| col8_ = UP_ROUND(matmul_param_->col_, 8); | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); | |||||
| pack_input_ = reinterpret_cast<float16_t *>(malloc(row16_ * matmul_param_->deep_ * sizeof(float16_t))); | |||||
| if (pack_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| pack_output_ = reinterpret_cast<float16_t *>( | |||||
| malloc(UP_ROUND(conv_param_->output_channel_, C8NUM) * output_plane_ * sizeof(float16_t))); | |||||
| if (pack_output_ == nullptr) { | |||||
| MS_LOG(ERROR) << "deconv Malloc pack_output_ error!"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| tmp_buffer_ = reinterpret_cast<float16_t *>(malloc(row16_ * col8_ * sizeof(float16_t))); | |||||
| if (tmp_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "deconv Malloc tmp_buffer_ error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int DeConvFp16Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||||
| auto deconv = reinterpret_cast<DeConvolutionFp16CPUKernel *>(cdata); | |||||
| auto error_code = deconv->DoDeconv(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "DeConvFp16Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::DoDeconv(int task_id) { | |||||
| int oc = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); | |||||
| int oc_res = MSMIN(thread_stride_ * C8NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C8NUM); | |||||
| if (oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto tmp_buf = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * row16_; | |||||
| MatMulFp16(pack_input_, execute_weight_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||||
| tmp_buf, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_, oc * C8NUM * kernel_plane_, 0, | |||||
| false); | |||||
| DeConvPostFp16(tmp_buf, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, | |||||
| reinterpret_cast<float16_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, | |||||
| execute_output_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); | |||||
| return RET_OK; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::Init() { | |||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| ConvolutionBaseCPUKernel::Init(); | |||||
| int error_code = InitParam(); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "deconv InitParam error!"; | |||||
| return error_code; | |||||
| } | |||||
| error_code = InitWeightBias(); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "deconv InitWeightBias error!"; | |||||
| return error_code; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int DeConvolutionFp16CPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | |||||
| RowMajor2Col8MajorFp16(execute_input_, pack_input_, input_plane_, conv_param_->input_channel_); | |||||
| int error_code = LiteBackendParallelLaunch(DeConvFp16Run, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "deconv fp32 run error! error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| ConvolutionBaseFP16CPUKernel::IfCastOutput(); | |||||
| ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); | |||||
| auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| delete kernel; | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ | |||||
| #include <float.h> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/deconv_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/matmul_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/pack_fp16.h" | |||||
| #include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" | |||||
| namespace mindspore::kernel { | |||||
| class DeConvolutionFp16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | |||||
| DeConvolutionFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | |||||
| const lite::Primitive *primitive) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||||
| matmul_param_ = new MatMulParameter(); | |||||
| } | |||||
| ~DeConvolutionFp16CPUKernel() override; | |||||
| int Init() override; | |||||
| int Run() override; | |||||
| int ReSize() override; | |||||
| public: | |||||
| int DoDeconv(int task_id); | |||||
| private: | |||||
| void FreeParam(); | |||||
| int InitParam(); | |||||
| int InitWeightBias(); | |||||
| private: | |||||
| MatMulParameter *matmul_param_; | |||||
| int row16_; | |||||
| int col8_; | |||||
| int input_plane_; | |||||
| int kernel_plane_; | |||||
| int output_plane_; | |||||
| int thread_count_; | |||||
| int thread_stride_; | |||||
| float16_t *pack_input_; | |||||
| float16_t *pack_output_; | |||||
| float16_t *tmp_buffer_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DECONVOLUTION_H_ | |||||
| @@ -24,10 +24,10 @@ MatmulFp16Neon64: | |||||
| st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 | st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 | ||||
| st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 | st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 | ||||
| mov w18, #32 // sizeof(float) * 8 | |||||
| mov w18, #16 // sizeof(float) * 8 | |||||
| mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth | mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth | ||||
| mov x11, x3 // bias flag | mov x11, x3 // bias flag | ||||
| mov x18, #4 | |||||
| mov x18, #2 | |||||
| ldr x17, [sp] | ldr x17, [sp] | ||||
| mul x17, x17, x18 | mul x17, x17, x18 | ||||
| @@ -0,0 +1,459 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| //.p2align 5,,15 | |||||
| .global PostFuncBiasReluC8Fp16 | |||||
| #ifndef __APPLE__ | |||||
| .type PostFuncBiasReluC8Fp16, %function | |||||
| #endif | |||||
| //void PostFuncBiasReluC8Fp16(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod | |||||
| // size_t plane_size, size_t stride, int relu_type); | |||||
| // x0 dst x1 srx x2 bias | |||||
| // x3 oc8div x4 oc8mod x5 plane_size | |||||
| // x6 stride x7 relu_type | |||||
| // v0 ~ v7 value | |||||
| // v16 bias data | |||||
| // x24 x25 weite loop tmp buf | |||||
| // x26 relu6 #6; x27 relu #0 | |||||
| // w10 oc8 loop control | |||||
| // w13 hw loop control | |||||
| PostFuncBiasReluC8Fp16: | |||||
| movi v26.8h, #6 | |||||
| scvtf v26.8h, v26.8h | |||||
| dup v27.8h, wzr | |||||
| mov w10, #0 | |||||
| Loop_C8: | |||||
| cmp w10, w3 | |||||
| beq Loop_C1 | |||||
| mov x25, #4 | |||||
| mul x24, x10, x25 | |||||
| add x25, x0, x24 | |||||
| add w10, w10, #8 | |||||
| mov w13, w5 | |||||
| ld1 {v16.8h}, [x2], #16 | |||||
| Loop8x8: | |||||
| cmp w13, #8 | |||||
| blt Loop_4x8 | |||||
| sub w13, w13, #8 | |||||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 | |||||
| ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fadd v1.8h, v1.8h, v16.8h | |||||
| fadd v2.8h, v2.8h, v16.8h | |||||
| fadd v3.8h, v3.8h, v16.8h | |||||
| fadd v4.8h, v4.8h, v16.8h | |||||
| fadd v5.8h, v5.8h, v16.8h | |||||
| fadd v6.8h, v6.8h, v16.8h | |||||
| fadd v7.8h, v7.8h, v16.8h | |||||
| cmp w7, #2 | |||||
| beq Relu6_8x8 | |||||
| cmp w7, #1 | |||||
| beq Relu_8x8 | |||||
| b Write_8x8 | |||||
| Relu6_8x8: | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmin v1.8h, v1.8h, v26.8h | |||||
| fmin v2.8h, v2.8h, v26.8h | |||||
| fmin v3.8h, v3.8h, v26.8h | |||||
| fmin v4.8h, v4.8h, v26.8h | |||||
| fmin v5.8h, v5.8h, v26.8h | |||||
| fmin v6.8h, v6.8h, v26.8h | |||||
| fmin v7.8h, v7.8h, v26.8h | |||||
| Relu_8x8: | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| fmax v1.8h, v1.8h, v27.8h | |||||
| fmax v2.8h, v2.8h, v27.8h | |||||
| fmax v3.8h, v3.8h, v27.8h | |||||
| fmax v4.8h, v4.8h, v27.8h | |||||
| fmax v5.8h, v5.8h, v27.8h | |||||
| fmax v6.8h, v6.8h, v27.8h | |||||
| fmax v7.8h, v7.8h, v27.8h | |||||
| Write_8x8: | |||||
| st1 {v0.8h}, [x25], x6 | |||||
| st1 {v1.8h}, [x25], x6 | |||||
| st1 {v2.8h}, [x25], x6 | |||||
| st1 {v3.8h}, [x25], x6 | |||||
| st1 {v4.8h}, [x25], x6 | |||||
| st1 {v5.8h}, [x25], x6 | |||||
| st1 {v6.8h}, [x25], x6 | |||||
| st1 {v7.8h}, [x25], x6 | |||||
| b Loop8x8 | |||||
| Loop_4x8: | |||||
| cmp w13, #4 | |||||
| blt Loop_1x8 | |||||
| sub w13, w13, #4 | |||||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fadd v1.8h, v1.8h, v16.8h | |||||
| fadd v2.8h, v2.8h, v16.8h | |||||
| fadd v3.8h, v3.8h, v16.8h | |||||
| cmp w7, #2 | |||||
| beq Relu6_4x8 | |||||
| cmp w7, #1 | |||||
| beq Relu_4x8 | |||||
| b Write_4x8 | |||||
| Relu6_4x8: | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmin v1.8h, v1.8h, v26.8h | |||||
| fmin v2.8h, v2.8h, v26.8h | |||||
| fmin v3.8h, v3.8h, v26.8h | |||||
| Relu_4x8: | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| fmax v1.8h, v1.8h, v27.8h | |||||
| fmax v2.8h, v2.8h, v27.8h | |||||
| fmax v3.8h, v3.8h, v27.8h | |||||
| Write_4x8: | |||||
| st1 {v0.8h}, [x25], x6 | |||||
| st1 {v1.8h}, [x25], x6 | |||||
| st1 {v2.8h}, [x25], x6 | |||||
| st1 {v3.8h}, [x25], x6 | |||||
| Loop_1x8: | |||||
| cmp w7, #2 | |||||
| beq Relu6_1x8 | |||||
| cmp w7, #1 | |||||
| beq Relu_1x8 | |||||
| b Write_1x8 | |||||
| Relu6_1x8: | |||||
| cmp w13, #0 | |||||
| beq Loop_C8 | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.8h}, [x25], x6 | |||||
| b Relu6_1x8 | |||||
| Relu_1x8: | |||||
| cmp w13, #0 | |||||
| beq Loop_C8 | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.8h}, [x25], x6 | |||||
| b Relu_1x8 | |||||
| Write_1x8: | |||||
| cmp w13, #0 | |||||
| beq Loop_C8 | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v0.8h}, [x25], x6 | |||||
| b Write_1x8 | |||||
| Loop_C1: | |||||
| cmp x4, #0 | |||||
| beq End | |||||
| mov w13, w5 | |||||
| ld1 {v16.8h}, [x2], #16 | |||||
| cmp x4, #1 | |||||
| beq Loop_C1_1 | |||||
| cmp x4, #2 | |||||
| beq Loop_C1_2 | |||||
| cmp x4, #3 | |||||
| beq Loop_C1_3 | |||||
| cmp x4, #4 | |||||
| beq Loop_C1_4 | |||||
| cmp x4, #5 | |||||
| beq Loop_C1_5 | |||||
| cmp x4, #6 | |||||
| beq Loop_C1_6 | |||||
| cmp x4, #7 | |||||
| beq Loop_C1_7 | |||||
| Loop_C1_1: | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_1_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_1_Relu | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_1_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| b Loop_C1_1_Relu6 | |||||
| Loop_C1_1_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| b Loop_C1_1_Relu | |||||
| Loop_C1_1_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_2: | |||||
| add x24, x0, #2 | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_2_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_2_Relu | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_2_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| b Loop_C1_2_Relu6 | |||||
| Loop_C1_2_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| b Loop_C1_2_Relu | |||||
| Loop_C1_2_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_3: | |||||
| add x24, x0, #2 | |||||
| add x25, x0, #4 | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_3_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_3_Relu | |||||
| b Loop_C1_3_Write | |||||
| Loop_C1_3_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| st1 {v1.h}[2], [x25], x6 | |||||
| b Loop_C1_3_Relu6 | |||||
| Loop_C1_3_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| st1 {v1.h}[2], [x25], x6 | |||||
| b Loop_C1_3_Relu | |||||
| Loop_C1_3_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v1.h}[0], [x0], x6 | |||||
| st1 {v1.h}[1], [x24], x6 | |||||
| st1 {v1.h}[2], [x25], x6 | |||||
| b Loop_C1_3_Write | |||||
| Loop_C1_4: | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_4_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_4_Relu | |||||
| b Loop_C1_4_Write | |||||
| Loop_C1_4_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| b Loop_C1_4_Relu6 | |||||
| Loop_C1_4_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| b Loop_C1_4_Relu6 | |||||
| Loop_C1_4_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| b Loop_C1_4_Write | |||||
| Loop_C1_5: | |||||
| add x25, x0, #16 | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_5_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_5_Relu | |||||
| b Loop_C1_5_Write | |||||
| Loop_C1_5_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| str h1, [x25] | |||||
| add x25, x25, x6 | |||||
| b Loop_C1_5_Relu6 | |||||
| Loop_C1_5_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| str h1, [x25] | |||||
| add x25, x25, x6 | |||||
| b Loop_C1_5_Relu | |||||
| Loop_C1_5_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| str h1, [x25] | |||||
| add x25, x25, x6 | |||||
| b Loop_C1_5_Write | |||||
| Loop_C1_6: | |||||
| add x23, x0, #8 | |||||
| add x24, x0, #10 | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_6_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_6_Relu | |||||
| b Loop_C1_6_Write | |||||
| Loop_C1_6_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| b Loop_C1_6_Relu6 | |||||
| Loop_C1_6_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| b Loop_C1_6_Relu | |||||
| Loop_C1_6_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| b Loop_C1_6_Write | |||||
| Loop_C1_7: | |||||
| add x23, x0, #8 | |||||
| add x24, x0, #10 | |||||
| add x25, x0, #12 | |||||
| cmp w7, #2 | |||||
| beq Loop_C1_7_Relu6 | |||||
| cmp w7, #1 | |||||
| beq Loop_C1_7_Relu | |||||
| b Loop_C1_7_Write | |||||
| Loop_C1_7_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmin v0.8h, v0.8h, v26.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| st1 {v1.h}[6], [x25], x6 | |||||
| b Loop_C1_7_Relu6 | |||||
| Loop_C1_7_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| st1 {v1.h}[6], [x25], x6 | |||||
| b Loop_C1_7_Relu | |||||
| Loop_C1_7_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.8h}, [x1], #16 | |||||
| fadd v0.8h, v0.8h, v16.8h | |||||
| fmax v0.8h, v0.8h, v27.8h | |||||
| st1 {v0.4h}, [x0], x6 | |||||
| st1 {v1.h}[4], [x23], x6 | |||||
| st1 {v1.h}[5], [x24], x6 | |||||
| st1 {v1.h}[6], [x25], x6 | |||||
| b Loop_C1_7_Write | |||||
| End: | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,119 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp16/deconv_fp16.h" | |||||
| void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr, | |||||
| size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6, | |||||
| int size) { | |||||
| for (int oc = 0; oc < output_channel; oc++) { | |||||
| int oc_div = oc / size, oc_mod = oc % size; | |||||
| for (int hw = 0; hw < plane_size; hw++) { | |||||
| int src_index = oc_div * size * plane_size + hw * size + oc_mod; | |||||
| int dst_index = hw * stride + oc; | |||||
| float16_t value = src_ptr_[src_index]; | |||||
| if (bias_ptr != NULL) { | |||||
| value = value + bias_ptr[oc]; | |||||
| } | |||||
| value = (is_relu || is_relu6) ? (MSMAX(0.f, value)) : (value); | |||||
| value = (is_relu6) ? (MSMIN(6.f, value)) : (value); | |||||
| out_ptr[dst_index] = value; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, | |||||
| size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6) { | |||||
| #ifdef DEBUG_CODE | |||||
| PostConvFuncCommFp16(out_ptr, c8_out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM); | |||||
| #else | |||||
| size_t oc8mod = output_channel % C8NUM; | |||||
| size_t oc8div = output_channel - oc8mod; | |||||
| size_t stride_size = stride * sizeof(float16_t); | |||||
| size_t relu_type = is_relu ? 1 : 0; | |||||
| relu_type = is_relu6 ? 2 : relu_type; | |||||
| PostFuncBiasReluC8Fp16(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); | |||||
| #endif | |||||
| return; | |||||
| } | |||||
| int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, | |||||
| ConvParameter *conv_param) { | |||||
| /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ | |||||
| size_t input_plane = conv_param->input_w_ * conv_param->input_h_; | |||||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | |||||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | |||||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||||
| int in_plane16 = UP_ROUND(input_plane, 16); | |||||
| int src_iw_stride = C8NUM; | |||||
| int src_ih_stride = conv_param->input_w_ * C8NUM; | |||||
| int src_kw_stride = in_plane16 * C8NUM; | |||||
| int src_kh_stride = in_plane16 * conv_param->kernel_w_ * C8NUM; | |||||
| int dst_oh_stride = conv_param->output_w_ * C8NUM; | |||||
| int dst_ow_stride = C8NUM; | |||||
| int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; | |||||
| int dst_kw_stride = conv_param->dilation_w_ * C8NUM; | |||||
| for (int c = 0; c < oc8; c += 8) { | |||||
| float16_t *dst_ptr = tmp + c * output_plane; | |||||
| const float16_t *src_ptr = src + c * in_plane16 * kernel_plane; | |||||
| memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float16_t)); | |||||
| for (int ih = 0; ih < conv_param->input_h_; ih++) { | |||||
| for (int iw = 0; iw < conv_param->input_w_; iw++) { | |||||
| int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; | |||||
| int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; | |||||
| int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); | |||||
| int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); | |||||
| int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); | |||||
| int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); | |||||
| for (int kh = kh_start; kh < kh_end; kh++) { | |||||
| for (int kw = kw_start; kw < kw_end; kw++) { | |||||
| int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; | |||||
| int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; | |||||
| float16_t *tmp_dst = dst_ptr + dst_index; | |||||
| const float16_t *tmp_src = src_ptr + src_index; | |||||
| #ifdef DEBUG_CODE | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_dst[i] += tmp_src[i]; | |||||
| } | |||||
| #else | |||||
| asm volatile( | |||||
| "mov x0, %[tmp_src] \n" | |||||
| "mov x1, %[tmp_dst] \n" | |||||
| "ld1 {v0.8h}, [x0] \n" | |||||
| "ld1 {v1.8h}, [x1] \n" | |||||
| "fadd v0.8h, v0.8h, v1.8h \n" | |||||
| "st1 {v0.8h}, [x1] \n" | |||||
| : | |||||
| : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) | |||||
| : "x0", "x1", "v0", "v1"); | |||||
| #endif | |||||
| } /*kw*/ | |||||
| } /*kh*/ | |||||
| } /*iw*/ | |||||
| } /*ih*/ | |||||
| } /*oc8*/ | |||||
| PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, | |||||
| conv_param->is_relu6_); | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ | |||||
| #include <string.h> | |||||
| #include <arm_neon.h> | |||||
| #include "nnacl/conv_parameter.h" | |||||
| #include "nnacl/matmul_parameter.h" | |||||
| #include "nnacl/fp16/matmul_fp16.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, | |||||
| ConvParameter *conv_param); | |||||
| void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, | |||||
| size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6); | |||||
| void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, | |||||
| size_t plane_size, size_t stride, size_t relu_type); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_DECONV_FP16_H_ | |||||
| @@ -15,10 +15,37 @@ | |||||
| */ | */ | ||||
| #include "nnacl/fp16/matmul_fp16.h" | #include "nnacl/fp16/matmul_fp16.h" | ||||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, bool write_nhwc) { | |||||
| int row_16 = UP_ROUND(row, C16NUM); | |||||
| int col_8 = UP_ROUND(col, C8NUM); | |||||
| /* col16-major * row8-major => row16x8-major */ | |||||
| if (write_nhwc) return; | |||||
| for (int r = 0; r < row_16; r++) { | |||||
| for (int c = 0; c < col_8; c++) { | |||||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | |||||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||||
| size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| if (bias != NULL) value += bias[col]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | ||||
| int depth, int row, int col, int stride, bool write_nhwc) { | int depth, int row, int col, int stride, bool write_nhwc) { | ||||
| MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); | |||||
| // MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); | |||||
| MatMul16x8(a, b, c, bias, (int)act_type, depth, row, col, stride, write_nhwc); | |||||
| return; | |||||
| } | } | ||||
| void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | ||||
| @@ -33,10 +33,10 @@ void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const floa | |||||
| int depth, int row, int col, int stride, bool write_nhwc); | int depth, int row, int col, int stride, bool write_nhwc); | ||||
| void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | void RowMajor2Col8MajorFp16(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | ||||
| #ifdef __aarch64__ | |||||
| void MatmulFp16Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||||
| int col, int stride, bool write_nhwc); | |||||
| #endif | |||||
| void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -369,6 +369,21 @@ void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, i | |||||
| } | } | ||||
| } | } | ||||
| void PackNHWCFp32ToC8HWN8Fp16(float *src, float16_t *dst, int batch, int plane, int channel) { | |||||
| for (int n = 0; n < batch; n++) { | |||||
| for (int hw = 0; hw < plane; hw++) { | |||||
| for (int c = 0; c < channel; c++) { | |||||
| int c8div = c / C8NUM; | |||||
| int c8mod = c % C8NUM; | |||||
| int src_index = n * plane * channel + hw * channel + c; | |||||
| int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; | |||||
| dst[dst_index] = (float16_t)(src[src_index]); | |||||
| } | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { | void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel) { | ||||
| int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; | ||||
| for (int b = 0; b < batch; b++) { | for (int b = 0; b < batch; b++) { | ||||
| @@ -59,6 +59,8 @@ void PackNCHWFp32ToNC8HW8Fp16(float *src, float16_t *dst, int batch, int plane, | |||||
| void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); | void PackNHWCFp32ToNHWC8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); | ||||
| void PackNHWCFp32ToC8HWN8Fp16(float *src, float16_t *dst, int batch, int plane, int channel); | |||||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | ||||
| void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | void PackNHWCToNHWC8Fp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | ||||
| @@ -120,7 +120,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||||
| } | } | ||||
| void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, | void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, | ||||
| size_t data_lenth) { | |||||
| size_t data_lenth) { | |||||
| size_t copy_size = col * data_lenth; | size_t copy_size = col * data_lenth; | ||||
| size_t src_size = src_stride * data_lenth; | size_t src_size = src_stride * data_lenth; | ||||
| size_t dst_size = dst_stride * data_lenth; | size_t dst_size = dst_stride * data_lenth; | ||||