From: @fuzhiye Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongtags/v1.1.0
| @@ -221,7 +221,7 @@ void FreeMemoryFp16(const std::vector<kernel::LiteKernel *> &group_convs, const | |||||
| } | } | ||||
| } | } | ||||
| lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag) { | |||||
| lite::Tensor *CreateInputTensorFp16(TypeId data_type, std::vector<int> in_shape, bool infered_flag) { | |||||
| auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); | auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); | ||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "new in_tensor failed."; | MS_LOG(ERROR) << "new in_tensor failed."; | ||||
| @@ -238,8 +238,8 @@ lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, boo | |||||
| return in_tensor; | return in_tensor; | ||||
| } | } | ||||
| lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector<int> filter_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int copy_length, int index) { | |||||
| lite::Tensor *CreateFilterTensorFp16(TypeId data_type, std::vector<int> filter_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int copy_length, int index) { | |||||
| auto filter_tensor = | auto filter_tensor = | ||||
| new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | ||||
| if (filter_tensor == nullptr) { | if (filter_tensor == nullptr) { | ||||
| @@ -263,8 +263,8 @@ lite::Tensor *CreateFilterTensor(TypeId data_type, std::vector<int> filter_shape | |||||
| return filter_tensor; | return filter_tensor; | ||||
| } | } | ||||
| lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector<int> bias_shape, const std::vector<lite::Tensor *> &inputs, | |||||
| int new_out_channel, int index) { | |||||
| lite::Tensor *CreateBiasTensorFp16(TypeId data_type, std::vector<int> bias_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int new_out_channel, int index) { | |||||
| auto *origin_bias = inputs.at(kBiasIndex)->data_c(); | auto *origin_bias = inputs.at(kBiasIndex)->data_c(); | ||||
| auto bias_tensor = | auto bias_tensor = | ||||
| new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | ||||
| @@ -289,8 +289,8 @@ lite::Tensor *CreateBiasTensor(TypeId data_type, std::vector<int> bias_shape, co | |||||
| return bias_tensor; | return bias_tensor; | ||||
| } | } | ||||
| lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index) { | |||||
| lite::Tensor *CreateOutputTensorFp16(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index) { | |||||
| auto out_tensor = new (std::nothrow) lite::Tensor(); | auto out_tensor = new (std::nothrow) lite::Tensor(); | ||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "new tmp_out_tensor failed."; | MS_LOG(ERROR) << "new tmp_out_tensor failed."; | ||||
| @@ -356,7 +356,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // create new input for each group | // create new input for each group | ||||
| auto in_tensor = CreateInputTensor(mindspore::kNumberTypeFloat16, in_shape, infered_flag); | |||||
| auto in_tensor = CreateInputTensorFp16(mindspore::kNumberTypeFloat16, in_shape, infered_flag); | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp16(group_convs, new_inputs, new_outputs); | FreeMemoryFp16(group_convs, new_inputs, new_outputs); | ||||
| @@ -367,7 +367,8 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor | |||||
| // create new weight | // create new weight | ||||
| int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; | int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; | ||||
| auto filter_tensor = CreateFilterTensor(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | |||||
| auto filter_tensor = | |||||
| CreateFilterTensorFp16(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | |||||
| if (filter_tensor == nullptr) { | if (filter_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp16(group_convs, new_inputs, new_outputs); | FreeMemoryFp16(group_convs, new_inputs, new_outputs); | ||||
| @@ -378,7 +379,8 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor | |||||
| // if has bias, create new bias | // if has bias, create new bias | ||||
| if (has_bias) { | if (has_bias) { | ||||
| auto bias_tensor = CreateBiasTensor(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | |||||
| auto bias_tensor = | |||||
| CreateBiasTensorFp16(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | |||||
| if (bias_tensor == nullptr) { | if (bias_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp16(group_convs, new_inputs, new_outputs); | FreeMemoryFp16(group_convs, new_inputs, new_outputs); | ||||
| @@ -390,7 +392,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor | |||||
| // create new output tensors | // create new output tensors | ||||
| for (size_t j = 0; j < outputs.size(); ++j) { | for (size_t j = 0; j < outputs.size(); ++j) { | ||||
| auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); | |||||
| auto out_tensor = CreateOutputTensorFp16(out_shape, outputs, infered_flag, j); | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp16(group_convs, new_inputs, new_outputs); | FreeMemoryFp16(group_convs, new_inputs, new_outputs); | ||||
| @@ -168,8 +168,8 @@ ConvParameter *CreateNewConvParameter(ConvParameter *parameter) { | |||||
| return conv_parameter; | return conv_parameter; | ||||
| } | } | ||||
| void FreeMemoryFp32(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs) { | |||||
| void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs) { | |||||
| for (auto sub_conv : group_convs) { | for (auto sub_conv : group_convs) { | ||||
| if (sub_conv != nullptr) { | if (sub_conv != nullptr) { | ||||
| delete sub_conv; | delete sub_conv; | ||||
| @@ -187,7 +187,7 @@ void FreeMemoryFp32(const std::vector<kernel::LiteKernel *> &group_convs, const | |||||
| } | } | ||||
| } | } | ||||
| lite::Tensor *CreateInputTensorFp32(TypeId data_type, std::vector<int> in_shape, bool infered_flag) { | |||||
| lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag) { | |||||
| auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); | auto in_tensor = new (std::nothrow) lite::Tensor(data_type, in_shape, Format_NHWC, lite::Tensor::Category::VAR); | ||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "new in_tensor failed."; | MS_LOG(ERROR) << "new in_tensor failed."; | ||||
| @@ -247,8 +247,8 @@ lite::Tensor *CreateBiasTensorFp32(TypeId data_type, std::vector<int> bias_shape | |||||
| return bias_tensor; | return bias_tensor; | ||||
| } | } | ||||
| lite::Tensor *CreateOutputTensorFp32(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index) { | |||||
| lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index) { | |||||
| auto out_tensor = new (std::nothrow) lite::Tensor(); | auto out_tensor = new (std::nothrow) lite::Tensor(); | ||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "new tmp_out_tensor failed."; | MS_LOG(ERROR) << "new tmp_out_tensor failed."; | ||||
| @@ -324,16 +324,16 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| std::vector<lite::Tensor *> new_outputs; | std::vector<lite::Tensor *> new_outputs; | ||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param); | auto new_conv_parameter = CreateNewConvParameter(conv_param); | ||||
| if (new_conv_parameter == nullptr) { | if (new_conv_parameter == nullptr) { | ||||
| FreeMemoryFp32(group_convs, new_inputs, new_outputs); | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "Get new conv parameter failed."; | MS_LOG(ERROR) << "Get new conv parameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // create new input for each group | // create new input for each group | ||||
| auto in_tensor = CreateInputTensorFp32(inputs.front()->data_type(), in_shape, infered_flag); | |||||
| auto in_tensor = CreateInputTensor(inputs.front()->data_type(), in_shape, infered_flag); | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp32(group_convs, new_inputs, new_outputs); | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create input tensor failed."; | MS_LOG(ERROR) << "create input tensor failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -345,7 +345,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| CreateFilterTensorFp32(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | CreateFilterTensorFp32(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | ||||
| if (filter_tensor == nullptr) { | if (filter_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp32(group_convs, new_inputs, new_outputs); | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create filter tensor failed."; | MS_LOG(ERROR) << "create filter tensor failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -357,7 +357,7 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| CreateBiasTensorFp32(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | CreateBiasTensorFp32(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | ||||
| if (bias_tensor == nullptr) { | if (bias_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp32(group_convs, new_inputs, new_outputs); | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create bias_tensor failed."; | MS_LOG(ERROR) << "create bias_tensor failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -366,10 +366,10 @@ kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor | |||||
| // create new output tensor | // create new output tensor | ||||
| for (size_t j = 0; j < outputs.size(); ++j) { | for (size_t j = 0; j < outputs.size(); ++j) { | ||||
| auto out_tensor = CreateOutputTensorFp32(out_shape, outputs, infered_flag, j); | |||||
| auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| delete new_conv_parameter; | delete new_conv_parameter; | ||||
| FreeMemoryFp32(group_convs, new_inputs, new_outputs); | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "new out_tensor failed."; | MS_LOG(ERROR) << "new out_tensor failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -61,6 +61,16 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| float *packed_input_ = nullptr; | float *packed_input_ = nullptr; | ||||
| float *col_major_input_ = nullptr; | float *col_major_input_ = nullptr; | ||||
| }; | }; | ||||
| void FreeMemory(const std::vector<kernel::LiteKernel *> &group_convs, const std::vector<lite::Tensor *> &new_inputs, | |||||
| const std::vector<lite::Tensor *> &new_outputs); | |||||
| ConvParameter *CreateNewConvParameter(ConvParameter *parameter); | |||||
| lite::Tensor *CreateInputTensor(TypeId data_type, std::vector<int> in_shape, bool infered_flag); | |||||
| lite::Tensor *CreateOutputTensor(std::vector<int> out_shape, const std::vector<lite::Tensor *> &outputs, | |||||
| bool infered_flag, int index); | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_H_ | ||||
| @@ -28,6 +28,11 @@ using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int GroupConvolutionCPUKernel::Init() { | int GroupConvolutionCPUKernel::Init() { | ||||
| for (int i = 0; i < group_num_; ++i) { | for (int i = 0; i < group_num_; ++i) { | ||||
| auto sub_conv = group_convs_.at(i); | |||||
| if (sub_conv == nullptr) { | |||||
| MS_LOG(ERROR) << "sub con " << i << " is null."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = group_convs_.at(i)->Init(); | auto ret = group_convs_.at(i)->Init(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Sub kernel init failed."; | MS_LOG(ERROR) << "Sub kernel init failed."; | ||||
| @@ -127,7 +132,7 @@ int GroupConvolutionCPUKernel::PreProcess() { | |||||
| auto ret = output->MallocData(); | auto ret = output->MallocData(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| FreeSubKernel(); | FreeSubKernel(); | ||||
| MS_LOG(ERROR) << "fp32 group conv out tensor malloc data failed."; | |||||
| MS_LOG(ERROR) << "group conv out tensor malloc data failed."; | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,15 +41,17 @@ class GroupConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int PreProcess() override; | int PreProcess() override; | ||||
| void SeparateInput(int group_id); | |||||
| void PostConcat(int group_id); | |||||
| virtual void SeparateInput(int group_id); | |||||
| virtual void PostConcat(int group_id); | |||||
| void FreeSubKernel(); | void FreeSubKernel(); | ||||
| private: | |||||
| protected: | |||||
| std::vector<kernel::LiteKernel *> group_convs_; | std::vector<kernel::LiteKernel *> group_convs_; | ||||
| const int group_num_; | |||||
| private: | |||||
| float *ori_in_data_ = nullptr; // do not free | float *ori_in_data_ = nullptr; // do not free | ||||
| float *ori_out_data_ = nullptr; // do not free | float *ori_out_data_ = nullptr; // do not free | ||||
| const int group_num_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -20,8 +20,10 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/kernel/arm/base/layout_transform.h" | #include "src/runtime/kernel/arm/base/layout_transform.h" | ||||
| #include "src/runtime/kernel/arm/fp32/convolution_fp32.h" | |||||
| #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | ||||
| #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" | #include "src/runtime/kernel/arm/int8/convolution_3x3_int8.h" | ||||
| #include "src/runtime/kernel/arm/int8/group_convolution_int8.h" | |||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| #include "src/runtime/kernel/arm/int8/opt_op_handler.h" | #include "src/runtime/kernel/arm/int8/opt_op_handler.h" | ||||
| @@ -32,6 +34,7 @@ using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::schema::PrimitiveType_Conv2D; | using mindspore::schema::PrimitiveType_Conv2D; | ||||
| using mindspore::schema::Format::Format_NHWC; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| void ConvolutionInt8CPUKernel::CheckSupportOptimize() { | void ConvolutionInt8CPUKernel::CheckSupportOptimize() { | ||||
| @@ -242,6 +245,166 @@ int ConvolutionInt8CPUKernel::Run() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| lite::Tensor *CreateFilterTensorInt8(TypeId data_type, std::vector<int> filter_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int copy_length, int index) { | |||||
| MS_ASSERT(data_type == kNumberTypeInt8); | |||||
| auto filter_tensor = | |||||
| new (std::nothrow) lite::Tensor(data_type, filter_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (filter_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new filter_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = filter_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete filter_tensor; | |||||
| MS_LOG(ERROR) << "filter_tensor malloc failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *origin_weight = reinterpret_cast<int8_t *>(inputs.at(kWeightIndex)->data_c()); | |||||
| memcpy(filter_tensor->data_c(), origin_weight + index * copy_length, copy_length * sizeof(int8_t)); | |||||
| return filter_tensor; | |||||
| } | |||||
| lite::Tensor *CreateBiasTensorInt8(TypeId data_type, std::vector<int> bias_shape, | |||||
| const std::vector<lite::Tensor *> &inputs, int new_out_channel, int index) { | |||||
| MS_ASSERT(data_type == kNumberTypeInt32); | |||||
| auto *origin_bias = inputs.at(kBiasIndex)->data_c(); | |||||
| auto bias_tensor = | |||||
| new (std::nothrow) lite::Tensor(data_type, bias_shape, Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| if (bias_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new bias_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = bias_tensor->MallocData(); | |||||
| if (ret != RET_OK) { | |||||
| delete bias_tensor; | |||||
| MS_LOG(ERROR) << "bias_tensor malloc failed."; | |||||
| return nullptr; | |||||
| } | |||||
| auto bias_data = reinterpret_cast<int32_t *>(origin_bias); | |||||
| memcpy(bias_tensor->data_c(), bias_data + index * new_out_channel, new_out_channel * sizeof(int32_t)); | |||||
| return bias_tensor; | |||||
| } | |||||
| kernel::LiteKernel *CpuConvInt8KernelSelect(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && | |||||
| conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| if (mindspore::lite::IsSupportSDot()) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| #else | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| #endif | |||||
| } else if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive, | |||||
| int group) { | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | |||||
| std::vector<int> in_shape; | |||||
| std::vector<int> out_shape; | |||||
| int new_in_channel = inputs.at(kWeightIndex)->Channel(); | |||||
| int new_out_channel = 0; | |||||
| if (group == 0) { | |||||
| MS_LOG(ERROR) << "Divisor 'group' cannot be 0."; | |||||
| return nullptr; | |||||
| } else { | |||||
| new_out_channel = inputs.at(kWeightIndex)->Batch() / group; | |||||
| } | |||||
| bool infered_flag = primitive != nullptr && primitive->infer_flag(); | |||||
| if (infered_flag) { | |||||
| int batch = inputs.front()->Batch(); | |||||
| int in_h = inputs.front()->Height(); | |||||
| int in_w = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = new_in_channel; | |||||
| conv_param->output_channel_ = new_out_channel; | |||||
| in_shape = {batch, in_h, in_w, new_in_channel}; | |||||
| out_shape = {batch, conv_param->output_h_, conv_param->output_w_, new_out_channel}; | |||||
| } | |||||
| std::vector<int> filter_shape = {new_out_channel, conv_param->kernel_h_, conv_param->kernel_w_, new_in_channel}; | |||||
| std::vector<int> bias_shape = {new_out_channel}; | |||||
| // create sub kernels | |||||
| std::vector<kernel::LiteKernel *> group_convs; | |||||
| for (int i = 0; i < group; ++i) { | |||||
| std::vector<lite::Tensor *> new_inputs; | |||||
| std::vector<lite::Tensor *> new_outputs; | |||||
| auto new_conv_parameter = CreateNewConvParameter(conv_param); | |||||
| if (new_conv_parameter == nullptr) { | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "Get new conv parameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // create new input for each group | |||||
| auto input_data_type = inputs.front()->data_type(); | |||||
| MS_ASSERT(input_data_type == kNumberTypeInt8); | |||||
| auto in_tensor = CreateInputTensor(input_data_type, in_shape, infered_flag); | |||||
| if (in_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create input tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(in_tensor); | |||||
| // create new weight | |||||
| int copy_length = conv_param->kernel_h_ * conv_param->kernel_w_ * new_in_channel * new_out_channel; | |||||
| auto filter_tensor = | |||||
| CreateFilterTensorInt8(inputs.at(kWeightIndex)->data_type(), filter_shape, inputs, copy_length, i); | |||||
| if (filter_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create filter tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(filter_tensor); | |||||
| // if has bias, create new bias | |||||
| if (inputs.size() == 3) { | |||||
| auto bias_tensor = | |||||
| CreateBiasTensorInt8(inputs.at(kBiasIndex)->data_type(), bias_shape, inputs, new_out_channel, i); | |||||
| if (bias_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "create bias_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_inputs.emplace_back(bias_tensor); | |||||
| } | |||||
| // create new output tensor | |||||
| for (size_t j = 0; j < outputs.size(); ++j) { | |||||
| auto out_tensor = CreateOutputTensor(out_shape, outputs, infered_flag, j); | |||||
| if (out_tensor == nullptr) { | |||||
| delete new_conv_parameter; | |||||
| FreeMemory(group_convs, new_inputs, new_outputs); | |||||
| MS_LOG(ERROR) << "new out_tensor failed."; | |||||
| return nullptr; | |||||
| } | |||||
| new_outputs.emplace_back(out_tensor); | |||||
| } | |||||
| group_convs.emplace_back(CpuConvInt8KernelSelect( | |||||
| new_inputs, new_outputs, reinterpret_cast<OpParameter *>(new_conv_parameter), ctx, primitive)); | |||||
| } | |||||
| return new (std::nothrow) | |||||
| GroupConvolutionInt8CPUKernel(op_parameter, inputs, outputs, ctx, primitive, group_convs, group); | |||||
| } | |||||
| kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const InnerContext *ctx, const kernel::KernelKey &desc, | const InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| @@ -249,27 +412,12 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> & | |||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D); | ||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| int kernel_h = conv_param->kernel_h_; | |||||
| int kernel_w = conv_param->kernel_w_; | |||||
| int stride_h = conv_param->stride_h_; | |||||
| int stride_w = conv_param->stride_w_; | |||||
| int dilation_h = conv_param->dilation_h_; | |||||
| int dilation_w = conv_param->dilation_w_; | |||||
| kernel::LiteKernel *kernel; | |||||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| if (mindspore::lite::IsSupportSDot()) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else { | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| #else | |||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| #endif | |||||
| } else if (kernel_h == 1 && kernel_w == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| kernel::LiteKernel *kernel = nullptr; | |||||
| if (conv_param->group_ == 1) { | |||||
| kernel = CpuConvInt8KernelSelect(inputs, outputs, opParameter, ctx, primitive); | |||||
| } else { | } else { | ||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| MS_ASSERT(conv_param->group_ > 1); | |||||
| kernel = CpuGroupConvInt8KernelCreator(inputs, outputs, opParameter, ctx, primitive, conv_param->group_); | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/int8/group_convolution_int8.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Conv2D; | |||||
| namespace mindspore::kernel { | |||||
| void GroupConvolutionInt8CPUKernel::SeparateInput(int group_id) { | |||||
| int in_plane = conv_param_->input_h_ * conv_param_->input_w_; | |||||
| int sub_in_channel = conv_param_->input_channel_; | |||||
| int ori_in_channel = sub_in_channel * group_num_; | |||||
| auto sub_in_data = reinterpret_cast<int8_t *>(group_convs_.at(group_id)->in_tensors().front()->data_c()); | |||||
| int8_t *src_ptr = ori_in_data_ + group_id * sub_in_channel; | |||||
| int8_t *dst_ptr = sub_in_data; | |||||
| for (int i = 0; i < in_plane; ++i) { | |||||
| memcpy(dst_ptr, src_ptr, sub_in_channel * sizeof(int8_t)); | |||||
| src_ptr += ori_in_channel; | |||||
| dst_ptr += sub_in_channel; | |||||
| } | |||||
| } | |||||
| void GroupConvolutionInt8CPUKernel::PostConcat(int group_id) { | |||||
| int out_plane = conv_param_->output_h_ * conv_param_->output_w_; | |||||
| int sub_out_channel = conv_param_->output_channel_; | |||||
| int ori_out_channel = sub_out_channel * group_num_; | |||||
| auto sub_out_data = reinterpret_cast<int8_t *>(group_convs_.at(group_id)->out_tensors().front()->data_c()); | |||||
| int8_t *src_ptr = sub_out_data; | |||||
| int8_t *dst_ptr = ori_out_data_ + group_id * sub_out_channel; | |||||
| for (int i = 0; i < out_plane; ++i) { | |||||
| memcpy(dst_ptr, src_ptr, sub_out_channel * sizeof(int8_t)); | |||||
| src_ptr += sub_out_channel; | |||||
| dst_ptr += ori_out_channel; | |||||
| } | |||||
| } | |||||
| int GroupConvolutionInt8CPUKernel::Run() { | |||||
| ori_in_data_ = reinterpret_cast<int8_t *>(in_tensors().front()->data_c()); | |||||
| ori_out_data_ = reinterpret_cast<int8_t *>(out_tensors().front()->data_c()); | |||||
| for (int i = 0; i < group_num_; ++i) { | |||||
| // first, separate group conv input into several parts. This step must be in runtime stage. | |||||
| SeparateInput(i); | |||||
| // sun kernels run | |||||
| auto ret = group_convs_.at(i)->Run(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "sub kernel " << i << " execute failed."; | |||||
| return ret; | |||||
| } | |||||
| // post process, concat all outputs of sub-kernels into one output | |||||
| PostConcat(i); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/op_base.h" | |||||
| #include "src/runtime/kernel/arm/fp32/group_convolution_fp32.h" | |||||
| namespace mindspore::kernel { | |||||
| class GroupConvolutionInt8CPUKernel : public GroupConvolutionCPUKernel { | |||||
| public: | |||||
| GroupConvolutionInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive, | |||||
| std::vector<kernel::LiteKernel *> group_convs, const int group_num) | |||||
| : GroupConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive, group_convs, group_num) { | |||||
| } // opParameter(in channel, out channel) in this kernel has been split to groups, if | |||||
| // you want to get real params, multiply in channel / out channel with group num | |||||
| ~GroupConvolutionInt8CPUKernel() override { GroupConvolutionCPUKernel::FreeSubKernel(); } | |||||
| int Run() override; | |||||
| void SeparateInput(int group_id) override; | |||||
| void PostConcat(int group_id) override; | |||||
| private: | |||||
| int8_t *ori_in_data_ = nullptr; // do not free | |||||
| int8_t *ori_out_data_ = nullptr; // do not free | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_GROUP_CONVOLUTION_INT8_H_ | |||||