From: @fuzhiye Reviewed-by: @zhang_xue_tong Signed-off-by:tags/v1.2.0-rc1
| @@ -43,12 +43,9 @@ void StridedSliceCPUKernel::InitFastRunParam() { | |||||
| for (int i = 0; i < split_axis_; ++i) { | for (int i = 0; i < split_axis_; ++i) { | ||||
| outer_ *= in_shape[i]; | outer_ *= in_shape[i]; | ||||
| } | } | ||||
| int inner = 1; | |||||
| for (size_t i = split_axis_ + 1; i < in_shape.size(); i++) { | for (size_t i = split_axis_ + 1; i < in_shape.size(); i++) { | ||||
| inner *= in_shape[i]; | |||||
| inner_ *= in_shape[i]; | |||||
| } | } | ||||
| inner_size_ = in_tensors_.front()->Size() / in_tensors_.front()->ElementsNum() * inner; | |||||
| // decide multi-thread launch strategy | // decide multi-thread launch strategy | ||||
| if (outer_ == 1) { | if (outer_ == 1) { | ||||
| parallel_on_split_axis_ = true; | parallel_on_split_axis_ = true; | ||||
| @@ -142,6 +139,26 @@ int StrideRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int StridedSliceCPUKernel::FastRun() { | int StridedSliceCPUKernel::FastRun() { | ||||
| // Update length of inner size, because data type of tensor may be changed | |||||
| // from float32 to float16 during fp16 sub-graph partition process. | |||||
| auto input = in_tensors_.front(); | |||||
| switch (input->data_type()) { | |||||
| case kNumberTypeInt8: | |||||
| inner_size_ = inner_ * sizeof(int8_t); | |||||
| break; | |||||
| case kNumberTypeFloat32: | |||||
| inner_size_ = inner_ * sizeof(float); | |||||
| break; | |||||
| case kNumberTypeFloat16: | |||||
| inner_size_ = inner_ * sizeof(int16_t); | |||||
| break; | |||||
| case kNumberTypeInt32: | |||||
| inner_size_ = inner_ * sizeof(int32_t); | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Not supported data type: " << input->data_type(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| input_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.front()->data_c()); | input_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.front()->data_c()); | ||||
| output_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.front()->data_c()); | output_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.front()->data_c()); | ||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, StrideRun, this, context_->thread_num_); | auto ret = ParallelLaunch(this->context_->thread_pool_, StrideRun, this, context_->thread_num_); | ||||
| @@ -46,9 +46,10 @@ class StridedSliceCPUKernel : public LiteKernel { | |||||
| uint8_t *input_ptr_ = nullptr; | uint8_t *input_ptr_ = nullptr; | ||||
| uint8_t *output_ptr_ = nullptr; | uint8_t *output_ptr_ = nullptr; | ||||
| int split_axis_{-1}; | int split_axis_{-1}; | ||||
| int inner_{1}; | |||||
| int outer_{1}; | int outer_{1}; | ||||
| int cal_num_per_thread_{1}; | int cal_num_per_thread_{1}; | ||||
| size_t inner_size_{0}; | |||||
| size_t inner_size_{1}; | |||||
| bool fast_run_{false}; | bool fast_run_{false}; | ||||
| bool parallel_on_split_axis_{false}; | bool parallel_on_split_axis_{false}; | ||||
| bool parallel_on_outer_{false}; | bool parallel_on_outer_{false}; | ||||
| @@ -97,10 +97,10 @@ int ConcatFp16CPUKernel::Run() { | |||||
| for (size_t i = 0; i < input_num; ++i) { | for (size_t i = 0; i < input_num; ++i) { | ||||
| const auto in_tensor = in_tensors_.at(i); | const auto in_tensor = in_tensors_.at(i); | ||||
| if (in_tensor->data_type() == kNumberTypeFloat || in_tensor->data_type() == kNumberTypeFloat32) { | if (in_tensor->data_type() == kNumberTypeFloat || in_tensor->data_type() == kNumberTypeFloat32) { | ||||
| auto in_tensor_data = reinterpret_cast<float *>(in_tensor->MutableData()); | |||||
| auto in_tensor_data = reinterpret_cast<float *>(in_tensor->data_c()); | |||||
| Float32ToFloat16(in_tensor_data, fp16_inputs_[i], in_tensor->ElementsNum()); | Float32ToFloat16(in_tensor_data, fp16_inputs_[i], in_tensor->ElementsNum()); | ||||
| } else { | } else { | ||||
| fp16_inputs_[i] = reinterpret_cast<float16_t *>(in_tensor->MutableData()); | |||||
| fp16_inputs_[i] = reinterpret_cast<float16_t *>(in_tensor->data_c()); | |||||
| } | } | ||||
| shapes.push_back(in_tensors_[i]->shape()); | shapes.push_back(in_tensors_[i]->shape()); | ||||
| @@ -110,7 +110,7 @@ int ConcatFp16CPUKernel::Run() { | |||||
| inputs_output_shape[input_num] = output_shape.data(); | inputs_output_shape[input_num] = output_shape.data(); | ||||
| auto output_addr = out_tensors_.at(0)->MutableData(); | auto output_addr = out_tensors_.at(0)->MutableData(); | ||||
| if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { | if (out_tensors_.at(0)->data_type() == kNumberTypeFloat16) { | ||||
| fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData()); | |||||
| fp16_output_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c()); | |||||
| } | } | ||||
| int dtype_len = in_tensors_.at(0)->data_type() == kNumberTypeInt32 ? sizeof(int32_t) : sizeof(float16_t); | int dtype_len = in_tensors_.at(0)->data_type() == kNumberTypeInt32 ? sizeof(int32_t) : sizeof(float16_t); | ||||
| @@ -93,7 +93,13 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memcpy(bias_data_, fp16_bias_, output_channel * sizeof(float16_t)); | |||||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | |||||
| if (bias_tensor->data_type() == kNumberTypeFloat16) { | |||||
| memcpy(bias_data_, origin_bias_, output_channel * sizeof(float16_t)); | |||||
| } else { | |||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), | |||||
| output_channel); | |||||
| } | |||||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | ||||
| } | } | ||||
| @@ -105,7 +111,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | ||||
| ColMajor2Row8MajorFp16(fp16_weight_, weight_ptr_, input_channel, output_channel, true); | |||||
| ColMajor2Row8MajorFp16(origin_weight_, weight_ptr_, input_channel, output_channel, | |||||
| weight_tensor->data_type() == kNumberTypeFloat16); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -30,11 +30,10 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | Convolution1x1FP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive, float16_t *fp16_weight, | |||||
| float16_t *fp16_bias) | |||||
| const mindspore::lite::PrimitiveC *primitive, void *origin_weight, void *origin_bias) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | ||||
| fp16_weight_(fp16_weight), | |||||
| fp16_bias_(fp16_bias) {} | |||||
| origin_weight_(origin_weight), | |||||
| origin_bias_(origin_bias) {} | |||||
| ~Convolution1x1FP16CPUKernel() override; | ~Convolution1x1FP16CPUKernel() override; | ||||
| int Init() override; | int Init() override; | ||||
| @@ -56,8 +55,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| bool multi_thread_by_hw_ = false; | bool multi_thread_by_hw_ = false; | ||||
| int thread_count_ = 1; | int thread_count_ = 1; | ||||
| int thread_stride_ = 0; | int thread_stride_ = 0; | ||||
| float16_t *fp16_weight_; // do not free | |||||
| float16_t *fp16_bias_; // do not free | |||||
| void *origin_weight_; // do not free | |||||
| void *origin_bias_; // do not free | |||||
| float16_t *weight_ptr_ = nullptr; | float16_t *weight_ptr_ = nullptr; | ||||
| float16_t *input_ptr_ = nullptr; | float16_t *input_ptr_ = nullptr; | ||||
| float16_t *pack_input_ = nullptr; | float16_t *pack_input_ = nullptr; | ||||
| @@ -40,19 +40,13 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteTensor() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| int ConvolutionBaseFP16CPUKernel::GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data) { | |||||
| auto weight_data_type = weight_tensor->data_type(); | auto weight_data_type = weight_tensor->data_type(); | ||||
| auto input_channel = weight_tensor->Channel(); | |||||
| auto output_channel = weight_tensor->Batch(); | |||||
| auto kernel_h = weight_tensor->Height(); | |||||
| auto kernel_w = weight_tensor->Width(); | |||||
| MS_ASSERT(weight_data_type == kNumberTypeFloat32 || weight_data_type == kNumberTypeFloat16); | MS_ASSERT(weight_data_type == kNumberTypeFloat32 || weight_data_type == kNumberTypeFloat16); | ||||
| if (weight_data_type == kNumberTypeFloat32) { | if (weight_data_type == kNumberTypeFloat32) { | ||||
| float *origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->MutableData()); | |||||
| size_t fp16_weight_size = input_channel * output_channel * kernel_h * kernel_w * sizeof(float16_t); | |||||
| float *origin_weight = reinterpret_cast<float *>(origin_data); | |||||
| size_t fp16_weight_size = weight_tensor->Channel() * weight_tensor->Batch() * weight_tensor->Height() * | |||||
| weight_tensor->Width() * sizeof(float16_t); | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | fp16_weight_ = reinterpret_cast<float16_t *>(malloc(fp16_weight_size)); | ||||
| if (fp16_weight_ == nullptr) { | if (fp16_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | MS_LOG(ERROR) << "malloc fp16_weight_ failed."; | ||||
| @@ -63,8 +57,7 @@ int ConvolutionBaseFP16CPUKernel::GetExecuteFilter() { | |||||
| } | } | ||||
| execute_weight_ = fp16_weight_; | execute_weight_ = fp16_weight_; | ||||
| } else { | } else { | ||||
| auto *origin_weight = reinterpret_cast<float16_t *>(in_tensors_.at(kWeightIndex)->MutableData()); | |||||
| execute_weight_ = origin_weight; | |||||
| execute_weight_ = reinterpret_cast<float16_t *>(origin_data); | |||||
| fp16_weight_ = nullptr; | fp16_weight_ = nullptr; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -37,7 +37,10 @@ class ConvolutionBaseFP16CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int Run() override { return mindspore::lite::RET_OK; } | int Run() override { return mindspore::lite::RET_OK; } | ||||
| int RunImpl(int task_id) { return mindspore::lite::RET_OK; } | int RunImpl(int task_id) { return mindspore::lite::RET_OK; } | ||||
| virtual int GetExecuteTensor(); | virtual int GetExecuteTensor(); | ||||
| virtual int GetExecuteFilter(); | |||||
| // origin_data may not be the same as the data in the weight tensor, | |||||
| // because weight tensor has released data already. In this situation, | |||||
| // origin_data is the pointer of another memory block. | |||||
| virtual int GetExecuteFilter(lite::Tensor *weight_tensor, void *origin_data); | |||||
| protected: | protected: | ||||
| float16_t *fp16_weight_ = nullptr; | float16_t *fp16_weight_ = nullptr; | ||||
| @@ -35,99 +35,45 @@ using mindspore::schema::Format::Format_NHWC; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() { | void ConvolutionDelegateFP16CPUKernel::FreeCopiedData() { | ||||
| if ((fp16_weight_ != nullptr) && (need_free_ & WEIGHT_NEED_FREE)) { | |||||
| free(fp16_weight_); | |||||
| fp16_weight_ = nullptr; | |||||
| if ((origin_weight_ != nullptr) && (need_free_ & WEIGHT_NEED_FREE)) { | |||||
| free(origin_weight_); | |||||
| origin_weight_ = nullptr; | |||||
| } | } | ||||
| if ((fp16_bias_ != nullptr) && (need_free_ & BIAS_NEED_FREE)) { | |||||
| free(fp16_bias_); | |||||
| fp16_bias_ = nullptr; | |||||
| if ((origin_bias_ != nullptr) && (need_free_ & BIAS_NEED_FREE)) { | |||||
| free(origin_bias_); | |||||
| origin_bias_ = nullptr; | |||||
| } | } | ||||
| } | } | ||||
| int ConvolutionDelegateFP16CPUKernel::GetFp16WeightAndBias() { | |||||
| auto ret = GetFp16Weight(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Get Fp16 Weight failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ret = GetFp16Bias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Get Fp16 Bias failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDelegateFP16CPUKernel::GetFp16Weight() { | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| if (weight_tensor->data_type() == kNumberTypeFloat16 && InferShapeDone()) { | |||||
| // do not need malloc new memory to store origin data | |||||
| fp16_weight_ = reinterpret_cast<float16_t *>(weight_tensor->data_c()); | |||||
| return RET_OK; | |||||
| } else { | |||||
| fp16_weight_ = CopyData(weight_tensor); | |||||
| if (fp16_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Generate fp16_weight failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| need_free_ = need_free_ | WEIGHT_NEED_FREE; | |||||
| return RET_OK; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int ConvolutionDelegateFP16CPUKernel::GetFp16Bias() { | |||||
| if (in_tensors_.size() == 3) { | |||||
| // has bias situation | |||||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | |||||
| if (bias_tensor->data_type() == kNumberTypeFloat16 && InferShapeDone()) { | |||||
| // do not need malloc new memory to store origin data | |||||
| fp16_bias_ = reinterpret_cast<float16_t *>(bias_tensor->data_c()); | |||||
| return RET_OK; | |||||
| } else { | |||||
| fp16_bias_ = CopyData(bias_tensor); | |||||
| if (fp16_bias_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Generate fp16_bias failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| need_free_ = need_free_ | BIAS_NEED_FREE; | |||||
| return RET_OK; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| float16_t *ConvolutionDelegateFP16CPUKernel::CopyData(lite::Tensor *tensor) { | |||||
| void *ConvolutionDelegateFP16CPUKernel::CopyData(lite::Tensor *tensor) { | |||||
| auto data_type = tensor->data_type(); | auto data_type = tensor->data_type(); | ||||
| MS_ASSERT(data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat16); | |||||
| auto fp16_data = reinterpret_cast<float16_t *>(malloc(tensor->ElementsNum() * sizeof(float16_t))); | |||||
| if (fp16_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc fp16_data failed."; | |||||
| if (data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16) { | |||||
| MS_LOG(ERROR) << "Not supported data type: " << data_type; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (data_type == kNumberTypeFloat32) { | |||||
| float *origin_data = reinterpret_cast<float *>(tensor->data_c()); | |||||
| for (size_t i = 0; i < tensor->ElementsNum(); ++i) { | |||||
| fp16_data[i] = (float16_t)origin_data[i]; | |||||
| } | |||||
| } else { | |||||
| auto *origin_data = reinterpret_cast<float16_t *>(tensor->data_c()); | |||||
| memcpy(fp16_data, origin_data, tensor->Size()); | |||||
| auto copied_data = malloc(tensor->Size()); | |||||
| if (copied_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc copied_data failed."; | |||||
| return nullptr; | |||||
| } | } | ||||
| return fp16_data; | |||||
| memcpy(copied_data, tensor->data_c(), tensor->Size()); | |||||
| return copied_data; | |||||
| } | } | ||||
| int ConvolutionDelegateFP16CPUKernel::Init() { | int ConvolutionDelegateFP16CPUKernel::Init() { | ||||
| auto ret = GetFp16WeightAndBias(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Get fp16 weight and bias failed."; | |||||
| return ret; | |||||
| } | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| origin_weight_ = CopyData(in_tensors_.at(kWeightIndex)); | |||||
| need_free_ = need_free_ | WEIGHT_NEED_FREE; | |||||
| if (in_tensors_.size() == 3) { | |||||
| origin_bias_ = CopyData(in_tensors_.at(kBiasIndex)); | |||||
| need_free_ = need_free_ | BIAS_NEED_FREE; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| origin_weight_ = in_tensors_.at(kWeightIndex)->data_c(); | |||||
| if (in_tensors_.size() == 3) { | |||||
| origin_bias_ = in_tensors_.at(kBiasIndex)->data_c(); | |||||
| } | |||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| @@ -136,8 +82,8 @@ int ConvolutionDelegateFP16CPUKernel::ReSize() { | |||||
| SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(op_parameter_), in_tensors_.front(), out_tensors_.front(), | SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(op_parameter_), in_tensors_.front(), out_tensors_.front(), | ||||
| context_); | context_); | ||||
| if (fp16_conv_kernel_ == nullptr) { | if (fp16_conv_kernel_ == nullptr) { | ||||
| fp16_conv_kernel_ = | |||||
| CpuConvFp16KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, primitive_, fp16_weight_, fp16_bias_); | |||||
| fp16_conv_kernel_ = CpuConvFp16KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, primitive_, | |||||
| origin_weight_, origin_bias_); | |||||
| if (fp16_conv_kernel_ == nullptr) { | if (fp16_conv_kernel_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; | MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -161,7 +107,7 @@ ConvParameter *CreateNewConvParameterFp16(ConvParameter *parameter) { | |||||
| kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | ||||
| float16_t *fp16_weight, float16_t *fp16_bias) { | |||||
| void *origin_weight, void *origin_bias) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | ||||
| bool use_winograd = false; | bool use_winograd = false; | ||||
| int out_unit; | int out_unit; | ||||
| @@ -169,13 +115,13 @@ kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector<lite::Tensor *> &i | |||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | ||||
| kernel = new (std::nothrow) | kernel = new (std::nothrow) | ||||
| kernel::Convolution1x1FP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, fp16_weight, fp16_bias); | |||||
| kernel::Convolution1x1FP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, origin_weight, origin_bias); | |||||
| } else if (use_winograd) { | } else if (use_winograd) { | ||||
| kernel = new (std::nothrow) kernel::ConvolutionWinogradFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, | kernel = new (std::nothrow) kernel::ConvolutionWinogradFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, | ||||
| out_unit, fp16_weight, fp16_bias); | |||||
| out_unit, origin_weight, origin_bias); | |||||
| } else { | } else { | ||||
| kernel = new (std::nothrow) | kernel = new (std::nothrow) | ||||
| kernel::ConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, fp16_weight, fp16_bias); | |||||
| kernel::ConvolutionFP16CPUKernel(op_parameter, inputs, outputs, ctx, primitive, origin_weight, origin_bias); | |||||
| } | } | ||||
| // Once kernel is selected, init func will invoke InitWeightAndBias | // Once kernel is selected, init func will invoke InitWeightAndBias | ||||
| auto ret = kernel->Init(); | auto ret = kernel->Init(); | ||||
| @@ -40,10 +40,7 @@ class ConvolutionDelegateFP16CPUKernel : public LiteKernel { | |||||
| fp16_conv_kernel_ = nullptr; | fp16_conv_kernel_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| int GetFp16WeightAndBias(); | |||||
| int GetFp16Weight(); | |||||
| int GetFp16Bias(); | |||||
| float16_t *CopyData(lite::Tensor *tensor); | |||||
| void *CopyData(lite::Tensor *tensor); | |||||
| void FreeCopiedData(); | void FreeCopiedData(); | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| @@ -51,15 +48,15 @@ class ConvolutionDelegateFP16CPUKernel : public LiteKernel { | |||||
| private: | private: | ||||
| uint8_t need_free_ = 0b00; | uint8_t need_free_ = 0b00; | ||||
| void *origin_weight_ = nullptr; | |||||
| void *origin_bias_ = nullptr; | |||||
| kernel::LiteKernel *fp16_conv_kernel_ = nullptr; | kernel::LiteKernel *fp16_conv_kernel_ = nullptr; | ||||
| float16_t *fp16_weight_ = nullptr; | |||||
| float16_t *fp16_bias_ = nullptr; | |||||
| }; | }; | ||||
| kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvFp16KernelSelect(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | ||||
| float16_t *fp16_weight, float16_t *fp16_bias); | |||||
| void *origin_weight, void *origin_bias); | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DELEGATE_FP16_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_DELEGATE_FP16_H_ | ||||
| @@ -39,8 +39,6 @@ ConvolutionDepthwiseFp16CPUKernel::~ConvolutionDepthwiseFp16CPUKernel() { | |||||
| int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { | int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { | ||||
| // init weight: o, h, w, i; o == group, i == 1 | // init weight: o, h, w, i; o == group, i == 1 | ||||
| ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | auto weight_tensor = in_tensors_.at(kWeightIndex); | ||||
| int channel = weight_tensor->Batch(); | int channel = weight_tensor->Batch(); | ||||
| int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width(); | int pack_weight_size = channel * weight_tensor->Height() * weight_tensor->Width(); | ||||
| @@ -50,6 +48,11 @@ int ConvolutionDepthwiseFp16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(weight_tensor, weight_tensor->data_c()); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "get execute filter data failed."; | |||||
| return ret; | |||||
| } | |||||
| PackNCHWToNHWCFp16(fp16_weight_, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), | PackNCHWToNHWCFp16(fp16_weight_, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), | ||||
| weight_tensor->Batch()); | weight_tensor->Batch()); | ||||
| if (fp16_weight_ != nullptr) { | if (fp16_weight_ != nullptr) { | ||||
| @@ -51,7 +51,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); | ||||
| RowMajor2Col8MajorFp16(fp16_weight_, packed_weight_, out_channel, in_channel * kernel_plane, false); | |||||
| RowMajor2Col8MajorFp16(origin_weight_, packed_weight_, out_channel, in_channel * kernel_plane, | |||||
| filter_tensor->data_type() == kNumberTypeFloat32); | |||||
| // init bias | // init bias | ||||
| bias_data_ = malloc(oc8 * sizeof(float16_t)); | bias_data_ = malloc(oc8 * sizeof(float16_t)); | ||||
| @@ -61,8 +62,12 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| memset(bias_data_, 0, oc8 * sizeof(float16_t)); | memset(bias_data_, 0, oc8 * sizeof(float16_t)); | ||||
| if (in_tensors_.size() == kInputSize2) { | if (in_tensors_.size() == kInputSize2) { | ||||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||||
| memcpy(fp16_bias_data, fp16_bias_, out_channel * sizeof(float16_t)); | |||||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | |||||
| if (bias_tensor->data_type() == kNumberTypeFloat16) { | |||||
| memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | |||||
| } else { | |||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel); | |||||
| } | |||||
| } else { | } else { | ||||
| MS_ASSERT(in_tensors_.size() == kInputSize1); | MS_ASSERT(in_tensors_.size() == kInputSize1); | ||||
| } | } | ||||
| @@ -27,10 +27,10 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ConvolutionFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive, float16_t *fp16_weight, float16_t *fp16_bias) | |||||
| const mindspore::lite::PrimitiveC *primitive, void *origin_weight, void *origin_bias) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | ||||
| fp16_weight_(fp16_weight), | |||||
| fp16_bias_(fp16_bias) {} | |||||
| origin_weight_(origin_weight), | |||||
| origin_bias_(origin_bias) {} | |||||
| ~ConvolutionFP16CPUKernel() override { | ~ConvolutionFP16CPUKernel() override { | ||||
| if (packed_weight_ != nullptr) { | if (packed_weight_ != nullptr) { | ||||
| free(packed_weight_); | free(packed_weight_); | ||||
| @@ -56,8 +56,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| col_major_input_ = nullptr; | col_major_input_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| float16_t *fp16_weight_; // do not free | |||||
| float16_t *fp16_bias_; // do not free | |||||
| void *origin_weight_; // do not free | |||||
| void *origin_bias_; // do not free | |||||
| float16_t *packed_input_ = nullptr; | float16_t *packed_input_ = nullptr; | ||||
| float16_t *packed_weight_ = nullptr; | float16_t *packed_weight_ = nullptr; | ||||
| float16_t *col_major_input_ = nullptr; | float16_t *col_major_input_ = nullptr; | ||||
| @@ -68,11 +68,21 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; | MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = WinogradFilterTransformFp16(fp16_origin_weight_, matrix_g, matrix_gt, oc_block); | |||||
| ret = GetExecuteFilter(filter_tensor, origin_weight_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "winograd filter transfrom failed."; | |||||
| MS_LOG(ERROR) << "get execute filter failed."; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "winograd filter transform failed."; | |||||
| return ret; | |||||
| } | |||||
| // if fp16_weight is malloced, free it. It will not be used in runtime anymore. | |||||
| if (fp16_weight_ != nullptr) { | |||||
| free(fp16_weight_); | |||||
| fp16_weight_ = nullptr; | |||||
| } | |||||
| // init bias | // init bias | ||||
| bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); | bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); | ||||
| @@ -81,9 +91,14 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); | memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); | ||||
| auto fp16_bias_data = reinterpret_cast<float16_t *>(bias_data_); | |||||
| if (in_tensors_.size() == kInputSize2) { | if (in_tensors_.size() == kInputSize2) { | ||||
| memcpy(fp16_bias_data, fp16_bias_, out_channel * sizeof(float16_t)); | |||||
| auto bias_tensor = in_tensors_.at(kBiasIndex); | |||||
| if (bias_tensor->data_type() == kNumberTypeFloat16) { | |||||
| memcpy(bias_data_, origin_bias_, out_channel * sizeof(float16_t)); | |||||
| } else { | |||||
| Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_), out_channel); | |||||
| } | |||||
| } else { | } else { | ||||
| MS_ASSERT(in_tensors_.size() == kInputSize1); | MS_ASSERT(in_tensors_.size() == kInputSize1); | ||||
| } | } | ||||
| @@ -31,12 +31,12 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| public: | public: | ||||
| ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ConvolutionWinogradFP16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive, int out_unit, float16_t *fp16_weight, | |||||
| float16_t *fp16_bias) | |||||
| const mindspore::lite::PrimitiveC *primitive, int out_unit, void *origin_weight, | |||||
| void *origin_bias) | |||||
| : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive), | ||||
| output_unit_(out_unit), | output_unit_(out_unit), | ||||
| fp16_origin_weight_(fp16_weight), | |||||
| fp16_bias_(fp16_bias) {} | |||||
| origin_weight_(origin_weight), | |||||
| origin_bias_(origin_bias) {} | |||||
| ~ConvolutionWinogradFP16CPUKernel() override { | ~ConvolutionWinogradFP16CPUKernel() override { | ||||
| if (trans_weight_ != nullptr) { | if (trans_weight_ != nullptr) { | ||||
| free(trans_weight_); | free(trans_weight_); | ||||
| @@ -75,8 +75,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| int kernel_unit_; | int kernel_unit_; | ||||
| int input_unit_; | int input_unit_; | ||||
| int output_unit_; | int output_unit_; | ||||
| float16_t *fp16_origin_weight_; // do not free | |||||
| float16_t *fp16_bias_; // do not free | |||||
| void *origin_weight_; // do not free | |||||
| void *origin_bias_; // do not free | |||||
| float16_t *tmp_data_ = nullptr; | float16_t *tmp_data_ = nullptr; | ||||
| float16_t *trans_input_ = nullptr; | float16_t *trans_input_ = nullptr; | ||||
| float16_t *gemm_out_ = nullptr; | float16_t *gemm_out_ = nullptr; | ||||
| @@ -319,11 +319,13 @@ int DeConvWinogradFp16CPUKernel::InitComputeParam() { | |||||
| int DeConvWinogradFp16CPUKernel::InitDataParam() { | int DeConvWinogradFp16CPUKernel::InitDataParam() { | ||||
| /* unit data : weight & winograd data*/ | /* unit data : weight & winograd data*/ | ||||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); | |||||
| auto weight_tensor = in_tensors_.at(kWeightIndex); | |||||
| auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(weight_tensor, weight_tensor->data_c()); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Get Execute filter failed."; | MS_LOG(ERROR) << "Get Execute filter failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| for (int i = 0; i < deconv_param_->compute_size_; i++) { | for (int i = 0; i < deconv_param_->compute_size_; i++) { | ||||
| DeConvComputeUnit *unit = &deconv_param_->compute_units_[i]; | DeConvComputeUnit *unit = &deconv_param_->compute_units_[i]; | ||||
| ret = PackDeConvWgDataFp16(execute_weight_, unit, conv_param_, deconv_param_); | ret = PackDeConvWgDataFp16(execute_weight_, unit, conv_param_, deconv_param_); | ||||
| @@ -45,31 +45,31 @@ int ConcatCPUKernel::DoConcat(int task_id) { | |||||
| std::vector<std::vector<int>> shapes; | std::vector<std::vector<int>> shapes; | ||||
| for (size_t i = 0; i < input_num; ++i) { | for (size_t i = 0; i < input_num; ++i) { | ||||
| inputs_addr[i] = in_tensors_[i]->MutableData(); | |||||
| inputs_addr[i] = in_tensors_[i]->data_c(); | |||||
| shapes.push_back(in_tensors_[i]->shape()); | shapes.push_back(in_tensors_[i]->shape()); | ||||
| inputs_output_shape[i] = shapes[i].data(); | inputs_output_shape[i] = shapes[i].data(); | ||||
| } | } | ||||
| auto output_shape = out_tensors_.at(0)->shape(); | auto output_shape = out_tensors_.at(0)->shape(); | ||||
| inputs_output_shape[input_num] = output_shape.data(); | inputs_output_shape[input_num] = output_shape.data(); | ||||
| auto output_addr = out_tensors_.at(0)->MutableData(); | |||||
| auto output_addr = out_tensors_.at(0)->data_c(); | |||||
| Concat(inputs_addr.data(), input_num, concat_param_->axis_, inputs_output_shape.data(), output_shape.size(), | Concat(inputs_addr.data(), input_num, concat_param_->axis_, inputs_output_shape.data(), output_shape.size(), | ||||
| output_addr, task_id, op_parameter_->thread_num_, sizeof(float)); | output_addr, task_id, op_parameter_->thread_num_, sizeof(float)); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConcatsRun(void *cdata, int task_id) { | |||||
| int ConcatRun(void *cdata, int task_id) { | |||||
| auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata); | auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata); | ||||
| auto error_code = concat_kernel->DoConcat(task_id); | auto error_code = concat_kernel->DoConcat(task_id); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| MS_LOG(ERROR) << "ConcatRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConcatCPUKernel::Run() { | int ConcatCPUKernel::Run() { | ||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConcatsRun, this, op_parameter_->thread_num_); | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, ConcatRun, this, op_parameter_->thread_num_); | |||||
| return error_code; | return error_code; | ||||
| } | } | ||||
| @@ -67,8 +67,8 @@ int TileCPUKernel::ReSize() { | |||||
| } | } | ||||
| int TileCPUKernel::Run() { | int TileCPUKernel::Run() { | ||||
| auto input_addr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| auto input_addr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c()); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c()); | |||||
| MS_ASSERT(input_addr); | MS_ASSERT(input_addr); | ||||
| MS_ASSERT(output_addr); | MS_ASSERT(output_addr); | ||||
| Tile(input_addr, output_addr, reinterpret_cast<TileParameter *>(op_parameter_)); | Tile(input_addr, output_addr, reinterpret_cast<TileParameter *>(op_parameter_)); | ||||
| @@ -23,11 +23,11 @@ | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| class TileCPUKernel : public LiteKernel { | class TileCPUKernel : public LiteKernel { | ||||
| public: | public: | ||||
| explicit TileCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| TileCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | ||||
| ~TileCPUKernel() override {} | |||||
| ~TileCPUKernel() override = default; | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||