From: @mengyuanli Reviewed-by: @zhang_xue_tong,@hangangqiang Signed-off-by: @zhang_xue_tongtags/v1.2.0-rc1
| @@ -23,7 +23,6 @@ typedef struct GatherParameter { | |||||
| // Primitive parameter | // Primitive parameter | ||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| int axis_; | int axis_; | ||||
| int quant_type_; | |||||
| } GatherParameter; | } GatherParameter; | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_ | ||||
| @@ -25,8 +25,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * | |||||
| const TensorC *indices = inputs[1]; | const TensorC *indices = inputs[1]; | ||||
| TensorC *output = outputs[0]; | TensorC *output = outputs[0]; | ||||
| output->data_type_ = input->data_type_; | output->data_type_ = input->data_type_; | ||||
| GatherParameter *param = (GatherParameter *)parameter; | |||||
| if (param->quant_type_ == QuantType_WeightQuant) { | |||||
| if (parameter->quant_type_ == QuantType_WeightQuant) { | |||||
| output->data_type_ = kNumberTypeFloat32; | output->data_type_ = kNumberTypeFloat32; | ||||
| } | } | ||||
| output->format_ = input->format_; | output->format_ = input->format_; | ||||
| @@ -80,6 +80,7 @@ typedef struct OpParameter { | |||||
| bool infer_flag_; | bool infer_flag_; | ||||
| int type_; | int type_; | ||||
| int thread_num_; | int thread_num_; | ||||
| int quant_type_; | |||||
| } OpParameter; | } OpParameter; | ||||
| typedef struct QuantArg { | typedef struct QuantArg { | ||||
| @@ -49,7 +49,7 @@ int MindrtExecutor::Prepare(const std::vector<kernel::LiteKernel *> &kernels) { | |||||
| for (size_t j = 0; j < outTensorSize; j++) { | for (size_t j = 0; j < outTensorSize; j++) { | ||||
| auto data = | auto data = | ||||
| std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->in_tensors()[j], static_cast<int>(j)); | |||||
| std::make_shared<OpData<Tensor>>(opActors_[i]->GetAID(), kernels[i]->out_tensors()[j], static_cast<int>(j)); | |||||
| outputData_.emplace_back(data); | outputData_.emplace_back(data); | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,19 +14,20 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/where_parameter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| namespace { | namespace { | ||||
| OpParameter *PopulateWhereParameter(const void *prim) { | OpParameter *PopulateWhereParameter(const void *prim) { | ||||
| OpParameter *where_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter))); | |||||
| WhereParameter *where_parameter = reinterpret_cast<WhereParameter *>(malloc(sizeof(WhereParameter))); | |||||
| if (where_parameter == nullptr) { | if (where_parameter == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc Where parameter failed."; | MS_LOG(ERROR) << "malloc Where parameter failed."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(where_parameter, 0, sizeof(OpParameter)); | memset(where_parameter, 0, sizeof(OpParameter)); | ||||
| auto primitive = static_cast<const schema::Primitive *>(prim); | auto primitive = static_cast<const schema::Primitive *>(prim); | ||||
| where_parameter->type_ = primitive->value_type(); | |||||
| where_parameter->op_parameter_.type_ = primitive->value_type(); | |||||
| return reinterpret_cast<OpParameter *>(where_parameter); | return reinterpret_cast<OpParameter *>(where_parameter); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -135,6 +135,8 @@ int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_i | |||||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive)); | MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive)); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| parameter->quant_type_ = node->quant_type_; | |||||
| op_parameters_[node->output_indices_.at(0)] = parameter; | op_parameters_[node->output_indices_.at(0)] = parameter; | ||||
| parameter->infer_flag_ = !(*infer_shape_interrupt); | parameter->infer_flag_ = !(*infer_shape_interrupt); | ||||
| auto ret = KernelInferShape(inputs, &outputs, parameter); | auto ret = KernelInferShape(inputs, &outputs, parameter); | ||||
| @@ -81,7 +81,7 @@ class QuantParamHolder : public Value { | |||||
| } | } | ||||
| void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | ||||
| if (index > this->input_quant_param_.size()) { | |||||
| if (index >= this->input_quant_param_.size()) { | |||||
| std::vector<schema::QuantParamT> place_quant(1); | std::vector<schema::QuantParamT> place_quant(1); | ||||
| this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(), | this->input_quant_param_.insert(this->input_quant_param_.end(), index + 1 - input_quant_param_.size(), | ||||
| place_quant); | place_quant); | ||||
| @@ -94,7 +94,7 @@ class QuantParamHolder : public Value { | |||||
| } | } | ||||
| void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { | void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { | ||||
| if (index > this->output_quant_param_.size()) { | |||||
| if (index >= this->output_quant_param_.size()) { | |||||
| std::vector<schema::QuantParamT> place_quant(1); | std::vector<schema::QuantParamT> place_quant(1); | ||||
| this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(), | this->output_quant_param_.insert(this->output_quant_param_.end(), index + 1 - output_quant_param_.size(), | ||||
| place_quant); | place_quant); | ||||