From: @fuzhiye Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongpull/13585/MERGE
| @@ -159,8 +159,6 @@ std::unique_ptr<OperatorCoder> CPUConvolutionFP32CoderCreator(const std::vector< | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(paramGen(node->primitive_)); | auto conv_param = reinterpret_cast<ConvParameter *>(paramGen(node->primitive_)); | ||||
| bool use_winograd = false; | |||||
| int out_unit = 0; | |||||
| int kernel_h = conv_param->kernel_h_; | int kernel_h = conv_param->kernel_h_; | ||||
| int kernel_w = conv_param->kernel_w_; | int kernel_w = conv_param->kernel_w_; | ||||
| conv_param->input_h_ = inputs.at(kInputIndex)->Height(); | conv_param->input_h_ = inputs.at(kInputIndex)->Height(); | ||||
| @@ -170,7 +168,8 @@ std::unique_ptr<OperatorCoder> CPUConvolutionFP32CoderCreator(const std::vector< | |||||
| conv_param->output_w_ = outputs.at(kOutputIndex)->Width(); | conv_param->output_w_ = outputs.at(kOutputIndex)->Width(); | ||||
| conv_param->output_channel_ = outputs.at(kOutputIndex)->Channel(); | conv_param->output_channel_ = outputs.at(kOutputIndex)->Channel(); | ||||
| conv_param->op_parameter_.thread_num_ = 1; | conv_param->op_parameter_.thread_num_ = 1; | ||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||||
| int out_unit = 0; | |||||
| bool use_winograd = CheckIfUseWinograd(&out_unit, conv_param); | |||||
| free(conv_param); | free(conv_param); | ||||
| // weight de quant | // weight de quant | ||||
| std::unique_ptr<OperatorCoder> coder; | std::unique_ptr<OperatorCoder> coder; | ||||
| @@ -3882,14 +3882,13 @@ int SelectOutputUnit(ConvParameter *conv_param) { | |||||
| return unit; | return unit; | ||||
| } | } | ||||
| void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param) { | |||||
| bool CheckIfUseWinograd(int *output_unit, ConvParameter *conv_param) { | |||||
| if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && | if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && | ||||
| conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { | conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { | ||||
| *output_unit = SelectOutputUnit(conv_param); | *output_unit = SelectOutputUnit(conv_param); | ||||
| if (*output_unit > 1) { | if (*output_unit > 1) { | ||||
| *use_winograd = true; | |||||
| return true; | |||||
| } | } | ||||
| } else { | |||||
| *use_winograd = false; | |||||
| } | } | ||||
| return false; | |||||
| } | } | ||||
| @@ -308,7 +308,7 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f | |||||
| int SelectOutputUnit(ConvParameter *conv_param); | int SelectOutputUnit(ConvParameter *conv_param); | ||||
| void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param); | |||||
| bool CheckIfUseWinograd(int *output_unit, ConvParameter *conv_param); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -79,10 +79,23 @@ int ConvolutionDelegateFP16CPUKernel::Init() { | |||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| static void SetInputOutputShapeInfo(ConvParameter *conv_param, lite::Tensor *input, lite::Tensor *output, | |||||
| const InnerContext *ctx) { | |||||
| conv_param->input_batch_ = input->Batch(); | |||||
| conv_param->input_h_ = input->Height(); | |||||
| conv_param->input_w_ = input->Width(); | |||||
| conv_param->input_channel_ = input->Channel(); | |||||
| conv_param->output_batch_ = output->Batch(); | |||||
| conv_param->output_h_ = output->Height(); | |||||
| conv_param->output_w_ = output->Width(); | |||||
| conv_param->output_channel_ = output->Channel(); | |||||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||||
| } | |||||
| int ConvolutionDelegateFP16CPUKernel::ReSize() { | int ConvolutionDelegateFP16CPUKernel::ReSize() { | ||||
| // Update shape info of input and output | // Update shape info of input and output | ||||
| SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(op_parameter_), in_tensors_.front(), out_tensors_.front(), | |||||
| context_); | |||||
| kernel::SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(op_parameter_), in_tensors_.front(), | |||||
| out_tensors_.front(), context_); | |||||
| if (fp16_conv_kernel_ == nullptr) { | if (fp16_conv_kernel_ == nullptr) { | ||||
| fp16_conv_kernel_ = CpuConvFp16KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, origin_weight_, | fp16_conv_kernel_ = CpuConvFp16KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, origin_weight_, | ||||
| origin_bias_, origin_weight_data_type_, origin_bias_data_type_); | origin_bias_, origin_weight_data_type_, origin_bias_data_type_); | ||||
| @@ -0,0 +1,212 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_creator_manager.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" | |||||
| #include "src/runtime/kernel/arm/int8/convolution_int8.h" | |||||
| #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | |||||
| #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| namespace mindspore::lite { | |||||
| using mindspore::lite::Format::Format_NHWC; | |||||
| static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) { | |||||
| if (tensor->MallocData() != RET_OK) { | |||||
| delete tensor; | |||||
| MS_LOG(ERROR) << "malloc tensor data failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| lite::Tensor *CreateConstTensor(lite::Tensor *tensor, const std::vector<int> &shape, const int index) { | |||||
| auto new_tensor = | |||||
| new (std::nothrow) lite::Tensor(tensor->data_type(), shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (new_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create new_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = new_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete new_tensor; | |||||
| MS_LOG(ERROR) << "Malloc new_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(new_tensor->data_c(), reinterpret_cast<char *>(tensor->data_c()) + index * new_tensor->Size(), | |||||
| new_tensor->Size()); | |||||
| return new_tensor; | |||||
| } | |||||
| lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) { | |||||
| auto tensor = new (std::nothrow) lite::Tensor(); | |||||
| if (!tensor) { | |||||
| MS_LOG(ERROR) << "new tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| tensor->set_data_type(tensor_info.data_type_); | |||||
| tensor->set_format(tensor_info.format_); | |||||
| tensor->set_category(tensor_info.tensor_type_); | |||||
| if (tensor_info.is_in_) { | |||||
| tensor->set_shape(tensor_info.shape_); | |||||
| } | |||||
| if (inferred) { | |||||
| // set shape of out tensor | |||||
| if (!tensor_info.is_in_) { | |||||
| tensor->set_shape(tensor_info.shape_); | |||||
| } | |||||
| return TensorMalloc(tensor); | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| /* Kernel creator func part */ | |||||
| kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && | |||||
| conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| if (mindspore::lite::IsSupportSDot()) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } | |||||
| #else | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| #endif | |||||
| } else if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| kernel::LiteKernel *DispatchConvDw(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const InnerContext *ctx) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (opParameter != nullptr && opParameter->infer_flag_) { | |||||
| #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) | |||||
| if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| #endif | |||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| #endif | |||||
| if (kernel == nullptr && conv_param->input_channel_ < 32) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| } | |||||
| if (kernel == nullptr) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| kernel::LiteKernel *DispatchGroupConv(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx) { | |||||
| GroupConvCreator group_conv_creator(inputs, outputs, op_parameter, ctx, false); | |||||
| group_conv_creator.SetShapeOfTensors(); | |||||
| if (group_conv_creator.CreatGroupConv() != RET_OK) { | |||||
| MS_LOG(ERROR) << "Create group conv failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return new (std::nothrow) | |||||
| kernel::GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, group_conv_creator.get_group_conv(), | |||||
| reinterpret_cast<ConvParameter *>(op_parameter)->group_); | |||||
| } | |||||
| /* Class GroupConv Creator Implement Part*/ | |||||
| void GroupConvCreator::SetShapeOfTensors() { | |||||
| int new_in_channel = origin_inputs_.at(kWeightIndex)->Channel(); | |||||
| int new_out_channel; | |||||
| if (conv_param_->group_ == 0) { | |||||
| MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; | |||||
| return; | |||||
| } else { | |||||
| new_out_channel = origin_inputs_.at(kWeightIndex)->Batch() / conv_param_->group_; | |||||
| } | |||||
| /* set shape */ | |||||
| set_filter_shape({new_out_channel, conv_param_->kernel_h_, conv_param_->kernel_w_, new_in_channel}); | |||||
| set_bias_shape({new_out_channel}); | |||||
| if (infered_) { | |||||
| conv_param_->input_channel_ = new_in_channel; | |||||
| conv_param_->output_channel_ = new_out_channel; | |||||
| set_input_shape({origin_inputs_.front()->Batch(), origin_inputs_.front()->Height(), origin_inputs_.front()->Width(), | |||||
| new_in_channel}); | |||||
| set_output_shape({origin_inputs_.front()->Batch(), origin_outputs_.front()->Height(), | |||||
| origin_outputs_.front()->Width(), new_out_channel}); | |||||
| } | |||||
| } | |||||
| int GroupConvCreator::CreatGroupConv() { | |||||
| for (int i = 0; i < conv_param_->group_; ++i) { | |||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param_); | |||||
| if (!CheckIfValidPoint(new_conv_parameter)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| // create new input for each group | |||||
| std::vector<lite::Tensor *> new_inputs; | |||||
| if (NewInputTensor(&new_inputs) != RET_OK) { | |||||
| MS_LOG(ERROR) << "new input tensor failed."; | |||||
| FreeMemory(new_conv_parameter, new_inputs, {}); | |||||
| return RET_ERROR; | |||||
| } | |||||
| // const tensor | |||||
| if (NewConstTensor(&new_inputs, i) != RET_OK) { | |||||
| MS_LOG(ERROR) << "new const tensor failed."; | |||||
| FreeMemory(new_conv_parameter, new_inputs, {}); | |||||
| return RET_ERROR; | |||||
| } | |||||
| // create new output tensor | |||||
| std::vector<lite::Tensor *> new_outputs; | |||||
| for (auto &output : origin_outputs_) { | |||||
| if (NewOutputTensor(&new_outputs, output) != RET_OK) { | |||||
| MS_LOG(ERROR) << "new output tensor failed."; | |||||
| FreeMemory(new_conv_parameter, new_inputs, new_outputs); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (is_quant_) { | |||||
| CopyQuantParam(&new_inputs); | |||||
| group_convs_.emplace_back(CpuConvInt8KernelSelect(new_inputs, new_outputs, | |||||
| reinterpret_cast<OpParameter *>(new_conv_parameter), context_)); | |||||
| } else { | |||||
| group_convs_.emplace_back(new (std::nothrow) kernel::ConvolutionDelegateCPUKernel( | |||||
| reinterpret_cast<OpParameter *>(new_conv_parameter), new_inputs, new_outputs, context_)); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -0,0 +1,180 @@ | |||||
| /** | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_CREATOR_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_CREATOR_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| namespace mindspore::lite { | |||||
| using Category = lite::Tensor::Category; | |||||
| using Format = mindspore::schema::Format; | |||||
| struct TensorInfo { | |||||
| std::vector<int> shape_; | |||||
| Format format_; | |||||
| TypeId data_type_; | |||||
| Category tensor_type_; | |||||
| bool is_in_; | |||||
| }; | |||||
| inline void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) { | |||||
| for (size_t i = 0; i < src->quant_params().size(); i++) { | |||||
| dst->AddQuantParam(src->quant_params().at(i)); | |||||
| } | |||||
| } | |||||
| inline ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { | |||||
| auto conv_parameter = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(conv_parameter, parameter, sizeof(ConvParameter)); | |||||
| return conv_parameter; | |||||
| } | |||||
| inline void FreeMemory(ConvParameter *conv_param, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs) { | |||||
| if (conv_param) { | |||||
| free(conv_param); | |||||
| } | |||||
| for (auto &in_tensor : new_inputs) { | |||||
| delete in_tensor; | |||||
| } | |||||
| for (auto &out_tensor : new_outputs) { | |||||
| delete out_tensor; | |||||
| } | |||||
| } | |||||
| lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred); | |||||
| lite::Tensor *CreateConstTensor(lite::Tensor *tensor, const std::vector<int> &shape, int index); | |||||
| kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx); | |||||
| kernel::LiteKernel *DispatchConvDw(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const InnerContext *ctx); | |||||
| kernel::LiteKernel *DispatchGroupConv(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx); | |||||
| class GroupConvCreator { | |||||
| public: | |||||
| GroupConvCreator(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, bool is_quant) | |||||
| : origin_inputs_(std::move(inputs)), | |||||
| origin_outputs_(std::move(outputs)), | |||||
| context_(ctx), | |||||
| infered_(op_parameter->infer_flag_), | |||||
| is_quant_(is_quant) { | |||||
| conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| } | |||||
| ~GroupConvCreator() = default; | |||||
| public: | |||||
| void SetShapeOfTensors(); | |||||
| void set_input_shape(const std::vector<int> &shape) { input_shape_ = shape; } | |||||
| void set_output_shape(const std::vector<int> &shape) { output_shape_ = shape; } | |||||
| void set_filter_shape(const std::vector<int> &shape) { filter_shape_ = shape; } | |||||
| void set_bias_shape(const std::vector<int> &shape) { bias_shape_ = shape; } | |||||
| std::vector<kernel::LiteKernel *> get_group_conv() { return group_convs_; } | |||||
| int CreatGroupConv(); | |||||
| protected: | |||||
| void FreeSubConv() { | |||||
| for (auto &sub_conv : group_convs_) { | |||||
| delete sub_conv; | |||||
| } | |||||
| } | |||||
| bool CheckIfValidPoint(void *ptr) { | |||||
| if (ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "pointer is nullptr."; | |||||
| FreeSubConv(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| int NewInputTensor(std::vector<lite::Tensor *> *tensors) { | |||||
| auto in_tensor = CreateVarTensor( | |||||
| {input_shape_, Format::Format_NHWC, origin_inputs_.at(0)->data_type(), Category::VAR, true}, infered_); | |||||
| if (!CheckIfValidPoint(in_tensor)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| tensors->emplace_back(in_tensor); | |||||
| return RET_OK; | |||||
| } | |||||
| int NewConstTensor(std::vector<lite::Tensor *> *tensors, int group_id) { | |||||
| std::vector<std::pair<int, std::vector<int>>> const_tensor_list{std::make_pair(kWeightIndex, filter_shape_)}; | |||||
| if (origin_inputs_.size() == 3) { | |||||
| const_tensor_list.emplace_back(std::make_pair(kBiasIndex, bias_shape_)); | |||||
| } | |||||
| for (auto &info : const_tensor_list) { | |||||
| auto const_tensor = CreateConstTensor(origin_inputs_.at(info.first), info.second, group_id); | |||||
| if (!CheckIfValidPoint(const_tensor)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| tensors->emplace_back(const_tensor); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int NewOutputTensor(std::vector<lite::Tensor *> *tensors, lite::Tensor *output) { | |||||
| auto out_tensor = | |||||
| CreateVarTensor({output_shape_, output->format(), output->data_type(), output->category(), false}, infered_); | |||||
| if (!CheckIfValidPoint(out_tensor)) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (is_quant_) { | |||||
| CopyTensorQuantParam(out_tensor, output); | |||||
| } | |||||
| tensors->emplace_back(out_tensor); | |||||
| return RET_OK; | |||||
| } | |||||
| void CopyQuantParam(std::vector<lite::Tensor *> *tensors) { | |||||
| for (size_t j = 0; j < origin_inputs_.size(); ++j) { | |||||
| CopyTensorQuantParam(tensors->at(j), origin_inputs_.at(j)); | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::vector<lite::Tensor *> origin_inputs_; | |||||
| std::vector<lite::Tensor *> origin_outputs_; | |||||
| std::vector<kernel::LiteKernel *> group_convs_; | |||||
| std::vector<int> input_shape_; | |||||
| std::vector<int> output_shape_; | |||||
| std::vector<int> filter_shape_; | |||||
| std::vector<int> bias_shape_; | |||||
| const InnerContext *context_; | |||||
| ConvParameter *conv_param_; | |||||
| bool infered_; | |||||
| bool is_quant_; | |||||
| }; | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_CREATOR_H_ | |||||
| @@ -13,27 +13,21 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_creator_manager.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h" | #include "src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h" | ||||
| #include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_INFER_INVALID; | using mindspore::lite::RET_INFER_INVALID; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Conv2DFusion; | using mindspore::schema::PrimitiveType_Conv2DFusion; | ||||
| using mindspore::schema::Format::Format_NHWC; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| float *ConvolutionDelegateCPUKernel::CopyData(lite::Tensor *tensor) { | float *ConvolutionDelegateCPUKernel::CopyData(lite::Tensor *tensor) { | ||||
| @@ -46,17 +40,6 @@ float *ConvolutionDelegateCPUKernel::CopyData(lite::Tensor *tensor) { | |||||
| return data; | return data; | ||||
| } | } | ||||
| void ConvolutionDelegateCPUKernel::FreeCopiedData() { | |||||
| if (origin_weight_ != nullptr && need_free_weight_) { | |||||
| free(origin_weight_); | |||||
| origin_weight_ = nullptr; | |||||
| } | |||||
| if (origin_bias_ != nullptr && need_free_bias_) { | |||||
| free(origin_bias_); | |||||
| origin_bias_ = nullptr; | |||||
| } | |||||
| } | |||||
| int ConvolutionDelegateCPUKernel::GetWeightAndBias() { | int ConvolutionDelegateCPUKernel::GetWeightAndBias() { | ||||
| auto ret = GetWeightData(); | auto ret = GetWeightData(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -75,15 +58,13 @@ int ConvolutionDelegateCPUKernel::GetWeightData() { | |||||
| if (InferShapeDone()) { | if (InferShapeDone()) { | ||||
| origin_weight_ = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()); | origin_weight_ = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->data_c()); | ||||
| return RET_OK; | return RET_OK; | ||||
| } else { | |||||
| origin_weight_ = CopyData(in_tensors_.at(kWeightIndex)); | |||||
| if (origin_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Copy weight data failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| need_free_weight_ = true; | |||||
| return RET_OK; | |||||
| } | } | ||||
| origin_weight_ = CopyData(in_tensors_.at(kWeightIndex)); | |||||
| if (origin_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Copy weight data failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| need_free_weight_ = true; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -119,23 +100,24 @@ int ConvolutionDelegateCPUKernel::Init() { | |||||
| int ConvolutionDelegateCPUKernel::ReSize() { | int ConvolutionDelegateCPUKernel::ReSize() { | ||||
| // Update shape info of input and output | // Update shape info of input and output | ||||
| SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(op_parameter_), in_tensors_.front(), out_tensors_.front(), | |||||
| context_); | |||||
| SetInputOutputShapeInfo(); | |||||
| if (conv_kernel_ == nullptr) { | if (conv_kernel_ == nullptr) { | ||||
| // need to select actual execute kernel here | // need to select actual execute kernel here | ||||
| conv_kernel_ = | |||||
| CpuConvFp32KernelSelect(in_tensors_, out_tensors_, op_parameter_, context_, origin_weight_, origin_bias_); | |||||
| if (conv_kernel_ == nullptr) { | |||||
| conv_kernel_ = CpuConvFp32KernelSelect(); | |||||
| if (!conv_kernel_) { | |||||
| MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; | MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| conv_kernel_->set_name(this->name_); | |||||
| } | } | ||||
| FreeCopiedData(); | FreeCopiedData(); | ||||
| return conv_kernel_->ReSize(); | return conv_kernel_->ReSize(); | ||||
| } | } | ||||
| void SetInputOutputShapeInfo(ConvParameter *conv_param, const lite::Tensor *input, const lite::Tensor *output, | |||||
| const InnerContext *ctx) { | |||||
| void ConvolutionDelegateCPUKernel::SetInputOutputShapeInfo() { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||||
| auto input = in_tensors_.at(0); | |||||
| auto output = out_tensors_.at(0); | |||||
| conv_param->input_batch_ = input->Batch(); | conv_param->input_batch_ = input->Batch(); | ||||
| conv_param->input_h_ = input->Height(); | conv_param->input_h_ = input->Height(); | ||||
| conv_param->input_w_ = input->Width(); | conv_param->input_w_ = input->Width(); | ||||
| @@ -144,109 +126,27 @@ void SetInputOutputShapeInfo(ConvParameter *conv_param, const lite::Tensor *inpu | |||||
| conv_param->output_h_ = output->Height(); | conv_param->output_h_ = output->Height(); | ||||
| conv_param->output_w_ = output->Width(); | conv_param->output_w_ = output->Width(); | ||||
| conv_param->output_channel_ = output->Channel(); | conv_param->output_channel_ = output->Channel(); | ||||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||||
| } | |||||
| ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { | |||||
| auto conv_parameter = new (std::nothrow) ConvParameter; | |||||
| if (conv_parameter == nullptr) { | |||||
| MS_LOG(ERROR) << "Malloc new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(conv_parameter, parameter, sizeof(ConvParameter)); | |||||
| return conv_parameter; | |||||
| } | |||||
| void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs) { | |||||
| for (auto sub_conv : group_convs) { | |||||
| delete sub_conv; | |||||
| } | |||||
| for (auto in_tensor : new_inputs) { | |||||
| delete in_tensor; | |||||
| } | |||||
| for (auto out_tensor : new_outputs) { | |||||
| delete out_tensor; | |||||
| } | |||||
| } | |||||
| lite::Tensor *CreateInputTensor(TypeId data_type, const std::vector<int> &in_shape, bool infered_flag) { | |||||
| auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); | |||||
| if (in_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new in_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| if (infered_flag) { | |||||
| auto ret = in_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete in_tensor; | |||||
| MS_LOG(ERROR) << "in tensor malloc failed."; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return in_tensor; | |||||
| } | |||||
| // weight and bias are const | |||||
| static lite::Tensor *CreateConstTensorFp32(lite::Tensor *tensor, const std::vector<int> &shape, const int index) { | |||||
| auto new_tensor = | |||||
| new (std::nothrow) lite::Tensor(tensor->data_type(), shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (new_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "Create new_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = new_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete new_tensor; | |||||
| MS_LOG(ERROR) << "Malloc new_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| MS_ASSERT(tensor->data_type() == kNumberTypeFloat32); | |||||
| memcpy(new_tensor->data_c(), reinterpret_cast<char *>(tensor->data_c()) + index * new_tensor->Size(), | |||||
| new_tensor->Size()); | |||||
| return new_tensor; | |||||
| } | |||||
| lite::Tensor *CreateOutputTensor(const std::vector<int> &out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index) { | |||||
| auto out_tensor = new (std::nothrow) lite::Tensor(); | |||||
| if (out_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new tmp_out_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| out_tensor->set_data_type(outputs.at(index)->data_type()); | |||||
| out_tensor->set_format(outputs.at(index)->format()); | |||||
| if (infered_flag) { | |||||
| out_tensor->set_shape(out_shape); | |||||
| auto ret = out_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete out_tensor; | |||||
| MS_LOG(ERROR) << "out_tensor malloc data failed."; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return out_tensor; | |||||
| conv_param->op_parameter_.thread_num_ = context_->thread_num_; | |||||
| } | } | ||||
| kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, float *origin_weight, float *origin_bias) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| bool use_winograd = false; | |||||
| int out_unit; | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||||
| kernel::LiteKernel *ConvolutionDelegateCPUKernel::CpuConvFp32KernelSelect() { | |||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||||
| if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | ||||
| kernel = new (std::nothrow) | kernel = new (std::nothrow) | ||||
| kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, origin_weight, origin_bias); | |||||
| } else if (use_winograd) { | |||||
| kernel = new (std::nothrow) | |||||
| kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, out_unit, origin_weight, origin_bias); | |||||
| kernel::Convolution1x1CPUKernel(op_parameter_, in_tensors_, out_tensors_, context_, origin_weight_, origin_bias_); | |||||
| } else { | } else { | ||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, origin_weight, origin_bias); | |||||
| int out_unit; | |||||
| if (CheckIfUseWinograd(&out_unit, conv_param)) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionWinogradCPUKernel( | |||||
| op_parameter_, in_tensors_, out_tensors_, context_, out_unit, origin_weight_, origin_bias_); | |||||
| } else { | |||||
| kernel = new (std::nothrow) | |||||
| kernel::ConvolutionCPUKernel(op_parameter_, in_tensors_, out_tensors_, context_, origin_weight_, origin_bias_); | |||||
| } | |||||
| } | } | ||||
| if (kernel != nullptr) { | |||||
| if (kernel) { | |||||
| auto ret = kernel->Init(); | auto ret = kernel->Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "conv kernel init failed."; | MS_LOG(ERROR) << "conv kernel init failed."; | ||||
| @@ -257,124 +157,7 @@ kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &i | |||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| static kernel::LiteKernel *CreateDelegateConv(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx) { | |||||
| return new (std::nothrow) kernel::ConvolutionDelegateCPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } | |||||
| kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx) { | |||||
| bool infer_flag = op_parameter->infer_flag_; | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| int new_in_channel = inputs.at(kWeightIndex)->Channel(); | |||||
| int new_out_channel; | |||||
| if (conv_param->group_ == 0) { | |||||
| MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; | |||||
| return nullptr; | |||||
| } else { | |||||
| new_out_channel = inputs.at(kWeightIndex)->Batch() / conv_param->group_; | |||||
| } | |||||
| std::vector<int> in_shape; | |||||
| std::vector<int> out_shape; | |||||
| if (infer_flag) { | |||||
| conv_param->input_channel_ = new_in_channel; | |||||
| conv_param->output_channel_ = new_out_channel; | |||||
| in_shape = {inputs.front()->Batch(), inputs.front()->Height(), inputs.front()->Width(), new_in_channel}; | |||||
| out_shape = {inputs.front()->Batch(), outputs.front()->Height(), outputs.front()->Width(), new_out_channel}; | |||||
| } | |||||
| std::vector<int> filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; | |||||
| std::vector<int> bias_shape = {new_out_channel}; | |||||
| // create sub kernels | |||||
| std::vector<kernel::LiteKernel *> group_convs; | |||||
| for (int i = 0; i < conv_param->group_; ++i) { | |||||
| std::vector<lite::Tensor *> new_inputs; | |||||
| std::vector<lite::Tensor *> new_outputs; | |||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param); | |||||
| if (new_conv_parameter == nullptr) { | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "Get new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // create new input for each group | |||||
| auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infer_flag); | |||||
| if (in_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create input tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(in_tensor); | |||||
| // create new weight | |||||
| auto filter_tensor = CreateConstTensorFp32(inputs.at(kWeightIndex), filter_shape, i); | |||||
| if (filter_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create filter tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(filter_tensor); | |||||
| // if has bias, create new bias | |||||
| if (inputs.size() == 3) { | |||||
| auto bias_tensor = CreateConstTensorFp32(inputs.at(kBiasIndex), bias_shape, i); | |||||
| if (bias_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create bias_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(bias_tensor); | |||||
| } | |||||
| // create new output tensor | |||||
| for (size_t j = 0; j < outputs.size(); ++j) { | |||||
| auto out_tensor = CreateOutputTensor(out_shape, outputs, infer_flag, j); | |||||
| if (out_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "new out_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_outputs.emplace_back(out_tensor); | |||||
| } | |||||
| group_convs.emplace_back( | |||||
| CreateDelegateConv(new_inputs, new_outputs, reinterpret_cast<OpParameter *>(new_conv_parameter), ctx)); | |||||
| } | |||||
| return new (std::nothrow) | |||||
| GroupConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, group_convs, conv_param->group_); | |||||
| } | |||||
| kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const InnerContext *ctx, const kernel::KernelKey &desc) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (opParameter != nullptr && opParameter->infer_flag_) { | |||||
| #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) | |||||
| if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| #endif | |||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| #endif | |||||
| if (kernel == nullptr && conv_param->input_channel_ < 32) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| } | |||||
| if (kernel == nullptr) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); | |||||
| } | |||||
| return kernel; | |||||
| } // namespace mindspore::kernel | |||||
| /* creator func */ | |||||
| kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const InnerContext *ctx, const kernel::KernelKey &desc) { | const InnerContext *ctx, const kernel::KernelKey &desc) { | ||||
| @@ -385,11 +168,11 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | ||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (conv_param->group_ == 1) { | if (conv_param->group_ == 1) { | ||||
| kernel = CreateDelegateConv(inputs, outputs, op_parameter, ctx); | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDelegateCPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { | } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { | ||||
| kernel = CpuConvDwFp32KernelCreator(inputs, outputs, op_parameter, ctx, desc); | |||||
| kernel = DispatchConvDw(inputs, outputs, op_parameter, ctx); | |||||
| } else { | } else { | ||||
| kernel = CpuGroupConvFp32KernelCreator(inputs, outputs, op_parameter, ctx); | |||||
| kernel = DispatchGroupConv(inputs, outputs, op_parameter, ctx); | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/lite_kernel.h" | #include "src/lite_kernel.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_creator_manager.h" | |||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| @@ -39,12 +40,31 @@ class ConvolutionDelegateCPUKernel : public LiteKernel { | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override { return conv_kernel_->Run(); } | int Run() override { return conv_kernel_->Run(); } | ||||
| protected: | |||||
| int GetWeightAndBias(); | int GetWeightAndBias(); | ||||
| int GetWeightData(); | int GetWeightData(); | ||||
| int GetBiasData(); | int GetBiasData(); | ||||
| void SetInputOutputShapeInfo(); | |||||
| kernel::LiteKernel *CpuConvFp32KernelSelect(); | |||||
| // If inferShape process can't complete in Init part, initialization of weight and bis will be implemented in runtime | |||||
| // via Resize() API. However,data of const tensor(weight and bias) doesn't exist anymore in runtime stage.Thus, | |||||
| // copying data of const tensor is necessary. Otherwise, just pass origin raw pointer of data. | |||||
| static float *CopyData(lite::Tensor *tensor); | static float *CopyData(lite::Tensor *tensor); | ||||
| void FreeCopiedData(); | |||||
| void FreeCopiedData() { | |||||
| if (origin_weight_ != nullptr && need_free_weight_) { | |||||
| free(origin_weight_); | |||||
| origin_weight_ = nullptr; | |||||
| } | |||||
| if (origin_bias_ != nullptr && need_free_bias_) { | |||||
| free(origin_bias_); | |||||
| origin_bias_ = nullptr; | |||||
| } | |||||
| } | |||||
| // Train API | |||||
| int Eval() override { | int Eval() override { | ||||
| LiteKernel::Eval(); | LiteKernel::Eval(); | ||||
| return conv_kernel_->Eval(); | return conv_kernel_->Eval(); | ||||
| @@ -59,29 +79,12 @@ class ConvolutionDelegateCPUKernel : public LiteKernel { | |||||
| } | } | ||||
| protected: | protected: | ||||
| bool need_free_weight_ = false; | |||||
| bool need_free_bias_ = false; | |||||
| kernel::LiteKernel *conv_kernel_ = nullptr; | |||||
| float *origin_weight_ = nullptr; | |||||
| float *origin_bias_ = nullptr; | |||||
| kernel::LiteKernel *conv_kernel_{nullptr}; | |||||
| float *origin_weight_{nullptr}; | |||||
| float *origin_bias_{nullptr}; | |||||
| bool need_free_weight_{false}; | |||||
| bool need_free_bias_{false}; | |||||
| }; | }; | ||||
| void SetInputOutputShapeInfo(ConvParameter *conv_param, const lite::Tensor *input, const lite::Tensor *output, | |||||
| const InnerContext *ctx); | |||||
| void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs); | |||||
| ConvParameter *CreateNewConvParameter(ConvParameter *parameter); | |||||
| lite::Tensor *CreateInputTensor(TypeId data_type, const std::vector<int> &in_shape, bool infered_flag); | |||||
| lite::Tensor *CreateOutputTensor(const std::vector<int> &out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index); | |||||
| kernel::LiteKernel *CpuConvFp32KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, float *origin_weight, float *origin_bias); | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DELEGATE_FP32_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DELEGATE_FP32_H_ | ||||
| @@ -28,6 +28,14 @@ using mindspore::lite::RET_INFER_INVALID; | |||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| #ifdef ENABLE_AVX | |||||
| #define OC_BLOCK C16NUM | |||||
| #elif ENABLE_ARM32 | |||||
| #define OC_BLOCK C4NUM | |||||
| #else | |||||
| #define OC_BLOCK C8NUM | |||||
| #endif | |||||
| int ConvolutionCPUKernel::InitWeightBias() { | int ConvolutionCPUKernel::InitWeightBias() { | ||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | auto filter_tensor = in_tensors_.at(kWeightIndex); | ||||
| int in_channel = filter_tensor->Channel(); | int in_channel = filter_tensor->Channel(); | ||||
| @@ -35,14 +43,7 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||||
| conv_param_->input_channel_ = in_channel; | conv_param_->input_channel_ = in_channel; | ||||
| conv_param_->output_channel_ = out_channel; | conv_param_->output_channel_ = out_channel; | ||||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | ||||
| #ifdef ENABLE_AVX | |||||
| const int oc_block = C16NUM; | |||||
| #elif ENABLE_ARM32 | |||||
| const int oc_block = C4NUM; | |||||
| #else | |||||
| const int oc_block = C8NUM; | |||||
| #endif | |||||
| int oc_block_num = UP_ROUND(out_channel, oc_block); | |||||
| int oc_block_num = UP_ROUND(out_channel, OC_BLOCK); | |||||
| int pack_weight_size = oc_block_num * in_channel * kernel_plane; | int pack_weight_size = oc_block_num * in_channel * kernel_plane; | ||||
| packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float))); | packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float))); | ||||
| @@ -164,14 +165,7 @@ void ConvolutionCPUKernel::PackWeight() { | |||||
| int in_channel = filter_tensor->Channel(); | int in_channel = filter_tensor->Channel(); | ||||
| int out_channel = filter_tensor->Batch(); | int out_channel = filter_tensor->Batch(); | ||||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | ||||
| #ifdef ENABLE_AVX | |||||
| const int oc_block = C16NUM; | |||||
| #elif ENABLE_ARM32 | |||||
| const int oc_block = C4NUM; | |||||
| #else | |||||
| const int oc_block = C8NUM; | |||||
| #endif | |||||
| int oc_block_num = UP_ROUND(out_channel, oc_block); | |||||
| int oc_block_num = UP_ROUND(out_channel, OC_BLOCK); | |||||
| int pack_weight_size = oc_block_num * in_channel * kernel_plane; | int pack_weight_size = oc_block_num * in_channel * kernel_plane; | ||||
| auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "nnacl/int8/conv_int8.h" | #include "nnacl/int8/conv_int8.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_creator_manager.h" | |||||
| #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | ||||
| #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" | #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" | ||||
| #include "src/runtime/kernel/arm/int8/group_convolution_int8.h" | #include "src/runtime/kernel/arm/int8/group_convolution_int8.h" | ||||
| @@ -244,174 +244,17 @@ int ConvolutionInt8CPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::Tensor *CreateFilterTensorInt8(TypeId data_type, std::vector<int> filter_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int copy_length, int index) { | |||||
| MS_ASSERT(data_type == kNumberTypeInt8); | |||||
| auto filter_tensor = | |||||
| new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (filter_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new filter_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = filter_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete filter_tensor; | |||||
| MS_LOG(ERROR) << "filter_tensor malloc failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *origin_weight = reinterpret_cast<int8_t *>(inputs.at(kWeightIndex)->data_c()); | |||||
| memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(int8_t)); | |||||
| return filter_tensor; | |||||
| } | |||||
| lite::Tensor *CreateBiasTensorInt8(TypeId data_type, std::vector<int> bias_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int new_out_channel, int index) { | |||||
| MS_ASSERT(data_type == kNumberTypeInt32); | |||||
| auto *origin_bias = inputs.at(kBiasIndex)->data_c(); | |||||
| auto bias_tensor = | |||||
| new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (bias_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new bias_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = bias_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete bias_tensor; | |||||
| MS_LOG(ERROR) << "bias_tensor malloc failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto bias_data = reinterpret_cast<int32_t *>(origin_bias); | |||||
| memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(int32_t)); | |||||
| return bias_tensor; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && | |||||
| conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| if (mindspore::lite::IsSupportSDot()) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } | |||||
| #else | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| #endif | |||||
| } else if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| void CopyTensorQuantParam(lite::Tensor *dst, lite::Tensor *src) { | |||||
| for (size_t i = 0; i < src->quant_params().size(); i++) { | |||||
| dst->AddQuantParam(src->quant_params().at(i)); | |||||
| } | |||||
| } | |||||
| kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | ||||
| const InnerContext *ctx, int group) { | const InnerContext *ctx, int group) { | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| std::vector<int> in_shape; | |||||
| std::vector<int> out_shape; | |||||
| int new_in_channel = inputs.at(kWeightIndex)->Channel(); | |||||
| int new_out_channel = 0; | |||||
| if (group == 0) { | |||||
| MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; | |||||
| lite::GroupConvCreator group_conv_creator(inputs, outputs, op_parameter, ctx, true); | |||||
| group_conv_creator.SetShapeOfTensors(); | |||||
| if (group_conv_creator.CreatGroupConv() != RET_OK) { | |||||
| MS_LOG(ERROR) << "Create group conv failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } else { | |||||
| new_out_channel = inputs.at(kWeightIndex)->Batch() / group; | |||||
| } | |||||
| int batch = inputs.front()->Batch(); | |||||
| conv_param->input_batch_ = batch; | |||||
| conv_param->output_batch_ = batch; | |||||
| bool infered_flag = op_parameter != nullptr && op_parameter->infer_flag_; | |||||
| if (infered_flag) { | |||||
| int in_h = inputs.front()->Height(); | |||||
| int in_w = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = new_in_channel; | |||||
| conv_param->output_channel_ = new_out_channel; | |||||
| in_shape = {batch, in_h, in_w, new_in_channel}; | |||||
| out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; | |||||
| } | |||||
| std::vector<int> filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; | |||||
| std::vector<int> bias_shape = {new_out_channel}; | |||||
| // create sub kernels | |||||
| std::vector<kernel::LiteKernel *> group_convs; | |||||
| for (int i = 0; i < group; ++i) { | |||||
| std::vector<lite::Tensor *> new_inputs; | |||||
| std::vector<lite::Tensor *> new_outputs; | |||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param); | |||||
| if (new_conv_parameter == nullptr) { | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "Get new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // create new input for each group | |||||
| auto input_data_type = inputs.front()->data_type(); | |||||
| MS_ASSERT(input_data_type == kNumberTypeInt8); | |||||
| auto in_tensor = CreateInputTensor(input_data_type, in_shape, infered_flag); | |||||
| if (in_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create input tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| CopyTensorQuantParam(in_tensor, inputs[kInputIndex]); | |||||
| new_inputs.emplace_back(in_tensor); | |||||
| // create new weight | |||||
| int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; | |||||
| auto filter_tensor = | |||||
| CreateFilterTensorInt8(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | |||||
| if (filter_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create filter tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| CopyTensorQuantParam(filter_tensor, inputs[kWeightIndex]); | |||||
| new_inputs.emplace_back(filter_tensor); | |||||
| // if has bias, create new bias | |||||
| if (inputs.size() == 3) { | |||||
| auto bias_tensor = | |||||
| CreateBiasTensorInt8(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | |||||
| if (bias_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create bias_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| CopyTensorQuantParam(bias_tensor, inputs[kBiasIndex]); | |||||
| new_inputs.emplace_back(bias_tensor); | |||||
| } | |||||
| // create new output tensor | |||||
| for (size_t j = 0; j < outputs.size(); ++j) { | |||||
| auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); | |||||
| if (out_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "new out_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| CopyTensorQuantParam(out_tensor, outputs[j]); | |||||
| new_outputs.emplace_back(out_tensor); | |||||
| } | |||||
| group_convs.emplace_back( | |||||
| CpuConvInt8KernelSelect(new_inputs, new_outputs, reinterpret_cast<OpParameter *>(new_conv_parameter), ctx)); | |||||
| } | } | ||||
| return new (std::nothrow) GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, group_convs, group); | |||||
| return new (std::nothrow) | |||||
| GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, group_conv_creator.get_group_conv(), group); | |||||
| } | } | ||||
| kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| @@ -29,7 +29,7 @@ class GroupConvolutionInt8CPUKernel : public GroupConvolutionCPUKernel { | |||||
| GroupConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | GroupConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| std::vector<kernel::LiteKernel *> group_convs, const int group_num) | std::vector<kernel::LiteKernel *> group_convs, const int group_num) | ||||
| : GroupConvolutionCPUKernel(parameter, inputs, outputs, ctx, group_convs, group_num) { | |||||
| : GroupConvolutionCPUKernel(parameter, inputs, outputs, ctx, std::move(group_convs), group_num) { | |||||
| } // opParameter(in channel, out channel) in this kernel has been split to groups, if | } // opParameter(in channel, out channel) in this kernel has been split to groups, if | ||||
| // you want to get real params, multiply in channel / out channel with group num | // you want to get real params, multiply in channel / out channel with group num | ||||