| @@ -1,3 +1,5 @@ | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) | |||
| file(GLOB KERNEL_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc | |||
| nnacl/*.cc | |||
| @@ -23,6 +23,7 @@ | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -242,7 +243,7 @@ int ConvolutionFP16CPUKernel::Run() { | |||
| auto out_tensor = outputs_.at(kOutputIndex); | |||
| auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); | |||
| for (int j = 0; j < out_tensor->ElementsNum(); ++j) { | |||
| output_addr[j] = static_cast<float >(fp16_out_[j]); | |||
| output_addr[j] = static_cast<float>(fp16_out_[j]); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -264,20 +265,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten | |||
| conv_param->input_w_ = inputs.front()->Width(); | |||
| conv_param->output_h_ = outputs.front()->Height(); | |||
| conv_param->output_w_ = outputs.front()->Width(); | |||
| kernel::LiteKernel *kernel; | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| } else { | |||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| bool use_winograd = false; | |||
| int out_unit; | |||
| InputTransformUnitFunc input_trans_func = nullptr; | |||
| OutputTransformUnitFunc output_trans_func = nullptr; | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | |||
| if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { | |||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); | |||
| } | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "Create conv fp16 kernel failed."; | |||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| @@ -220,32 +220,6 @@ int ConvolutionCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, | |||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { | |||
| if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && | |||
| conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { | |||
| *output_unit = SelectOutputUnit(conv_param); | |||
| if (*output_unit > 1) { | |||
| *use_winograd = true; | |||
| int input_unit = conv_param->kernel_h_ + *output_unit - 1; | |||
| input_trans_func = GetInputTransFunc(input_unit); | |||
| if (input_trans_func == nullptr) { | |||
| MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; | |||
| *use_winograd = false; | |||
| } | |||
| output_trans_func = GetOutputTransFunc(input_unit, *output_unit); | |||
| if (output_trans_func == nullptr) { | |||
| MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; | |||
| *use_winograd = false; | |||
| } | |||
| } else { | |||
| *use_winograd = false; | |||
| } | |||
| } else { | |||
| *use_winograd = false; | |||
| } | |||
| } | |||
| kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const Context *ctx, | |||
| @@ -270,7 +244,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); | |||
| kernel::LiteKernel *kernel; | |||
| if (kernel_h == 1 && kernel_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); | |||
| // kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx); | |||
| kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx); | |||
| } else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx); | |||
| } else if (use_winograd) { | |||
| @@ -4708,3 +4708,28 @@ OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, | |||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { | |||
| if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && | |||
| conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { | |||
| *output_unit = SelectOutputUnit(conv_param); | |||
| if (*output_unit > 1) { | |||
| *use_winograd = true; | |||
| int input_unit = conv_param->kernel_h_ + *output_unit - 1; | |||
| input_trans_func = GetInputTransFunc(input_unit); | |||
| if (input_trans_func == nullptr) { | |||
| *use_winograd = false; | |||
| } | |||
| output_trans_func = GetOutputTransFunc(input_unit, *output_unit); | |||
| if (output_trans_func == nullptr) { | |||
| *use_winograd = false; | |||
| } | |||
| } else { | |||
| *use_winograd = false; | |||
| } | |||
| } else { | |||
| *use_winograd = false; | |||
| } | |||
| } | |||
| @@ -54,5 +54,7 @@ InputTransformUnitFunc GetInputTransFunc(int input_unit); | |||
| OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit); | |||
| void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, | |||
| InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_ | |||