Merge pull request !3815 from xutianchun/quant_0731tags/v0.7.0-beta
| @@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| } | } | ||||
| node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); | node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); | ||||
| primitiveT_value->SetPrimitiveT(nullptr); | |||||
| std::vector<schema::TensorT *> outputs; | std::vector<schema::TensorT *> outputs; | ||||
| SetOpInputNode(cnode, metaGraphT.get(), node.get()); | SetOpInputNode(cnode, metaGraphT.get(), node.get()); | ||||
| SetOpOutputNode(outputs, metaGraphT.get(), node.get()); | SetOpOutputNode(outputs, metaGraphT.get(), node.get()); | ||||
| @@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | auto input_quant_params = primitiveT_value->GetInputQuantParams(); | ||||
| if (input_quant_params.empty()) { | if (input_quant_params.empty()) { | ||||
| MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; | MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; | ||||
| continue; | |||||
| } else { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_params[0]); | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | |||||
| } | } | ||||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_params[0]); | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | |||||
| // output | // output | ||||
| auto output_index = node->outputIndex[0]; | auto output_index = node->outputIndex[0]; | ||||
| auto tensor_output = metaGraphT->allTensors[output_index].get(); | auto tensor_output = metaGraphT->allTensors[output_index].get(); | ||||
| auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | ||||
| if (output_quant_params.empty()) { | if (output_quant_params.empty()) { | ||||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | ||||
| continue; | |||||
| } else { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | |||||
| } | } | ||||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | |||||
| // // TensorType | // // TensorType | ||||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | // valuePtr = primitive->GetAttr(kInputTensorDataType); | ||||
| // if (valuePtr != nullptr) { | // if (valuePtr != nullptr) { | ||||
| @@ -26,8 +26,8 @@ namespace mindspore::lite { | |||||
| class PrimitiveTValue : public Value { | class PrimitiveTValue : public Value { | ||||
| public: | public: | ||||
| explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} | explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} | ||||
| ~PrimitiveTValue() override { delete this->primitive; } | |||||
| // not responsible to free primitive, the one created the dynamic memory is responsible to free it. | |||||
| ~PrimitiveTValue() override = default; | |||||
| MS_DECLARE_PARENT(PrimitiveTValue, Value) | MS_DECLARE_PARENT(PrimitiveTValue, Value) | ||||
| @@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) { | |||||
| MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| if (this->quantType == QuantType_AwareTrainning) { | |||||
| if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { | |||||
| status = QuantDataFormatTrans(graphNode); | status = QuantDataFormatTrans(graphNode); | ||||
| if (status != 0) { | if (status != 0) { | ||||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | ||||
| @@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| } else if (fmkType == converter::FmkType_TFLITE) { | } else if (fmkType == converter::FmkType_TFLITE) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_QUANT_NONE: | case QuantType_QUANT_NONE: | ||||
| case QuantType_AwareTrainning: { | |||||
| case QuantType_AwareTrainning: | |||||
| case QuantType_PostTraining: { | |||||
| if (opType == schema::PrimitiveType_Conv2D) { | if (opType == schema::PrimitiveType_Conv2D) { | ||||
| weightTensor->format = schema::Format_KHWC; | weightTensor->format = schema::Format_KHWC; | ||||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| @@ -292,13 +292,32 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data, | |||||
| } | } | ||||
| STATUS Calibrator::ComputeThreshold() { | STATUS Calibrator::ComputeThreshold() { | ||||
| for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { | |||||
| for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { | |||||
| DivergInfo *info = iter->second.get(); | DivergInfo *info = iter->second.get(); | ||||
| info->ComputeThreshold(); | info->ComputeThreshold(); | ||||
| } | } | ||||
| for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { | |||||
| // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as | |||||
| for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { | |||||
| DivergInfo *info = iter->second.get(); | DivergInfo *info = iter->second.get(); | ||||
| info->ComputeThreshold(); | |||||
| auto cnode = info->cnode; | |||||
| bool already_computed = false; | |||||
| auto input = cnode->input(1); | |||||
| if (input->isa<mindspore::CNode>()) { | |||||
| auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input); | |||||
| for (const auto &output_diverg_info : output_diverg_info_) { | |||||
| auto output_diverg_cnode = output_diverg_info.second->cnode; | |||||
| if (output_diverg_cnode == input_cnode) { | |||||
| *info = *(output_diverg_info.second); | |||||
| info->cnode = cnode; | |||||
| already_computed = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!already_computed) { | |||||
| info->ComputeThreshold(); | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||