|
|
|
@@ -23,6 +23,7 @@ |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
#include "nnacl/winograd_utils.h" |
|
|
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH::kCPU; |
|
|
|
using mindspore::lite::KernelRegistrar; |
|
|
|
@@ -242,7 +243,7 @@ int ConvolutionFP16CPUKernel::Run() { |
|
|
|
auto out_tensor = outputs_.at(kOutputIndex); |
|
|
|
auto output_addr = reinterpret_cast<float *>(out_tensor->Data()); |
|
|
|
for (int j = 0; j < out_tensor->ElementsNum(); ++j) { |
|
|
|
output_addr[j] = static_cast<float >(fp16_out_[j]); |
|
|
|
output_addr[j] = static_cast<float>(fp16_out_[j]); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -264,20 +265,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::tensor::Ten |
|
|
|
conv_param->input_w_ = inputs.front()->Width(); |
|
|
|
conv_param->output_h_ = outputs.front()->Height(); |
|
|
|
conv_param->output_w_ = outputs.front()->Width(); |
|
|
|
kernel::LiteKernel *kernel; |
|
|
|
kernel::LiteKernel *kernel = nullptr; |
|
|
|
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { |
|
|
|
kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); |
|
|
|
} else { |
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); |
|
|
|
bool use_winograd = false; |
|
|
|
int out_unit; |
|
|
|
InputTransformUnitFunc input_trans_func = nullptr; |
|
|
|
OutputTransformUnitFunc output_trans_func = nullptr; |
|
|
|
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); |
|
|
|
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { |
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); |
|
|
|
} |
|
|
|
} |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Create conv fp16 kernel failed."; |
|
|
|
MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto ret = kernel->Init(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
delete kernel; |
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " |
|
|
|
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: " |
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|