| @@ -197,7 +197,7 @@ union PrimitiveType { | |||||
| enum QuantType: int { | enum QuantType: int { | ||||
| QUANT_NONE, | QUANT_NONE, | ||||
| AwareTrainning, | |||||
| AwareTraining, | |||||
| WeightQuant, | WeightQuant, | ||||
| PostTraining | PostTraining | ||||
| } | } | ||||
| @@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| // add quant param | // add quant param | ||||
| node->quantType = primitiveT_value->GetQuantType(); | node->quantType = primitiveT_value->GetQuantType(); | ||||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) { | |||||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) { | |||||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | ||||
| // activation | // activation | ||||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | auto input_quant_params = primitiveT_value->GetInputQuantParams(); | ||||
| @@ -202,14 +202,12 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| auto activate_index = node->inputIndex[i]; | auto activate_index = node->inputIndex[i]; | ||||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | auto tensor_input = metaGraphT->allTensors[activate_index].get(); | ||||
| if (tensor_input->quantParams.empty()) { | if (tensor_input->quantParams.empty()) { | ||||
| std::unique_ptr<schema::QuantParamT> input_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_params[i]); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale | |||||
| << " zp: " << input_quant_param->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param)); | |||||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { | |||||
| tensor_input->dataType = kNumberTypeInt8; | |||||
| for (auto input_quant_param : input_quant_params[i]) { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale | |||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -221,15 +219,18 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| if (output_quant_params.empty()) { | if (output_quant_params.empty()) { | ||||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | ||||
| } else { | } else { | ||||
| if (tensor_output->quantParams.empty()) { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_params[0]); | |||||
| MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale | |||||
| << " zp: " << output_quant_param->zeroPoint; | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param)); | |||||
| for (auto output_quant_param : output_quant_params[0]) { | |||||
| if (tensor_output->quantParams.empty()) { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale | |||||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (!(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| if (node->quantType != schema::QuantType_AwareTraining && | |||||
| !(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | ||||
| tensor_output->dataType = kNumberTypeInt8; | tensor_output->dataType = kNumberTypeInt8; | ||||
| } | } | ||||
| @@ -322,18 +323,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | paramTensor->nodeType = schema::NodeType_ValueNode; | ||||
| paramTensor->data.resize(paramValue->tensor_size()); | paramTensor->data.resize(paramValue->tensor_size()); | ||||
| memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | ||||
| for (auto &ite : paramValue->quant_param()) { | |||||
| auto quantPar = std::make_unique<schema::QuantParamT>(); | |||||
| quantPar->scale = ite->scale; | |||||
| quantPar->zeroPoint = ite->zeroPoint; | |||||
| quantPar->min = ite->min; | |||||
| quantPar->max = ite->max; | |||||
| quantPar->narrowRange = ite->narrowRange; | |||||
| quantPar->inited = ite->inited; | |||||
| quantPar->numBits = ite->numBits; | |||||
| paramTensor->quantParams.emplace_back(std::move(quantPar)); | |||||
| paramTensor->dataType = paramValue->tensor_type(); | |||||
| } | |||||
| } | } | ||||
| nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); | nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); | ||||
| fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | ||||
| @@ -225,7 +225,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, | |||||
| PopulaterConv2DSingleGroup(prim, primitive, group); | PopulaterConv2DSingleGroup(prim, primitive, group); | ||||
| } | } | ||||
| primitiveTValuePtr->SetPrimitiveT(primitive.release()); | primitiveTValuePtr->SetPrimitiveT(primitive.release()); | ||||
| if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) { | |||||
| if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { | |||||
| std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | std::vector<std::vector<schema::QuantParamT>> vecQuantParam; | ||||
| PopulaterQuantParam(prim, &vecQuantParam); | PopulaterQuantParam(prim, &vecQuantParam); | ||||
| primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | primitiveTValuePtr->SetInputQuantParam(vecQuantParam); | ||||
| @@ -89,13 +89,15 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||||
| } | } | ||||
| auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | ||||
| // add quant parameter | // add quant parameter | ||||
| if (cNode->quantType == schema::QuantType_AwareTrainning) { | |||||
| if (cNode->quantType == schema::QuantType_AwareTraining) { | |||||
| primTValue->SetQuantType(cNode->quantType); | primTValue->SetQuantType(cNode->quantType); | ||||
| for (int index : cNode->inputIndex) { | for (int index : cNode->inputIndex) { | ||||
| primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddInputQuantParam(quant_params); | |||||
| } | } | ||||
| for (int index : cNode->outputIndex) { | for (int index : cNode->outputIndex) { | ||||
| primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddOutputQuantParam(quant_params); | |||||
| } | } | ||||
| } | } | ||||
| cNode->primitive = nullptr; | cNode->primitive = nullptr; | ||||
| @@ -49,17 +49,17 @@ class PrimitiveTValue : public Value { | |||||
| void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) { | void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) { | ||||
| } | } | ||||
| void AddInputQuantParam(schema::QuantParamT quant_param) { | |||||
| void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | |||||
| this->input_quant_param_.emplace_back(quant_param); | this->input_quant_param_.emplace_back(quant_param); | ||||
| } | } | ||||
| std::vector<schema::QuantParamT> GetInputQuantParams() const { | |||||
| std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { | |||||
| return input_quant_param_; | return input_quant_param_; | ||||
| } | } | ||||
| void AddOutputQuantParam(schema::QuantParamT quant_param) { | |||||
| void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { | |||||
| this->output_quant_param_.emplace_back(quant_param); | this->output_quant_param_.emplace_back(quant_param); | ||||
| } | } | ||||
| std::vector<schema::QuantParamT> GetOutputQuantParams() const { | |||||
| std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { | |||||
| return output_quant_param_; | return output_quant_param_; | ||||
| } | } | ||||
| @@ -69,8 +69,8 @@ class PrimitiveTValue : public Value { | |||||
| protected: | protected: | ||||
| schema::PrimitiveT *primitive = nullptr; | schema::PrimitiveT *primitive = nullptr; | ||||
| std::vector<schema::QuantParamT> input_quant_param_; | |||||
| std::vector<schema::QuantParamT> output_quant_param_; | |||||
| std::vector<std::vector<schema::QuantParamT>> input_quant_param_; | |||||
| std::vector<std::vector<schema::QuantParamT>> output_quant_param_; | |||||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | ||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -130,7 +130,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||||
| void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { | void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { | ||||
| auto type = flags->quantType; | auto type = flags->quantType; | ||||
| switch (type) { | switch (type) { | ||||
| case mindspore::schema::QuantType_AwareTrainning: { | |||||
| case mindspore::schema::QuantType_AwareTraining: { | |||||
| // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); | // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -33,7 +33,7 @@ Flags::Flags() { | |||||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | ||||
| AddFlag(&Flags::inferenceType, "inferenceType", | AddFlag(&Flags::inferenceType, "inferenceType", | ||||
| "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); | "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); | ||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", ""); | |||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); | |||||
| AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); | AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); | ||||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | ||||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); | AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); | ||||
| @@ -98,8 +98,8 @@ int Flags::Init(int argc, const char **argv) { | |||||
| std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| if (this->quantTypeIn == "AwareTrainning") { | |||||
| this->quantType = QuantType_AwareTrainning; | |||||
| if (this->quantTypeIn == "AwareTraining") { | |||||
| this->quantType = QuantType_AwareTraining; | |||||
| } else if (this->quantTypeIn == "WeightQuant") { | } else if (this->quantTypeIn == "WeightQuant") { | ||||
| this->quantType = QuantType_WeightQuant; | this->quantType = QuantType_WeightQuant; | ||||
| } else if (this->quantTypeIn == "PostTraining") { | } else if (this->quantTypeIn == "PostTraining") { | ||||
| @@ -107,7 +107,7 @@ int Flags::Init(int argc, const char **argv) { | |||||
| } else if (this->quantTypeIn.empty()) { | } else if (this->quantTypeIn.empty()) { | ||||
| this->quantType = QuantType_QUANT_NONE; | this->quantType = QuantType_QUANT_NONE; | ||||
| } else { | } else { | ||||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; | |||||
| std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining"; | |||||
| return 1; | return 1; | ||||
| } | } | ||||
| @@ -27,7 +27,7 @@ namespace lite { | |||||
| using mindspore::schema::QuantType; | using mindspore::schema::QuantType; | ||||
| using mindspore::schema::QuantType_PostTraining; | using mindspore::schema::QuantType_PostTraining; | ||||
| using mindspore::schema::QuantType_QUANT_NONE; | using mindspore::schema::QuantType_QUANT_NONE; | ||||
| using mindspore::schema::QuantType_AwareTrainning; | |||||
| using mindspore::schema::QuantType_AwareTraining; | |||||
| using mindspore::schema::QuantType_WeightQuant; | using mindspore::schema::QuantType_WeightQuant; | ||||
| using mindspore::schema::QuantType_PostTraining; | using mindspore::schema::QuantType_PostTraining; | ||||
| using mindspore::schema::QuantType_PostTraining; | using mindspore::schema::QuantType_PostTraining; | ||||
| @@ -68,8 +68,8 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ | |||||
| void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | ||||
| auto type = flags->quantType; | auto type = flags->quantType; | ||||
| switch (type) { | switch (type) { | ||||
| case QuantType::QuantType_AwareTrainning: { | |||||
| MS_LOG(INFO) << "create AwareTrainningQuantizer!"; | |||||
| case QuantType::QuantType_AwareTraining: { | |||||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | |||||
| fbQuantizer = | fbQuantizer = | ||||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | ||||
| break; | break; | ||||
| @@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| return status; | return status; | ||||
| } | } | ||||
| if (!(this->graphDefT->fmkType == converter::FmkType_TF && | if (!(this->graphDefT->fmkType == converter::FmkType_TF && | ||||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) { | |||||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { | |||||
| status = mQuantizer->GenerateQuantParam(); | status = mQuantizer->GenerateQuantParam(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | MS_LOG(ERROR) << "GenerateQuantParam failed"; | ||||
| @@ -173,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| formatTransOptimizer.AddPass(formatTransPass); | formatTransOptimizer.AddPass(formatTransPass); | ||||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | ||||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | ||||
| // if (ctx.quantType == QuantType_AwareTrainning) { | |||||
| // if (ctx.quantType == QuantType_AwareTraining) { | |||||
| // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); | // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); | ||||
| // } | // } | ||||
| status = formatTransOptimizer.Run(graphDefT); | status = formatTransOptimizer.Run(graphDefT); | ||||
| @@ -193,7 +193,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| // insert quantNode and deQuantNode | // insert quantNode and deQuantNode | ||||
| if (ctx.quantType == QuantType_AwareTrainning) { | |||||
| if (ctx.quantType == QuantType_AwareTraining) { | |||||
| Optimizer quantNodeOptimizer; | Optimizer quantNodeOptimizer; | ||||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | ||||
| if (dTypeTransPass == nullptr) { | if (dTypeTransPass == nullptr) { | ||||
| @@ -136,7 +136,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| // insert transNode before and after existNode | // insert transNode before and after existNode | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) { | |||||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto &node = *iter; | auto &node = *iter; | ||||
| @@ -208,7 +208,7 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte | |||||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | transNode->primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| transNode->primitive->value.value = quantDTypeCastParam; | transNode->primitive->value.value = quantDTypeCastParam; | ||||
| transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | ||||
| transNode->quantType = QuantType_AwareTrainning; | |||||
| transNode->quantType = QuantType_AwareTraining; | |||||
| if (nodeType == kInt8ToFP32) { | if (nodeType == kInt8ToFP32) { | ||||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | ||||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; | quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; | ||||
| @@ -103,7 +103,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| FormatTransNodeType beforeNodeType, afterNodeType; | FormatTransNodeType beforeNodeType, afterNodeType; | ||||
| if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | ||||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use | |||||
| // if (quantType == QuantType_AwareTraining) { // AwareTraining op use | |||||
| // nhwc | // nhwc | ||||
| // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only | // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only | ||||
| // support nhwc | // support nhwc | ||||
| @@ -120,7 +120,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| // beforeNodeType = kNCHW2NHWC; | // beforeNodeType = kNCHW2NHWC; | ||||
| // afterNodeType = kNHWC2NCHW; | // afterNodeType = kNHWC2NCHW; | ||||
| } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | ||||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc | |||||
| // if (quantType == QuantType_AwareTraining) { // AwareTraining op use nhwc | |||||
| // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc | // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc | ||||
| // continue; | // continue; | ||||
| // } | // } | ||||
| @@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) { | |||||
| MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { | |||||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) { | |||||
| status = QuantDataFormatTrans(graphNode); | status = QuantDataFormatTrans(graphNode); | ||||
| if (status != 0) { | if (status != 0) { | ||||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | ||||
| @@ -96,7 +96,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| return 0; | return 0; | ||||
| } else if (fmkType == converter::FmkType_MS) { | } else if (fmkType == converter::FmkType_MS) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_AwareTrainning: { | |||||
| case QuantType_AwareTraining: { | |||||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| weightTensor->format = schema::Format_HWCK; | weightTensor->format = schema::Format_HWCK; | ||||
| } else { | } else { | ||||
| @@ -123,7 +123,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| return 0; | return 0; | ||||
| } else if (fmkType == converter::FmkType_TF) { | } else if (fmkType == converter::FmkType_TF) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_AwareTrainning: { | |||||
| case QuantType_AwareTraining: { | |||||
| if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| weightTensor->format = schema::Format_HWCK; | weightTensor->format = schema::Format_HWCK; | ||||
| } else { | } else { | ||||
| @@ -148,7 +148,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| } else if (fmkType == converter::FmkType_TFLITE) { | } else if (fmkType == converter::FmkType_TFLITE) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_QUANT_NONE: | case QuantType_QUANT_NONE: | ||||
| case QuantType_AwareTrainning: | |||||
| case QuantType_AwareTraining: | |||||
| case QuantType_PostTraining: { | case QuantType_PostTraining: { | ||||
| if (opType == schema::PrimitiveType_Conv2D) { | if (opType == schema::PrimitiveType_Conv2D) { | ||||
| weightTensor->format = schema::Format_KHWC; | weightTensor->format = schema::Format_KHWC; | ||||
| @@ -170,7 +170,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| return 0; | return 0; | ||||
| } else if (fmkType == converter::FmkType_ONNX) { | } else if (fmkType == converter::FmkType_ONNX) { | ||||
| switch (node->quantType) { | switch (node->quantType) { | ||||
| case QuantType_AwareTrainning: { | |||||
| case QuantType_AwareTraining: { | |||||
| // sum up from current onnx quant models | // sum up from current onnx quant models | ||||
| if (opType == schema::PrimitiveType_Conv2D) { | if (opType == schema::PrimitiveType_Conv2D) { | ||||
| weightTensor->format = schema::Format_KHWC; | weightTensor->format = schema::Format_KHWC; | ||||
| @@ -312,7 +312,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const | |||||
| } | } | ||||
| } | } | ||||
| if (findQuantParams == needQuantParams) { | if (findQuantParams == needQuantParams) { | ||||
| dst_op->quantType = schema::QuantType_AwareTrainning; | |||||
| dst_op->quantType = schema::QuantType_AwareTraining; | |||||
| } else { | } else { | ||||
| dst_op->quantType = schema::QuantType_QUANT_NONE; | dst_op->quantType = schema::QuantType_QUANT_NONE; | ||||
| } | } | ||||
| @@ -324,7 +324,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||||
| MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | ||||
| node->quantType = schema::QuantType_QUANT_NONE; | node->quantType = schema::QuantType_QUANT_NONE; | ||||
| } else { | } else { | ||||
| node->quantType = schema::QuantType_AwareTrainning; | |||||
| node->quantType = schema::QuantType_AwareTraining; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -337,7 +337,7 @@ STATUS AwareQuantizer::DoQuantize() { | |||||
| if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (node->quantType != schema::QuantType_AwareTrainning) { | |||||
| if (node->quantType != schema::QuantType_AwareTraining) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| STATUS status; | STATUS status; | ||||
| @@ -584,7 +584,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { | |||||
| } | } | ||||
| } | } | ||||
| if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | ||||
| node->quantType = schema::QuantType_AwareTrainning; | |||||
| node->quantType = schema::QuantType_AwareTraining; | |||||
| } else { | } else { | ||||
| node->quantType = schema::QuantType_QUANT_NONE; | node->quantType = schema::QuantType_QUANT_NONE; | ||||
| } | } | ||||
| @@ -509,7 +509,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M | |||||
| quant_param.min = max_min->min; | quant_param.min = max_min->min; | ||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| lite_primitive->AddInputQuantParam(quant_param); | |||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||||
| lite_primitive->AddInputQuantParam(quant_params); | |||||
| // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); | // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -526,7 +527,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||||
| quant_param.min = max_min->min; | quant_param.min = max_min->min; | ||||
| quant_param.numBits = bit_num; | quant_param.numBits = bit_num; | ||||
| quant_param.narrowRange = false; | quant_param.narrowRange = false; | ||||
| lite_primitive->AddOutputQuantParam(quant_param); | |||||
| std::vector<schema::QuantParamT> quant_params = {quant_param}; | |||||
| lite_primitive->AddOutputQuantParam(quant_params); | |||||
| // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); | // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -569,7 +571,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input | |||||
| auto quant_params = input->GetInputQuantParams(); | auto quant_params = input->GetInputQuantParams(); | ||||
| size_t sizeX = quant_params.size(); | size_t sizeX = quant_params.size(); | ||||
| for (size_t i = 0; i < sizeX; i++) { | for (size_t i = 0; i < sizeX; i++) { | ||||
| input_scales.emplace_back(quant_params[i].scale); | |||||
| input_scales.emplace_back(quant_params[i].front().scale); | |||||
| } | } | ||||
| size_t sizeY = weight_param->quant_param().size(); | size_t sizeY = weight_param->quant_param().size(); | ||||
| if (sizeX != sizeY) { | if (sizeX != sizeY) { | ||||
| @@ -31,7 +31,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector | |||||
| auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); | ||||
| primTValue->SetQuantType(schema::QuantType_PostTraining); | primTValue->SetQuantType(schema::QuantType_PostTraining); | ||||
| for (auto &quant_param : quant_params) { | for (auto &quant_param : quant_params) { | ||||
| primTValue->AddInputQuantParam(quant_param); | |||||
| std::vector<schema::QuantParamT> quant_params_in = {quant_param}; | |||||
| primTValue->AddInputQuantParam(quant_params_in); | |||||
| } | } | ||||
| return NewValueNode(primTValue); | return NewValueNode(primTValue); | ||||
| } | } | ||||
| @@ -53,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| if (first) { | if (first) { | ||||
| if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { | ||||
| auto value_node = | auto value_node = | ||||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); | |||||
| NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; | ||||
| auto quant_cast_cnode = graph->NewCNode(op_inputs); | auto quant_cast_cnode = graph->NewCNode(op_inputs); | ||||
| quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); | ||||
| @@ -84,11 +85,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { | |||||
| if (curnode_quant_type == schema::QuantType_PostTraining && | if (curnode_quant_type == schema::QuantType_PostTraining && | ||||
| input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | input_cnode_quant_type == schema::QuantType_QUANT_NONE) { | ||||
| value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, | ||||
| primitiveT_value->GetInputQuantParams()); | |||||
| primitiveT_value->GetInputQuantParams().front()); | |||||
| } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && | ||||
| input_cnode_quant_type == schema::QuantType_PostTraining) { | input_cnode_quant_type == schema::QuantType_PostTraining) { | ||||
| value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, | ||||
| input_cnode_primitiveT_value->GetInputQuantParams()); | |||||
| input_cnode_primitiveT_value->GetInputQuantParams().front()); | |||||
| } | } | ||||
| if (value_node == nullptr) { | if (value_node == nullptr) { | ||||
| MS_LOG(WARNING) << "value_node is null! " | MS_LOG(WARNING) << "value_node is null! " | ||||