| @@ -197,7 +197,7 @@ union PrimitiveType { | |||
| enum QuantType: int { | |||
| QUANT_NONE, | |||
| AwareTrainning, | |||
| AwareTraining, | |||
| WeightQuant, | |||
| PostTraining | |||
| } | |||
| @@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| // add quant param | |||
| node->quantType = primitiveT_value->GetQuantType(); | |||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) { | |||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) { | |||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | |||
| // activation | |||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | |||
| @@ -202,14 +202,12 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| auto activate_index = node->inputIndex[i]; | |||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||
| if (tensor_input->quantParams.empty()) { | |||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | |||
| std::make_unique<schema::QuantParamT>(input_quant_params[i]); | |||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale | |||
| << " zp: " << input_quant_param->zeroPoint; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | |||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { | |||
| tensor_input->dataType = kNumberTypeInt8; | |||
| for (auto input_quant_param : input_quant_params[i]) { | |||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale | |||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||
| } | |||
| } | |||
| } | |||
| @@ -221,15 +219,18 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| if (output_quant_params.empty()) { | |||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | |||
| } else { | |||
| if (tensor_output->quantParams.empty()) { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | |||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||
| MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale | |||
| << " zp: " << output_quant_param->zeroPoint; | |||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | |||
| for (auto output_quant_param : output_quant_params[0]) { | |||
| if (tensor_output->quantParams.empty()) { | |||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||
| std::make_unique<schema::QuantParamT>(output_quant_param); | |||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale | |||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||
| } | |||
| } | |||
| } | |||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||
| if (node->quantType != schema::QuantType_AwareTraining && | |||
| !(node_type == schema::PrimitiveType_QuantDTypeCast && | |||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | |||
| tensor_output->dataType = kNumberTypeInt8; | |||
| } | |||
| @@ -322,18 +323,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta | |||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||
| paramTensor->data.resize(paramValue->tensor_size()); | |||
| memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | |||
| for (auto &ite : paramValue->quant_param()) { | |||
| auto quantPar = std::make_unique<schema::QuantParamT>(); | |||
| quantPar->scale = ite->scale; | |||
| quantPar->zeroPoint = ite->zeroPoint; | |||
| quantPar->min = ite->min; | |||
| quantPar->max = ite->max; | |||
| quantPar->narrowRange = ite->narrowRange; | |||
| quantPar->inited = ite->inited; | |||
| quantPar->numBits = ite->numBits; | |||
| paramTensor->quantParams.emplace_back(std::move(quantPar)); | |||
| paramTensor->dataType = paramValue->tensor_type(); | |||
| } | |||
| } | |||
| nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); | |||
| fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | |||
| @@ -225,7 +225,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, | |||
| PopulaterConv2DSingleGroup(prim, primitive, group); | |||
| } | |||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | |||
| if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) { | |||
| if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { | |||
| std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | |||
| PopulaterQuantParam(prim, &vecQuantParam); | |||
| primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | |||
| @@ -89,13 +89,15 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| } | |||
| auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | |||
| // add quant parameter | |||
| if (cNode->quantType == schema::QuantType_AwareTrainning) { | |||
| if (cNode->quantType == schema::QuantType_AwareTraining) { | |||
| primTValue->SetQuantType(cNode->quantType); | |||
| for (int index : cNode->inputIndex) { | |||
| primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||
| primTValue->AddInputQuantParam(quant_params); | |||
| } | |||
| for (int index : cNode->outputIndex) { | |||
| primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||
| primTValue->AddOutputQuantParam(quant_params); | |||
| } | |||
| } | |||
| cNode->primitive = nullptr; | |||
| @@ -49,17 +49,17 @@ class PrimitiveTValue : public Value { | |||
| void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) { | |||
| } | |||
| void AddInputQuantParam(schema::QuantParamT quant_param) { | |||
| void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | |||
| this->input_quant_param_.emplace_back(quant_param); | |||
| } | |||
| std::vector<schema::QuantParamT> GetInputQuantParams() const { | |||
| std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { | |||
| return input_quant_param_; | |||
| } | |||
| void AddOutputQuantParam(schema::QuantParamT quant_param) { | |||
| void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { | |||
| this->output_quant_param_.emplace_back(quant_param); | |||
| } | |||
| std::vector<schema::QuantParamT> GetOutputQuantParams() const { | |||
| std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { | |||
| return output_quant_param_; | |||
| } | |||
| @@ -69,8 +69,8 @@ class PrimitiveTValue : public Value { | |||
| protected: | |||
| schema::PrimitiveT *primitive = nullptr; | |||
| std::vector<schema::QuantParamT> input_quant_param_; | |||
| std::vector<schema::QuantParamT> output_quant_param_; | |||
| std::vector<std::vector<schema::QuantParamT>> input_quant_param_; | |||
| std::vector<std::vector<schema::QuantParamT>> output_quant_param_; | |||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | |||
| }; | |||
| } // namespace mindspore::lite | |||
| @@ -130,7 +130,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { | |||
| auto type = flags->quantType; | |||
| switch (type) { | |||
| case mindspore::schema::QuantType_AwareTrainning: { | |||
| case mindspore::schema::QuantType_AwareTraining: { | |||
| // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); | |||
| break; | |||
| } | |||
| @@ -33,7 +33,7 @@ Flags::Flags() { | |||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | |||
| AddFlag(&Flags::inferenceType, "inferenceType", | |||
| "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", ""); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); | |||
| AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); | |||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | |||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); | |||
| @@ -98,8 +98,8 @@ int Flags::Init(int argc, const char **argv) { | |||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | |||
| return 1; | |||
| } | |||
| if (this->quantTypeIn == "AwareTrainning") { | |||
| this->quantType = QuantType_AwareTrainning; | |||
| if (this->quantTypeIn == "AwareTraining") { | |||
| this->quantType = QuantType_AwareTraining; | |||
| } else if (this->quantTypeIn == "WeightQuant") { | |||
| this->quantType = QuantType_WeightQuant; | |||
| } else if (this->quantTypeIn == "PostTraining") { | |||
| @@ -107,7 +107,7 @@ int Flags::Init(int argc, const char **argv) { | |||
| } else if (this->quantTypeIn.empty()) { | |||
| this->quantType = QuantType_QUANT_NONE; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; | |||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining"; | |||
| return 1; | |||
| } | |||
| @@ -27,7 +27,7 @@ namespace lite { | |||
| using mindspore::schema::QuantType; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::schema::QuantType_AwareTrainning; | |||
| using mindspore::schema::QuantType_AwareTraining; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| @@ -68,8 +68,8 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ | |||
| void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| auto type = flags->quantType; | |||
| switch (type) { | |||
| case QuantType::QuantType_AwareTrainning: { | |||
| MS_LOG(INFO) << "create AwareTrainningQuantizer!"; | |||
| case QuantType::QuantType_AwareTraining: { | |||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | |||
| fbQuantizer = | |||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | |||
| break; | |||
| @@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| return status; | |||
| } | |||
| if (!(this->graphDefT->fmkType == converter::FmkType_TF && | |||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) { | |||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { | |||
| status = mQuantizer->GenerateQuantParam(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | |||
| @@ -173,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| formatTransOptimizer.AddPass(formatTransPass); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| // if (ctx.quantType == QuantType_AwareTrainning) { | |||
| // if (ctx.quantType == QuantType_AwareTraining) { | |||
| // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); | |||
| // } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| @@ -193,7 +193,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| // insert quantNode and deQuantNode | |||
| if (ctx.quantType == QuantType_AwareTrainning) { | |||
| if (ctx.quantType == QuantType_AwareTraining) { | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| @@ -136,7 +136,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert transNode before and after existNode | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) { | |||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||
| continue; | |||
| } | |||
| auto &node = *iter; | |||
| @@ -208,7 +208,7 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | |||
| transNode->quantType = QuantType_AwareTrainning; | |||
| transNode->quantType = QuantType_AwareTraining; | |||
| if (nodeType == kInt8ToFP32) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; | |||
| @@ -103,7 +103,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| FormatTransNodeType beforeNodeType, afterNodeType; | |||
| if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | |||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use | |||
| // if (quantType == QuantType_AwareTraining) { // AwareTraining op use | |||
| // nhwc | |||
| // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only | |||
| // support nhwc | |||
| @@ -120,7 +120,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| // beforeNodeType = kNCHW2NHWC; | |||
| // afterNodeType = kNHWC2NCHW; | |||
| } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | |||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc | |||
| // if (quantType == QuantType_AwareTraining) { // AwareTraining op use nhwc | |||
| // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc | |||
| // continue; | |||
| // } | |||
| @@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) { | |||
| MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { | |||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) { | |||
| status = QuantDataFormatTrans(graphNode); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | |||
| @@ -96,7 +96,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_MS) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTrainning: { | |||
| case QuantType_AwareTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else { | |||
| @@ -123,7 +123,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_TF) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTrainning: { | |||
| case QuantType_AwareTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | |||
| weightTensor->format = schema::Format_HWCK; | |||
| } else { | |||
| @@ -148,7 +148,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| } else if (fmkType == converter::FmkType_TFLITE) { | |||
| switch (node->quantType) { | |||
| case QuantType_QUANT_NONE: | |||
| case QuantType_AwareTrainning: | |||
| case QuantType_AwareTraining: | |||
| case QuantType_PostTraining: { | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| @@ -170,7 +170,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| return 0; | |||
| } else if (fmkType == converter::FmkType_ONNX) { | |||
| switch (node->quantType) { | |||
| case QuantType_AwareTrainning: { | |||
| case QuantType_AwareTraining: { | |||
| // sum up from current onnx quant models | |||
| if (opType == schema::PrimitiveType_Conv2D) { | |||
| weightTensor->format = schema::Format_KHWC; | |||
| @@ -312,7 +312,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const | |||
| } | |||
| } | |||
| if (findQuantParams == needQuantParams) { | |||
| dst_op->quantType = schema::QuantType_AwareTrainning; | |||
| dst_op->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| dst_op->quantType = schema::QuantType_QUANT_NONE; | |||
| } | |||
| @@ -324,7 +324,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||
| MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| node->quantType = schema::QuantType_AwareTrainning; | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } | |||
| } | |||
| } | |||
| @@ -337,7 +337,7 @@ STATUS AwareQuantizer::DoQuantize() { | |||
| if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||
| continue; | |||
| } | |||
| if (node->quantType != schema::QuantType_AwareTrainning) { | |||
| if (node->quantType != schema::QuantType_AwareTraining) { | |||
| continue; | |||
| } | |||
| STATUS status; | |||
| @@ -584,7 +584,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||
| node->quantType = schema::QuantType_AwareTrainning; | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } | |||
| @@ -509,7 +509,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M | |||
| quant_param.min = max_min->min; | |||
| quant_param.numBits = bit_num; | |||
| quant_param.narrowRange = false; | |||
| lite_primitive->AddInputQuantParam(quant_param); | |||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||
| lite_primitive->AddInputQuantParam(quant_params); | |||
| // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); | |||
| return RET_OK; | |||
| } | |||
| @@ -526,7 +527,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||
| quant_param.min = max_min->min; | |||
| quant_param.numBits = bit_num; | |||
| quant_param.narrowRange = false; | |||
| lite_primitive->AddOutputQuantParam(quant_param); | |||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||
| lite_primitive->AddOutputQuantParam(quant_params); | |||
| // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); | |||
| return RET_OK; | |||
| } | |||
| @@ -569,7 +571,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input | |||
| auto quant_params = input->GetInputQuantParams(); | |||
| size_t sizeX = quant_params.size(); | |||
| for (size_t i = 0; i < sizeX; i++) { | |||
| input_scales.emplace_back(quant_params[i].scale); | |||
| input_scales.emplace_back(quant_params[i].front().scale); | |||
| } | |||
| size_t sizeY = weight_param->quant_param().size(); | |||
| if (sizeX != sizeY) { | |||
| @@ -31,7 +31,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector | |||
| auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | |||
| primTValue->SetQuantType(schema::QuantType_PostTraining); | |||
| for (auto &quant_param : quant_params) { | |||
| primTValue->AddInputQuantParam(quant_param); | |||
| std::vector<schema::QuantParamT> quant_params_in = {quant_param}; | |||
| primTValue->AddInputQuantParam(quant_params_in); | |||
| } | |||
| return NewValueNode(primTValue); | |||
| } | |||
| @@ -53,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| if (first) { | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | |||
| auto value_node = | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); | |||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); | |||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | |||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | |||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | |||
| @@ -84,11 +85,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||
| if (curnode_quant_type == schema::QuantType_PostTraining && | |||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | |||
| primitiveT_value->GetInputQuantParams()); | |||
| primitiveT_value->GetInputQuantParams().front()); | |||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | |||
| input_cnode_quant_type == schema::QuantType_PostTraining) { | |||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | |||
| input_cnode_primitiveT_value->GetInputQuantParams()); | |||
| input_cnode_primitiveT_value->GetInputQuantParams().front()); | |||
| } | |||
| if (value_node == nullptr) { | |||
| MS_LOG(WARNING) << "value_node is null! " | |||