| @@ -99,8 +99,9 @@ enum PaddingMode : byte { | |||
| } | |||
| table Pad { | |||
| paddingmode: PaddingMode; | |||
| paddings: [int]; | |||
| paddingMode: PaddingMode; | |||
| constantValue: float; | |||
| } | |||
| table Maximum { | |||
| @@ -45,6 +45,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) { | |||
| return new lite::FullConnection(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_Power: | |||
| return new lite::Power(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_Pad: | |||
| return new lite::Pad(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_Range: | |||
| return new lite::Range(const_cast<schema::Primitive *>(primitive)); | |||
| case schema::PrimitiveType_Mul: | |||
| @@ -43,7 +43,7 @@ | |||
| #include "src/runtime/kernel/arm/opclib/fp32/local_response_norm.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/expandDims.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/arithmetic_self.h" | |||
| #include "src/runtime/kernel/arm/opclib/pad.h" | |||
| #include "src/runtime/kernel/arm/opclib/pad_parameter.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/fill.h" | |||
| #include "src/runtime/kernel/arm/opclib/transpose.h" | |||
| #include "src/runtime/kernel/arm/opclib/split.h" | |||
| @@ -300,6 +300,42 @@ ConvParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { | |||
| return parameter; | |||
| } | |||
| ConvParameter *PopulateDeconvParameter(const lite::Primitive *primitive) { | |||
| ConvParameter *parameter = new ConvParameter(); | |||
| auto conv_primitive = primitive->Value()->value_as_DeConv2D(); | |||
| parameter->kernel_h_ = conv_primitive->kernelH(); | |||
| parameter->kernel_w_ = conv_primitive->kernelW(); | |||
| parameter->stride_h_ = conv_primitive->strideH(); | |||
| parameter->stride_w_ = conv_primitive->strideW(); | |||
| auto deconv_lite_primitive = (lite::DeConv2D *)primitive; | |||
| MS_ASSERT(nullptr != deconvdw_lite_primitive); | |||
| parameter->pad_u_ = deconv_lite_primitive->PadUp(); | |||
| parameter->pad_d_ = deconv_lite_primitive->PadDown(); | |||
| parameter->pad_l_ = deconv_lite_primitive->PadLeft(); | |||
| parameter->pad_r_ = deconv_lite_primitive->PadRight(); | |||
| parameter->pad_h_ = deconv_lite_primitive->PadUp(); | |||
| parameter->pad_w_ = deconv_lite_primitive->PadLeft(); | |||
| parameter->dilation_h_ = conv_primitive->dilateH(); | |||
| parameter->dilation_w_ = conv_primitive->dilateW(); | |||
| auto act_type = conv_primitive->activationType(); | |||
| switch (act_type) { | |||
| case schema::ActivationType_RELU: | |||
| parameter->is_relu_ = true; | |||
| parameter->is_relu6_ = false; | |||
| break; | |||
| case schema::ActivationType_RELU6: | |||
| parameter->is_relu_ = false; | |||
| parameter->is_relu6_ = true; | |||
| break; | |||
| default: | |||
| parameter->is_relu_ = false; | |||
| parameter->is_relu6_ = false; | |||
| break; | |||
| } | |||
| return parameter; | |||
| } | |||
| SoftmaxParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { | |||
| auto softmax_primitive = primitive->Value()->value_as_SoftMax(); | |||
| SoftmaxParameter *parameter = new (std::nothrow) SoftmaxParameter(); | |||
| @@ -335,19 +371,31 @@ ReduceParameter *PopulateReduceParameter(const lite::Primitive *primitive) { | |||
| } | |||
| PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { | |||
| PadParameter *parameter = new (std::nothrow) PadParameter(); | |||
| if (parameter == nullptr) { | |||
| PadParameter *pad_param = new (std::nothrow) PadParameter(); | |||
| if (pad_param == nullptr) { | |||
| MS_LOG(ERROR) << "new PadParameter failed."; | |||
| return nullptr; | |||
| } | |||
| auto param = primitive->Value()->value_as_Pad(); | |||
| auto size = param->paddings()->size(); | |||
| parameter->ori_size_ = size; | |||
| auto valid_size = size <= 8 ? size : 8; | |||
| for (size_t i = 0; i < valid_size; i++) { | |||
| parameter->paddings[i] = (*(param->paddings()))[i]; | |||
| auto pad_node = primitive->Value()->value_as_Pad(); | |||
| pad_param->pad_mode_ = pad_node->paddingMode(); | |||
| if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) { | |||
| pad_param->constant_value_ = pad_node->constantValue(); | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid padding mode: " << pad_param->pad_mode_; | |||
| return nullptr; | |||
| } | |||
| return parameter; | |||
| auto size = pad_node->paddings()->size(); | |||
| if (size > MAX_PAD_SIZE) { | |||
| MS_LOG(ERROR) << "Invalid padding size: " << size; | |||
| return nullptr; | |||
| } | |||
| for (size_t i = 0; i < size; i++) { | |||
| pad_param->paddings_[MAX_PAD_SIZE - size + i] = (*(pad_node->paddings()))[i]; | |||
| } | |||
| return pad_param; | |||
| } | |||
| ActivationParameter *PopulateActivationParameter(const lite::Primitive *primitive) { | |||
| @@ -891,7 +939,7 @@ FlattenParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { | |||
| MS_LOG(ERROR) << "new FlattenParameter fail!"; | |||
| return nullptr; | |||
| } | |||
| return parameter; | |||
| return parameter; | |||
| } | |||
| StridedSliceParameter *PopulateStridedSliceParam(const lite::Primitive *primitive) { | |||
| @@ -932,6 +980,8 @@ OpParameter *PopulateParameter(const lite::Primitive *primitive) { | |||
| return reinterpret_cast<OpParameter *>(PopulateConvDwParameter(primitive)); | |||
| case schema::PrimitiveType_DeDepthwiseConv2D: | |||
| return reinterpret_cast<OpParameter *>(PopulateDeconvDwParameter(primitive)); | |||
| case schema::PrimitiveType_DeConv2D: | |||
| return reinterpret_cast<OpParameter *>(PopulateDeconvParameter(primitive)); | |||
| case schema::PrimitiveType_FusedBatchNorm: | |||
| return reinterpret_cast<OpParameter *>(PopulateFusedBatchNorm(primitive)); | |||
| case schema::PrimitiveType_FullConnection: | |||
| @@ -49,41 +49,51 @@ namespace mindspore::kernel { | |||
| ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { | |||
| if (bias_data_ != nullptr) { | |||
| free(bias_data_); | |||
| bias_data_ = nullptr; | |||
| } | |||
| if (nhwc4_input_ != nullptr) { | |||
| free(nhwc4_input_); | |||
| nhwc4_input_ = nullptr; | |||
| } | |||
| } | |||
| void ConvolutionBaseCPUKernel::FreeQuantParam() { | |||
| if (quant_args_ != nullptr) { | |||
| ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; | |||
| if (conv_quant_arg_ == nullptr) { | |||
| return; | |||
| } | |||
| if (conv_quant_arg_->real_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->real_multiplier_); | |||
| conv_quant_arg_->real_multiplier_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->left_shift_ != nullptr) { | |||
| free(conv_quant_arg_->left_shift_); | |||
| conv_quant_arg_->left_shift_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->right_shift_ != nullptr) { | |||
| free(conv_quant_arg_->right_shift_); | |||
| conv_quant_arg_->right_shift_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->quant_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->quant_multiplier_); | |||
| conv_quant_arg_->quant_multiplier_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->out_act_min_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_min_); | |||
| conv_quant_arg_->out_act_min_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->out_act_max_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_max_); | |||
| conv_quant_arg_->out_act_max_ = nullptr; | |||
| } | |||
| if (conv_quant_arg_->quant_args_ != nullptr) { | |||
| for (int i = 0; i < 3; ++i) { | |||
| if (*(quant_args_ + i) != nullptr) { | |||
| free(*(quant_args_ + i)); | |||
| if (*(conv_quant_arg_->quant_args_ + i) != nullptr) { | |||
| free(*(conv_quant_arg_->quant_args_ + i)); | |||
| } | |||
| } | |||
| } | |||
| if (conv_quant_arg_ != nullptr) { | |||
| if (conv_quant_arg_->real_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->real_multiplier_); | |||
| } | |||
| if (conv_quant_arg_->left_shift_ != nullptr) { | |||
| free(conv_quant_arg_->left_shift_); | |||
| } | |||
| if (conv_quant_arg_->right_shift_ != nullptr) { | |||
| free(conv_quant_arg_->right_shift_); | |||
| } | |||
| if (conv_quant_arg_->quant_multiplier_ != nullptr) { | |||
| free(conv_quant_arg_->quant_multiplier_); | |||
| } | |||
| if (conv_quant_arg_->out_act_min_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_min_); | |||
| } | |||
| if (conv_quant_arg_->out_act_max_ != nullptr) { | |||
| free(conv_quant_arg_->out_act_max_); | |||
| } | |||
| free(conv_quant_arg_); | |||
| } | |||
| } | |||
| int ConvolutionBaseCPUKernel::Init() { | |||
| @@ -116,11 +126,19 @@ int ConvolutionBaseCPUKernel::CheckLayout(lite::tensor::Tensor *input_tensor) { | |||
| } | |||
| int ConvolutionBaseCPUKernel::SetQuantParam() { | |||
| conv_quant_arg_ = new ConvQuantArg(); | |||
| quant_args_ = reinterpret_cast<QuantArg **>(malloc(3 * sizeof(QuantArg *))); | |||
| ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; | |||
| conv_quant_arg_->quant_args_ = reinterpret_cast<QuantArg **>(malloc(3 * sizeof(QuantArg *))); | |||
| if (conv_quant_arg_->quant_args_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc quant_args_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| // per-tensor init | |||
| for (int j = 0; j < 3; ++j) { | |||
| quant_args_[j] = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg))); | |||
| conv_quant_arg_->quant_args_[j] = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg))); | |||
| if (conv_quant_arg_->quant_args_[j] == nullptr) { | |||
| MS_LOG(ERROR) << "malloc quant_args_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto input_tensor = inputs_.at(kInputIndex); | |||
| auto weight_tensor = inputs_.at(kWeightIndex); | |||
| @@ -129,16 +147,15 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||
| auto weight_quant_arg = weight_tensor->GetQuantParams().front(); | |||
| auto output_quant_arg = output_tensor->GetQuantParams().front(); | |||
| // input | |||
| quant_args_[0][0].zp_ = input_quant_arg.zeroPoint; | |||
| quant_args_[0][0].scale_ = input_quant_arg.scale; | |||
| conv_quant_arg_->quant_args_[0][0].zp_ = input_quant_arg.zeroPoint; | |||
| conv_quant_arg_->quant_args_[0][0].scale_ = input_quant_arg.scale; | |||
| // weight | |||
| quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint; | |||
| quant_args_[1][0].scale_ = weight_quant_arg.scale; | |||
| conv_quant_arg_->quant_args_[1][0].zp_ = weight_quant_arg.zeroPoint; | |||
| conv_quant_arg_->quant_args_[1][0].scale_ = weight_quant_arg.scale; | |||
| // output | |||
| quant_args_[2][0].zp_ = output_quant_arg.zeroPoint; | |||
| quant_args_[2][0].scale_ = output_quant_arg.scale; | |||
| conv_quant_arg_->quant_args_[2][0].zp_ = output_quant_arg.zeroPoint; | |||
| conv_quant_arg_->quant_args_[2][0].scale_ = output_quant_arg.scale; | |||
| conv_quant_arg_->quant_args_ = quant_args_; | |||
| conv_quant_arg_->real_multiplier_ = reinterpret_cast<double *>(malloc(sizeof(double))); | |||
| conv_quant_arg_->left_shift_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||
| conv_quant_arg_->right_shift_ = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t))); | |||
| @@ -151,7 +168,6 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||
| QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[0], &conv_quant_arg_->left_shift_[0], | |||
| &conv_quant_arg_->right_shift_[0]); | |||
| conv_param_->conv_quant_arg_ = *conv_quant_arg_; | |||
| ComputeQuantOutRange(conv_param_); | |||
| return RET_OK; | |||
| } | |||
| @@ -57,15 +57,12 @@ class ConvolutionBaseCPUKernel : public LiteKernel { | |||
| int thread_count_; | |||
| int tile_num_; | |||
| void *bias_data_ = nullptr; | |||
| void *nhwc4_input_; | |||
| void *nhwc4_input_ = nullptr; | |||
| const Context *ctx_; | |||
| ConvParameter *conv_param_; | |||
| ConvQuantArg *conv_quant_arg_ = nullptr; | |||
| QuantArg **quant_args_ = nullptr; | |||
| LayoutConvertor convert_func_; | |||
| }; | |||
| void ComputeQuantOutRange(ConvParameter *conv_param); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_ | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp32/pad.h" | |||
| #include "src/runtime/kernel/arm/int8/pad_int8.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_factory.h" | |||
| #include "include/errorcode.h" | |||
| #include "include/context.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Pad; | |||
| namespace mindspore::kernel { | |||
| kernel::LiteKernel *CpuPadInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc) { | |||
| auto *kernel = new (std::nothrow) PadInt8CPUKernel(opParameter, inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PadCPUKernel failed."; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc) { | |||
| auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PadCPUKernel failed."; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| kernel::LiteKernel *CpuPadKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::Context *ctx, const kernel::KernelKey &desc) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Concat); | |||
| auto input_tensor = inputs.at(kInputIndex); | |||
| auto data_type = input_tensor->data_type(); | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| switch (data_type) { | |||
| case kNumberTypeInt8: | |||
| kernel = CpuPadInt8KernelCreator(inputs, outputs, opParameter, ctx, desc); | |||
| break; | |||
| case kNumberTypeFloat32: | |||
| kernel = CpuPadFp32KernelCreator(inputs, outputs, opParameter, ctx, desc); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| delete kernel; | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, PrimitiveType_Pad, CpuPadKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -39,10 +39,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() { | |||
| free(tmp_ptr_); | |||
| tmp_ptr_ = nullptr; | |||
| } | |||
| if (bias_ptr_ != nullptr) { | |||
| free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| } | |||
| if (weight_ptr_ != nullptr) { | |||
| free(weight_ptr_); | |||
| weight_ptr_ = nullptr; | |||
| @@ -64,15 +60,15 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { | |||
| int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| if (inputs_.size() == 3) { | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->col_ * C4NUM * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| bias_data_ = malloc(matmul_param_->col_ * C4NUM * sizeof(float)); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_ptr_, 0, matmul_param_->col_ * C4NUM * sizeof(float)); | |||
| memcpy(bias_ptr_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); | |||
| memset(bias_data_, 0, matmul_param_->col_ * C4NUM * sizeof(float)); | |||
| memcpy(bias_data_, inputs_[2]->Data(), conv_param_->output_channel_ * sizeof(float)); | |||
| } else { | |||
| bias_ptr_ = nullptr; | |||
| bias_data_ = nullptr; | |||
| } | |||
| weight_ptr_ = reinterpret_cast<float *>( | |||
| @@ -109,15 +105,15 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc tmp_ptr_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| c4_output_ = reinterpret_cast<float *>(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * | |||
| sizeof(float))); | |||
| c4_output_ = | |||
| reinterpret_cast<float *>(malloc(outputs_[0]->ElementsC4Num() / conv_param_->output_batch_ * sizeof(float))); | |||
| if (c4_output_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc c4_output_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| c4_input_ = reinterpret_cast<float *>(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * | |||
| sizeof(float))); | |||
| c4_input_ = | |||
| reinterpret_cast<float *>(malloc(inputs_[0]->ElementsC4Num() / conv_param_->input_batch_ * sizeof(float))); | |||
| if (c4_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc c4_input_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| @@ -189,9 +185,12 @@ int Convolution1x1CPUKernel::DoPostFunc(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| float *cur_bias = | |||
| (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + task_id * thread_oc_stride_; | |||
| PostConvFuncFp32(c4_output_ + matmul_param_->row_ * thread_oc_stride_ * task_id, | |||
| output_ptr_ + task_id * thread_oc_stride_, bias_ptr_ + task_id * thread_oc_stride_, cur_oc, | |||
| matmul_param_->row_, conv_param_->output_channel_, conv_param_->is_relu_, conv_param_->is_relu6_); | |||
| output_ptr_ + task_id * thread_oc_stride_, cur_bias, cur_oc, matmul_param_->row_, | |||
| conv_param_->output_channel_, conv_param_->is_relu_, conv_param_->is_relu6_); | |||
| return RET_OK; | |||
| } | |||
| @@ -228,4 +227,3 @@ int Convolution1x1CPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -56,7 +56,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { | |||
| int thread_hw_stride_ = 0; | |||
| int thread_oc4_count_ = 0; | |||
| int thread_oc_stride_ = 0; | |||
| float *bias_ptr_ = nullptr; | |||
| float *weight_ptr_ = nullptr; | |||
| float *tmp_ptr_ = nullptr; | |||
| float *c4_input_ = nullptr; | |||
| @@ -169,10 +169,12 @@ int DeConvolutionCPUKernel::DoPostFunc(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| float *cur_bias = | |||
| (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_co_stride_ * task_id; | |||
| DeConvPostFp32(tmp_output_ + thread_co_stride_ * task_id * input_plane * kernel_plane, | |||
| c4_output_ + thread_co_stride_ * task_id * output_plane, output_ptr_ + thread_co_stride_ * task_id, | |||
| reinterpret_cast<float *>(bias_data_) + thread_co_stride_ * task_id, cur_oc, input_plane, kernel_plane, | |||
| output_plane, conv_param_); | |||
| cur_bias, cur_oc, input_plane, kernel_plane, output_plane, conv_param_); | |||
| return RET_OK; | |||
| } | |||
| @@ -224,4 +226,3 @@ int DeConvolutionCPUKernel::Run() { | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -37,7 +37,7 @@ constexpr int kInputRank = 4; | |||
| constexpr int kPaddingsSize = 8; | |||
| } // namespace | |||
| int PadCPUKernel::CheckInputsOutputsParams() { | |||
| int PadCPUKernel::Init() { | |||
| if (inputs_.size() != kInputNum || outputs_.size() != kOutputNum) { | |||
| MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << inputs_.size() << ", output size should be" | |||
| << kOutputNum << ", got " << outputs_.size(); | |||
| @@ -71,42 +71,6 @@ int PadCPUKernel::CheckInputsOutputsParams() { | |||
| return RET_OK; | |||
| } | |||
| int PadCPUKernel::MaybeConvertInputLayout() { | |||
| auto input = inputs_.at(0); | |||
| auto input_format = input->GetFormat(); | |||
| if (input_format != exec_format_) { | |||
| auto input_type = input->data_type(); | |||
| layout_convertor_ = LayoutTransform(input_type, input_format, exec_format_); | |||
| if (layout_convertor_ == nullptr) { | |||
| MS_LOG(ERROR) << "Pad lack layout convertor from " << input_format << " to " << exec_format_; | |||
| return RET_ERROR; | |||
| } | |||
| exec_input_data_ = reinterpret_cast<float *>(malloc(input->DataSize() * sizeof(float))); | |||
| if (exec_input_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Pad malloc failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PadCPUKernel::Init() { | |||
| auto ret = CheckInputsOutputsParams(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| ret = MaybeConvertInputLayout(); | |||
| if (ret != RET_OK) { | |||
| return ret; | |||
| } | |||
| auto output = outputs_.at(0); | |||
| output->SetFormat(exec_format_); | |||
| return RET_OK; | |||
| } | |||
| int PadImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| auto padKernel = reinterpret_cast<PadCPUKernel *>(cdata); | |||
| int error_code = padKernel->RunImpl(task_id); | |||
| @@ -125,11 +89,8 @@ int PadCPUKernel::RunImpl(int task_id) { | |||
| auto output_data = reinterpret_cast<float *>(output->Data()); | |||
| auto input_shape = input->shape().data(); | |||
| auto output_shape = output->shape().data(); | |||
| if (exec_input_data_ != nullptr) { | |||
| Pad(exec_input_data_, output_data, input_shape, output_shape, paddings_.data(), task_id, context_->threadNum); | |||
| } else { | |||
| Pad(input_data, output_data, input_shape, output_shape, paddings_.data(), task_id, context_->threadNum); | |||
| } | |||
| Pad(input_data, output_data, input_shape, output_shape, paddings_.data(), task_id, context_->threadNum); | |||
| return RET_OK; | |||
| } | |||
| @@ -142,15 +103,6 @@ int PadCPUKernel::Run() { | |||
| // todo parallel memset to save time | |||
| memset(output_data, 0, output_size * sizeof(float)); | |||
| auto input = inputs_.at(0); | |||
| if (exec_input_data_ != nullptr) { | |||
| if (layout_convertor_ == nullptr) { | |||
| return RET_NULL_PTR; | |||
| } | |||
| layout_convertor_(inputs_.at(0), exec_input_data_, input->Batch(), input->Height() * input->Width(), | |||
| input->Channel()); | |||
| } | |||
| int error_code = LiteBackendParallelLaunch(PadImpl, this, context_->threadNum); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "Pad run error, error_code[" << error_code << "]"; | |||
| @@ -158,30 +110,4 @@ int PadCPUKernel::Run() { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuPadFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc) { | |||
| if (opParameter == nullptr) { | |||
| MS_LOG(ERROR) << "Pad opParameter nullptr"; | |||
| return nullptr; | |||
| } | |||
| MS_ASSERT(desc.type == PrimitiveType_Pad); | |||
| auto *kernel = new (std::nothrow) PadCPUKernel(opParameter, inputs, outputs, ctx); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new PadCPUKernel failed."; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, PrimitiveType_Pad, CpuPadFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/kernel/arm/opclib/pad.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/pad.h" | |||
| #include "src/runtime/kernel/arm/base/layout_transform.h" | |||
| namespace mindspore::kernel { | |||
| @@ -29,31 +29,18 @@ class PadCPUKernel : public LiteKernel { | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||
| : LiteKernel(parameter, inputs, outputs), context_(ctx) {} | |||
| ~PadCPUKernel() { | |||
| if (exec_input_data_ != nullptr) { | |||
| free(exec_input_data_); | |||
| exec_input_data_ = nullptr; | |||
| } | |||
| } | |||
| ~PadCPUKernel() {} | |||
| int Init() override; | |||
| int ReSize() override { return 0; }; | |||
| int Run() override; | |||
| int RunImpl(int task_id); | |||
| private: | |||
| int CheckInputsOutputsParams(); | |||
| int MaybeConvertInputLayout(); | |||
| private: | |||
| std::vector<int> paddings_; | |||
| size_t paddings_size_; | |||
| const lite::Context *context_; | |||
| schema::Format exec_format_ = schema::Format_NHWC; | |||
| LayoutConvertor layout_convertor_ = nullptr; | |||
| float *exec_input_data_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PAD_H_ | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/int8/pad_int8.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_MEMORY_FAILED; | |||
| using mindspore::lite::RET_OK; | |||
| namespace mindspore::kernel { | |||
| void PadInt8CPUKernel::FreeQuantParam() { | |||
| if (pad_param_->pad_quant_arg_.in_quant_args_ != nullptr) { | |||
| free(pad_param_->pad_quant_arg_.in_quant_args_); | |||
| pad_param_->pad_quant_arg_.in_quant_args_ = nullptr; | |||
| } | |||
| if (pad_param_->pad_quant_arg_.out_quanr_args_ != nullptr) { | |||
| free(pad_param_->pad_quant_arg_.out_quanr_args_); | |||
| pad_param_->pad_quant_arg_.out_quanr_args_ = nullptr; | |||
| } | |||
| } | |||
| int PadInt8CPUKernel::SetQuantParam() { | |||
| PadQuantArg *pad_quant_args = &pad_param_->pad_quant_arg_; | |||
| pad_quant_args->in_quant_args_ = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg))); | |||
| if (pad_quant_args->in_quant_args_ == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| pad_quant_args->out_quanr_args_ = reinterpret_cast<QuantArg *>(malloc(sizeof(QuantArg))); | |||
| if (pad_quant_args->out_quanr_args_ == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| pad_quant_args->constant_value_ = reinterpret_cast<int8_t *>(malloc(sizeof(int8_t))); | |||
| if (pad_quant_args->constant_value_ == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| auto *input_tensor = inputs_.at(kInputIndex); | |||
| auto *out_tensor = outputs_.at(kOutputIndex); | |||
| auto in_quant_arg = input_tensor->GetQuantParams(); | |||
| auto out_quant_arg = out_tensor->GetQuantParams(); | |||
| pad_quant_args->in_quant_args_->zp_ = in_quant_arg.front().zeroPoint; | |||
| pad_quant_args->in_quant_args_->scale_ = in_quant_arg.front().scale; | |||
| pad_quant_args->out_quanr_args_->zp_ = out_quant_arg.front().zeroPoint; | |||
| pad_quant_args->out_quanr_args_->scale_ = out_quant_arg.front().scale; | |||
| if (pad_quant_args->in_quant_args_->scale_ != pad_quant_args->out_quanr_args_->scale_ || | |||
| pad_quant_args->in_quant_args_->zp_ != pad_quant_args->out_quanr_args_->zp_) { | |||
| MS_LOG(ERROR) << "Pad int8 op : scale & zp of output and input must be equal."; | |||
| return RET_ERROR; | |||
| } | |||
| pad_quant_args->constant_value_[0] = QuantizeToInt8( | |||
| pad_param_->constant_value_, pad_quant_args->in_quant_args_->scale_, pad_quant_args->in_quant_args_->zp_); | |||
| return RET_OK; | |||
| } | |||
| int PadInt8CPUKernel::InitPadParam() { | |||
| auto in_dims = inputs_[0]->shape(); | |||
| auto out_dims = outputs_[0]->shape(); | |||
| int ndims = in_dims.size(); | |||
| int in[] = {1, 1, 1, 1}; | |||
| int out[] = {1, 1, 1, 1}; | |||
| for (int i = 0; i < ndims; i++) { | |||
| in[DEFAULT_PAD_NDIMS - ndims + i] = in_dims[i]; | |||
| out[DEFAULT_PAD_NDIMS - ndims + i] = out_dims[i]; | |||
| } | |||
| memcpy(in_dims_, in, DEFAULT_PAD_NDIMS * sizeof(int)); | |||
| memcpy(out_dims_, out, DEFAULT_PAD_NDIMS * sizeof(int)); | |||
| return RET_OK; | |||
| } | |||
| int PadInt8CPUKernel::ReSize() { | |||
| InitPadParam(); | |||
| return RET_OK; | |||
| } | |||
| int PadInt8CPUKernel::Init() { | |||
| int error_code = InitPadParam(); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "InitPadParam failed. errorcode: " << error_code; | |||
| return error_code; | |||
| } | |||
| error_code = SetQuantParam(); | |||
| if (error_code != RET_OK) { | |||
| MS_LOG(ERROR) << "SetQuantParam failed. errorcode: " << error_code; | |||
| return error_code; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int PadInt8CPUKernel::Run() { | |||
| int8_t *in_data = reinterpret_cast<int8_t *>(inputs_[0]->Data()); | |||
| int8_t *out_data = reinterpret_cast<int8_t *>(outputs_[0]->Data()); | |||
| memset(out_data, pad_param_->pad_quant_arg_.constant_value_[0], outputs_[0]->ElementsNum() * sizeof(int8_t)); | |||
| PadConstant4D(in_data, out_data, in_dims_, out_dims_, pad_param_->paddings_); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ | |||
| #include <vector> | |||
| #include "include/errorcode.h" | |||
| #include "src/lite_kernel.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/runtime/kernel/arm/opclib/pad_parameter.h" | |||
| #include "src/runtime/kernel/arm/opclib/int8/pad.h" | |||
| namespace mindspore::kernel { | |||
| class PadInt8CPUKernel : public LiteKernel { | |||
| public: | |||
| explicit PadInt8CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx) | |||
| : LiteKernel(parameter, inputs, outputs) { | |||
| opParameter->thread_num_ = ctx->threadNum; | |||
| pad_param_ = reinterpret_cast<PadParameter *>(opParameter); | |||
| } | |||
| ~PadInt8CPUKernel() override { FreeQuantParam(); }; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| int SetQuantParam(); | |||
| int InitPadParam(); | |||
| void FreeQuantParam(); | |||
| private: | |||
| PadParameter *pad_param_; | |||
| int in_dims_[DEFAULT_PAD_NDIMS]; | |||
| int out_dims_[DEFAULT_PAD_NDIMS]; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_PAD_INT8_H_ | |||
| @@ -13,9 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/opclib/pad.h" | |||
| #include <float.h> | |||
| #include "src/runtime/kernel/arm/opclib/offset_utils.h" | |||
| #include "src/runtime/kernel/arm/opclib/fp32/pad.h" | |||
| void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| const int *paddings, const int tid, const int thread_num) { | |||
| @@ -34,4 +33,3 @@ void Pad(const float *input_data, float *output_data, const int *input_shape, co | |||
| } | |||
| } | |||
| } | |||
| @@ -13,23 +13,19 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_PAD_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_PAD_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include <memory.h> | |||
| #include <float.h> | |||
| #include "src/runtime/kernel/arm/opclib/offset_utils.h" | |||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||
| struct PadParameter { | |||
| OpParameter op_parameter_; | |||
| int paddings[8]; | |||
| size_t ori_size_; | |||
| }; | |||
| #include "src/runtime/kernel/arm/opclib/pad_parameter.h" | |||
| void Pad(const float *input_data, float *output_data, const int *input_shape, const int *output_shape, | |||
| const int *paddings, const int tid, const int thread_num); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_H_ | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_PAD_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/opclib/int8/pad.h" | |||
| void PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, | |||
| const int32_t *paddings) { | |||
| int32_t copy_size = in_dims[3]; | |||
| for (int n = 0; n < in_dims[0]; n++) { | |||
| for (int h = 0; h < in_dims[1]; h++) { | |||
| for (int w = 0; w < in_dims[2]; w++) { | |||
| const int8_t *in = in_data + offset(in_dims, n, h, w, 0); | |||
| int8_t *out = out_data + offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); | |||
| memcpy(out, in, copy_size * sizeof(int8_t)); | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PAD_INT8_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PAD_INT8_H_ | |||
| #include <string.h> | |||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||
| #include "src/runtime/kernel/arm/opclib/offset_utils.h" | |||
| #include "src/runtime/kernel/arm/opclib/pad_parameter.h" | |||
| void PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, | |||
| const int32_t *paddings); | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_PAD_INT8_H_ | |||
| @@ -25,6 +25,10 @@ inline int offset(const int *shape, const int dim0, const int dim1, const int di | |||
| return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; | |||
| } | |||
| inline int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { | |||
| return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; | |||
| } | |||
| inline int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); } | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_OFFSET_UTILS_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_PARAMETER_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_PARAMETER_H_ | |||
| #include "src/runtime/kernel/arm/opclib/op_base.h" | |||
| #define MAX_PAD_SIZE 8 | |||
| #define DEFAULT_PAD_NDIMS 4 | |||
| struct PadParameter { | |||
| OpParameter op_parameter_; | |||
| PadQuantArg pad_quant_arg_; | |||
| int paddings_[MAX_PAD_SIZE] = {0}; | |||
| int pad_mode_; | |||
| float constant_value_; | |||
| }; | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_PAD_PARAMETER_H_ | |||
| @@ -58,6 +58,12 @@ struct FcQuantArg { | |||
| int32_t quant_multiplier; | |||
| }; | |||
| struct PadQuantArg { | |||
| QuantArg *in_quant_args_ = nullptr; | |||
| QuantArg *out_quanr_args_ = nullptr; | |||
| int8_t *constant_value_ = nullptr; | |||
| }; | |||
| struct MulQuantArg { | |||
| QuantArg in_quant_args_[2]; | |||
| QuantArg out_quant_arg_; | |||