| @@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||
| // add quant param | |||
| node->quantType = primitiveT_value->GetQuantType(); | |||
| if (node->quantType == schema::QuantType_PostTraining) { | |||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) { | |||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | |||
| // activation | |||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | |||
| @@ -60,6 +60,17 @@ void AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||
| param_value->set_tensor_addr(tensor_data); | |||
| param_value->set_tensor_size(size); | |||
| } | |||
| if (tensor->quantParams.size() > 0) { | |||
| std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>(); | |||
| quantParam->scale = tensor->quantParams[0]->scale; | |||
| quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; | |||
| quantParam->min = tensor->quantParams[0]->min; | |||
| quantParam->max = tensor->quantParams[0]->max; | |||
| quantParam->narrowRange = tensor->quantParams[0]->narrowRange; | |||
| quantParam->numBits = tensor->quantParams[0]->numBits; | |||
| quantParam->inited = tensor->quantParams[0]->inited; | |||
| param_value->set_quant_param(quantParam); | |||
| } | |||
| parameter->set_default_param(param_value); | |||
| AddNode(i, parameter); | |||
| } | |||
| @@ -77,6 +88,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { | |||
| flag = true; | |||
| } | |||
| auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | |||
| // add quant parameter | |||
| if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) { | |||
| primTValue->SetQuantType(cNode->quantType); | |||
| for (int index : cNode->inputIndex) { | |||
| primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||
| } | |||
| for (int index : cNode->outputIndex) { | |||
| primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); | |||
| } | |||
| } | |||
| cNode->primitive = nullptr; | |||
| auto value_node = NewValueNode(primTValue); | |||
| @@ -28,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| OpDefCopyer GetSimpleOpCopyer() { | |||
| return [](std::unique_ptr<CNodeT> &inCNode) -> std::unique_ptr<CNodeT> { | |||
| return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> { | |||
| std::unique_ptr<CNodeT> newCNode(new CNodeT); | |||
| newCNode->name = inCNode->name; | |||
| @@ -421,9 +421,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| } | |||
| preTensor->refCount = 0; | |||
| preTensor->data.clear(); | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| auto toAddNode = opDefCopyer(toAddNodeIn); | |||
| auto toAddNode = opDefCopyer(toAddNodeIn.get()); | |||
| if (toAddNode == nullptr) { | |||
| MS_LOG(ERROR) << "copy toAddNodeIn failed"; | |||
| *errorCode = RET_NULL_PTR; | |||
| @@ -456,9 +460,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si | |||
| // MS_LOG(ERROR)("Copy TensorT failed"); | |||
| return graphT->nodes.end(); | |||
| } | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| auto toAddNode = opDefCopyer(toAddNodeIn); | |||
| auto toAddNode = opDefCopyer(toAddNodeIn.get()); | |||
| if (toAddNode == nullptr) { | |||
| // MS_LOG(ERROR)("copy toAddNodeIn failed"); | |||
| *errorCode = RET_NULL_PTR; | |||
| @@ -505,9 +513,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| *errorCode = RET_NULL_PTR; | |||
| return graphT->nodes.end(); | |||
| } | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| auto toAddNode = opDefCopyer(toAddNodeIn); | |||
| auto toAddNode = opDefCopyer(toAddNodeIn.get()); | |||
| if (toAddNode == nullptr) { | |||
| // MS_LOG(ERROR)("copy toAddNodeIn failed"); | |||
| *errorCode = RET_NULL_PTR; | |||
| @@ -540,9 +552,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| *errorCode = RET_NULL_PTR; | |||
| return graphT->nodes.end(); | |||
| } | |||
| if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { | |||
| postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; | |||
| toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; | |||
| } | |||
| graphT->allTensors.emplace_back(std::move(toAddTensor)); | |||
| size_t toAddTensorIdx = graphT->allTensors.size() - 1; | |||
| auto toAddNode = opDefCopyer(toAddNodeIn); | |||
| auto toAddNode = opDefCopyer(toAddNodeIn.get()); | |||
| if (toAddNode == nullptr) { | |||
| // MS_LOG(ERROR)("copy toAddNodeIn failed"); | |||
| *errorCode = RET_NULL_PTR; | |||
| @@ -36,7 +36,7 @@ enum InsertPlace { kBefore, kAfter }; | |||
| using NodeIter = std::vector<std::unique_ptr<schema::CNodeT>>::iterator; | |||
| using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(std::unique_ptr<schema::CNodeT> &)>; | |||
| using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT> (schema::CNodeT *)>; | |||
| OpDefCopyer GetSimpleOpCopyer(); | |||
| @@ -19,8 +19,29 @@ | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto &quantParams = tensor->quantParams; | |||
| if (!quantParams.empty()) { | |||
| return std::move(CopyQuantParamT(quantParams.front())); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam) { | |||
| MS_ASSERT(srcQuantParam != nullptr); | |||
| std::unique_ptr<schema::QuantParamT> dstQuantParam = std::make_unique<schema::QuantParamT>(); | |||
| dstQuantParam->inited = srcQuantParam->inited; | |||
| dstQuantParam->scale = srcQuantParam->scale; | |||
| dstQuantParam->zeroPoint = srcQuantParam->zeroPoint; | |||
| dstQuantParam->min = srcQuantParam->min; | |||
| dstQuantParam->max = srcQuantParam->max; | |||
| dstQuantParam->narrowRange = srcQuantParam->narrowRange; | |||
| dstQuantParam->numBits = srcQuantParam->numBits; | |||
| return std::move(dstQuantParam); | |||
| } | |||
| std::unique_ptr<QuantParamT> CopyQuantParamArrayT(const std::unique_ptr<QuantParamT> &srcQuantParamArray) { | |||
| MS_ASSERT(srcQuantParamArray != nullptr); | |||
| auto dstQuantParamArrayT = std::unique_ptr<QuantParamT>(new (std::nothrow) QuantParamT()); | |||
| @@ -164,6 +185,9 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso | |||
| newTensor->refCount = oldTensor->refCount; | |||
| newTensor->nodeType = oldTensor->nodeType; | |||
| newTensor->data = oldTensor->data; | |||
| if (!oldTensor->quantParams.empty()) { | |||
| newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor))); | |||
| } | |||
| return std::move(newTensor); | |||
| } | |||
| @@ -186,6 +210,4 @@ size_t GetShapeSize(const std::vector<int32_t> &shape) { | |||
| } | |||
| return shapeSize; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -38,6 +38,9 @@ using schema::FusedBatchNormT; | |||
| using schema::Format_NCHW; | |||
| using schema::Format_NHWC; | |||
| using STATUS = int; | |||
| std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor); | |||
| size_t GetElementSize(const TensorT &tensor); | |||
| size_t GetElementSize(const TypeId &dataType); | |||
| @@ -50,6 +53,8 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &); | |||
| size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx); | |||
| std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam); | |||
| std::unique_ptr<schema::QuantParamT> \ | |||
| CopyQuantParamArrayT(const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray); | |||
| @@ -101,6 +101,7 @@ target_link_libraries(converter_lite PRIVATE | |||
| node_mid | |||
| graph_pass_mid | |||
| fusion_mid | |||
| quantizer_mid | |||
| protobuf | |||
| quantizer_mid | |||
| pthread | |||
| @@ -77,7 +77,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| MS_ASSERT(nullptr != modelParser); | |||
| const std::string modelFile = flag->modelFile; | |||
| const std::string weightFile = flag->weightFile; | |||
| auto meta_graph = modelParser->Parse(modelFile, weightFile); | |||
| auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; | |||
| return nullptr; | |||
| @@ -118,6 +118,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| // transform | |||
| transform->SetGraphDef(meta_graph); | |||
| transform->CreateQuantizer(flag); | |||
| auto status = transform->Transform(*flag); | |||
| if (status != 0) { | |||
| MS_LOG(ERROR) << "FBTransform model failed " << status; | |||
| @@ -125,6 +126,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { | |||
| } | |||
| return meta_graph; | |||
| } | |||
| void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { | |||
| auto type = flags->quantType; | |||
| switch (type) { | |||
| @@ -132,17 +134,18 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags * | |||
| // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); | |||
| break; | |||
| } | |||
| case mindspore::schema::QuantType_WeightQuant: { | |||
| MS_LOG(INFO) << "create WeightQuantizer!"; | |||
| mQuantizer.reset( | |||
| new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum)); | |||
| break; | |||
| } | |||
| case mindspore::schema::QuantType_PostTraining: { | |||
| MS_LOG(INFO) << "create PostTrainningQuantizer!"; | |||
| mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); | |||
| break; | |||
| } | |||
| // case mindspore::schema::QuantType_WeightQuant: { | |||
| // MS_LOG(INFO) << "create WeightQuantizer!"; | |||
| // mQuantizer.reset( | |||
| // new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, | |||
| // flags->bitNum)); | |||
| // break; | |||
| // } | |||
| // case mindspore::schema::QuantType_PostTraining: { | |||
| // MS_LOG(INFO) << "create PostTrainningQuantizer!"; | |||
| // mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); | |||
| // break; | |||
| // } | |||
| case mindspore::schema::QuantType_QUANT_NONE: | |||
| MS_LOG(INFO) << "Not do quantization for model!"; | |||
| break; | |||
| @@ -14,8 +14,12 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "tools/converter/converter_flags.h" | |||
| #include <regex> | |||
| #include <string> | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -70,9 +74,11 @@ int Flags::Init(int argc, const char **argv) { | |||
| return 1; | |||
| } | |||
| if (this->inputInferenceTypeIn == "FLOAT") { | |||
| this->inputInferenceType = 0; | |||
| this->inputInferenceType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inputInferenceTypeIn == "UINT8") { | |||
| this->inputInferenceType = 1; | |||
| this->inputInferenceType = TypeId::kNumberTypeUInt8; | |||
| } else if (this->inputInferenceTypeIn == "INT8") { | |||
| this->inputInferenceType = TypeId::kNumberTypeInt8; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); | |||
| return 1; | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include "tools/common/flag_parser.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| @@ -66,7 +67,7 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| // used for parse aware trainning | |||
| std::string inputInferenceTypeIn; | |||
| // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; | |||
| int inputInferenceType = 0; | |||
| TypeId inputInferenceType = TypeId::kNumberTypeFloat; | |||
| std::string stdDev; | |||
| std::string mean; | |||
| // used for post-trainning-weight | |||
| @@ -16,11 +16,13 @@ | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "schema/model_generated.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/common/op_utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" | |||
| @@ -28,7 +30,7 @@ | |||
| #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" | |||
| // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" | |||
| // | |||
| // #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" | |||
| @@ -52,18 +54,45 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | |||
| #include "tools/converter/quantizer/aware_quantizer.h" | |||
| #include "tools/converter/converter.h" | |||
| using std::string; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace mindspore::lite { | |||
| GraphDefTransform::GraphDefTransform() = default; | |||
| GraphDefTransform::~GraphDefTransform() = default; | |||
| void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } | |||
| void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| auto type = flags->quantType; | |||
| switch (type) { | |||
| case QuantType::QuantType_AwareTrainning: { | |||
| MS_LOG(INFO) << "create AwareTrainningQuantizer!"; | |||
| fbQuantizer = | |||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | |||
| break; | |||
| } | |||
| // case QuantType::QuantType_WeightQuant: { | |||
| // MS_LOGI("create WeightQuantizer!"); | |||
| // mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize)); | |||
| // break; | |||
| // } | |||
| // case QuantType_PostTraining: { | |||
| // MS_LOGI("create PostTrainningQuantizer!"); | |||
| // mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile)); | |||
| // break; | |||
| // } | |||
| // case QuantType::QuantType_QUANT_NONE: | |||
| // MS_LOGD("Not do quantization for model!"); | |||
| // break; | |||
| default: | |||
| // MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str()); | |||
| break; | |||
| } | |||
| } | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| // // constant folding | |||
| @@ -133,6 +162,53 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| { | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // topological sorting | |||
| { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| // generate and infer quant parameters | |||
| { | |||
| if (mQuantizer != nullptr) { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| if (!(this->graphDefT->fmkType == converter::FmkType_TF && | |||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) { | |||
| status = mQuantizer->GenerateQuantParam(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | |||
| return status; | |||
| } | |||
| status = mQuantizer->DetermineNodeQuantType(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DetermineNodeQuant failed"; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // format transform | |||
| if (ctx.formatTrans) { | |||
| Optimizer formatTransOptimizer; | |||
| @@ -156,13 +232,30 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| { | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||
| unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); | |||
| status = unusedOpRemoveOptimizer.Run(graphDefT); | |||
| // do quantization | |||
| if (fbQuantizer != nullptr) { | |||
| status = fbQuantizer->DoQuantize(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| return status; | |||
| } | |||
| } | |||
| // insert quantNode and deQuantNode | |||
| if (ctx.quantType == QuantType_AwareTrainning) { | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| return RET_ERROR; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| status = quantNodeOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; | |||
| MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| @@ -178,6 +271,4 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| } // namespace mindspore::lite | |||
| @@ -17,8 +17,9 @@ | |||
| #ifndef MS_GRAPHDEF_TRANSFORM_H | |||
| #define MS_GRAPHDEF_TRANSFORM_H | |||
| #include <memory> | |||
| #include "tools/converter/optimizer.h" | |||
| // #include "quantizer/quantizer.h" | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/common/storage.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| @@ -42,7 +43,8 @@ class GraphDefTransform { | |||
| schema::MetaGraphT *graphDefT = nullptr; | |||
| Optimizer *optimizer = nullptr; | |||
| // std::unique_ptr<Quantizer> mQuantizer; | |||
| std::unique_ptr<quant::Quantizer> mQuantizer; | |||
| std::unique_ptr<quant::FbQuantizer> fbQuantizer; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -53,7 +53,7 @@ class MatMulBiasAddFusionPass : public FusionPass { | |||
| bool transB = false; | |||
| size_t id = 0; | |||
| OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr<CNodeT> &inOpDef) -> std::unique_ptr<CNodeT> { | |||
| OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { | |||
| std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT); | |||
| if (newOpDef == nullptr) { | |||
| MS_LOG(ERROR) << "new OpDefT failed"; | |||
| @@ -1,5 +1,6 @@ | |||
| add_library(graph_pass_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc | |||
| @@ -0,0 +1,235 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | |||
| #include <string> | |||
| #include "tools/common/converter_op_utils.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define kMinInputNum 1 | |||
| #define kOutputNum 1 | |||
| STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto status = DoModelInputDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoModelInputDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| status = DoModelOutputDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| status = DoNodeInoutDTypeTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // modify inputTensor first | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| for (auto graphInIdx : graphInIdxes) { | |||
| MS_ASSERT(graph->allTensors.size() > graphInIdx); | |||
| auto &graphInTensor = graph->allTensors.at(graphInIdx); | |||
| graphInTensor->dataType = TypeId::kNumberTypeUInt8; | |||
| } | |||
| if (this->inputDataDType == TypeId::kNumberTypeInt8) { | |||
| return RET_OK; | |||
| } | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| // insert fp2int8 node | |||
| for (auto graphInIdx : graphInIdxes) { | |||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | |||
| auto &tensor = graph->allTensors.at(graphInIdx); | |||
| if (tensor->dims.size() != kNHWCDimNumber) { | |||
| continue; | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto nodeName = node->name; | |||
| for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { | |||
| if (node->inputIndex.at(inputIndexIdx) == graphInIdx) { | |||
| STATUS status = RET_OK; | |||
| // insert dtype cast node between input tensor and input node | |||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status); | |||
| } else { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (inputDataDType == TypeId::kNumberTypeInt8) { | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat); | |||
| auto &graphOutIdxes = graph->outputIndex; | |||
| for (auto graphOutIdx : graphOutIdxes) { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto nodeName = node->name; | |||
| MS_ASSERT(node != nullptr); | |||
| for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) { | |||
| if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||
| } else { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | |||
| return status; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert transNode before and after existNode | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) { | |||
| continue; | |||
| } | |||
| auto &node = *iter; | |||
| if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { | |||
| continue; | |||
| } | |||
| bool needInsertPost = true; | |||
| if (GetCNodeTType(**iter) == PrimitiveType_Shape) { | |||
| needInsertPost = false; | |||
| } | |||
| auto nodeName = node->name; | |||
| if (node->inputIndex.size() < kMinInputNum) { | |||
| MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; | |||
| return RET_ERROR; | |||
| } | |||
| STATUS status; | |||
| // insert pre | |||
| for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | |||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | |||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (needInsertPost) { | |||
| for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| (*iter)->quantType = QuantType_QUANT_NONE; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, | |||
| size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { | |||
| MS_ASSERT((*existNodeIter) != nullptr); | |||
| auto existNodeName = (*existNodeIter)->name; | |||
| std::string tileName; | |||
| if (place == kBefore) { | |||
| tileName = existNodeName + "_pre"; | |||
| } else { | |||
| tileName = existNodeName + "_post"; | |||
| } | |||
| auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT); | |||
| if (transNode == nullptr) { | |||
| MS_LOG(ERROR) << "new TransNode failed"; | |||
| *errorCode = RET_ERROR; | |||
| return graph->nodes.end(); | |||
| } | |||
| auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; | |||
| if (quantDTypeCastParam == nullptr) { | |||
| MS_LOG(ERROR) << "new quantDTypeCastParam failed"; | |||
| *errorCode = RET_ERROR; | |||
| return graph->nodes.end(); | |||
| } | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; | |||
| transNode->quantType = QuantType_AwareTrainning; | |||
| if (nodeType == kInt8ToFP32) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; | |||
| transNode->name = "int8toft32_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kFP32ToInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; | |||
| transNode->name = "ft32toint8_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kUInt8ToInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; | |||
| transNode->name = "uint8toint8_" + tileName + std::to_string(id++); | |||
| } else if (nodeType == kInt8ToUInt8) { | |||
| quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; | |||
| quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8; | |||
| transNode->name = "int8touint8_" + tileName + std::to_string(id++); | |||
| } | |||
| transNode->primitive->value.value = quantDTypeCastParam; | |||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, castOpCopyer); | |||
| } | |||
| void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H | |||
| #define MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }; | |||
| class DTypeTransPass : public GraphPass { | |||
| public: | |||
| DTypeTransPass() : id(0) {} | |||
| ~DTypeTransPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| void SetInputDataDType(TypeId dataType); | |||
| private: | |||
| STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); | |||
| STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); | |||
| STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); | |||
| NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||
| DTypeTransNodeType nodeType, STATUS *errorCode); | |||
| private: | |||
| size_t id; | |||
| TypeId inputDataDType = TypeId::kNumberTypeFloat; | |||
| OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> { | |||
| std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT); | |||
| if (newCNode == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return nullptr; | |||
| } | |||
| newCNode->name = inCNode->name; | |||
| newCNode->quantType = inCNode->quantType; | |||
| newCNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| newCNode->primitive->value.type = inCNode->primitive->value.type; | |||
| auto oldQuantDTypeCastParam = inCNode->primitive->value.AsQuantDTypeCast(); | |||
| auto QuantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; | |||
| if (QuantDTypeCastParam == nullptr) { | |||
| MS_LOG(ERROR) << "new QuantDTypeCast failed"; | |||
| return nullptr; | |||
| } | |||
| QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT; | |||
| QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT; | |||
| newCNode->primitive->value.value = QuantDTypeCastParam; | |||
| return std::move(newCNode); | |||
| }; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H | |||
| @@ -209,6 +209,9 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| return 0; | |||
| } | |||
| // inference needed filterFormat: | |||
| // conv deconv depth dedepth | |||
| // uint8 KHWC KHWC KHWC KHWC | |||
| int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| MS_ASSERT(graphNode != nullptr); | |||
| auto &subGraph = graphNode->subGraph; | |||
| @@ -227,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| auto &weightTensor = subGraph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT | |||
| STATUS status = RET_OK; | |||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK | |||
| if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | |||
| @@ -236,58 +239,51 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| } else { | |||
| MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format | |||
| << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } | |||
| } else if (weightTensor->format == schema::Format_KHWC) { // from onnx | |||
| return RET_OK; | |||
| // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| // status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK); | |||
| // } else { | |||
| // status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); | |||
| // } | |||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | |||
| return 0; | |||
| } else { | |||
| } else if (weightTensor->format != schema::Format_KHWC) { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| if (status == 0) { | |||
| node->primitive->value.AsConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_HWCK; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : " | |||
| << (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"), | |||
| node->name.c_str(); | |||
| MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : " | |||
| << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK | |||
| } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, | |||
| weightTensor->dataType; | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK); | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC); | |||
| } else if (weightTensor->dataType == kNumberTypeUInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC); | |||
| } else { | |||
| MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, | |||
| weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2HWCK); | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | |||
| } | |||
| } else if (weightTensor->format == schema::Format_HWCK) { // from tf | |||
| return 0; | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from onnx | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| if (weightTensor->dataType == kNumberTypeInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); | |||
| MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC"; | |||
| } else if (weightTensor->dataType == kNumberTypeUInt8) { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format | |||
| << "datatype: " << weightTensor->dataType; | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } | |||
| } else if (weightTensor->format == schema::Format_KCHW) { | |||
| if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { | |||
| status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK); | |||
| } else { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK); | |||
| } | |||
| } else { | |||
| } else if (weightTensor->format != schema::Format_KHWC) { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| } | |||
| @@ -295,14 +291,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { | |||
| node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } else { | |||
| MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " | |||
| << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), | |||
| node->name.c_str(); | |||
| MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW") | |||
| << "To KHWC failed, node : " << node->name.c_str(); | |||
| // todo(00445839): consider varible weight condition | |||
| } | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK | |||
| node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } else { // weight should be HWCK | |||
| node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC; | |||
| weightTensor->format = schema::Format_KHWC; | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -354,7 +349,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||
| if (graphNode->subGraph->fmkType == converter::FmkType_MS) { | |||
| weightTensor->format = schema::Format_CKHW; | |||
| } | |||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms | |||
| if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_KCHW) { | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| @@ -374,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC | |||
| if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms | |||
| status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC); | |||
| } else if (weightTensor->format == schema::Format_CHWK) { // from tf | |||
| status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC); | |||
| } else if (weightTensor->format == schema::Format_KHWC) { // from tf | |||
| status = RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; | |||
| return -1; | |||
| @@ -40,7 +40,8 @@ class ModelParser { | |||
| } | |||
| return Fb2Anf(Parse(modelFile, weightFile)); | |||
| } | |||
| virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0; | |||
| virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) = 0; | |||
| public: | |||
| static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { | |||
| @@ -31,7 +31,8 @@ CaffeModelParser::~CaffeModelParser() {} | |||
| const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"}; | |||
| schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { | |||
| schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| std::unique_ptr<schema::MetaGraphT> graph(new schema::MetaGraphT()); | |||
| if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { | |||
| @@ -91,7 +92,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const | |||
| // ConvertCaffeBatchNorm(graph.get()); | |||
| return graph.release(); | |||
| // return Fb2Anf(graph.release()); | |||
| // return Fb2Anf(graph.release()); | |||
| } | |||
| STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, | |||
| @@ -33,7 +33,8 @@ class CaffeModelParser : public ModelParser { | |||
| virtual ~CaffeModelParser(); | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT); | |||
| @@ -37,7 +37,8 @@ class OnnxModelParser : public ModelParser { | |||
| public: | |||
| OnnxModelParser(); | |||
| virtual ~OnnxModelParser(); | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); | |||
| @@ -20,7 +20,6 @@ | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/storage.h" | |||
| #include "flatbuffers/flatbuffers.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/common/file_utils.h" | |||
| namespace mindspore { | |||
| @@ -60,42 +59,64 @@ STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema: | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor, | |||
| schema::TensorT *tensor) { | |||
| std::unique_ptr<schema::QuantParamT> quant_param(new QuantParamT()); | |||
| if (!tflite_tensor->quantization->scale.empty()) { | |||
| quant_param->scale = tflite_tensor->quantization->scale[0]; | |||
| } | |||
| STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||
| auto dst_op = tfliteOpMap.at(tflite_op.get()); | |||
| if (!tflite_tensor->quantization->zero_point.empty()) { | |||
| quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; | |||
| } | |||
| std::vector<uint32_t> quant_params_index; | |||
| quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end()); | |||
| quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); | |||
| for (const auto &index : quant_params_index) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | |||
| if (tflite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; | |||
| return RET_ERROR; | |||
| } | |||
| // change quant param min to 0 to fit ms-lite ops | |||
| if (tensor->dataType == TypeId::kNumberTypeInt8) { | |||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | |||
| } | |||
| if (!tflite_tensor->quantization->min.empty()) { | |||
| quant_param->min = tflite_tensor->quantization->min[0]; | |||
| } | |||
| if (!tflite_tensor->quantization->max.empty()) { | |||
| quant_param->max = tflite_tensor->quantization->max[0]; | |||
| } | |||
| quant_param->inited = true; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(std::move(quant_param)); | |||
| } | |||
| STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| schema::CNodeT *op, TensorCache *tensor_cache) { | |||
| MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size()); | |||
| for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)]; | |||
| if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | |||
| tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { | |||
| continue; | |||
| } | |||
| std::unique_ptr<schema::QuantParamT> quant_param(new schema::QuantParamT()); | |||
| if (!tflite_tensor->quantization->scale.empty()) { | |||
| quant_param->scale = tflite_tensor->quantization->scale[0]; | |||
| } | |||
| if (!tflite_tensor->quantization->zero_point.empty()) { | |||
| quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; | |||
| auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i)); | |||
| if (inTensor == nullptr) { | |||
| MS_LOG(ERROR) << "Parse tflite quant params inTensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (!tflite_tensor->quantization->min.empty()) { | |||
| quant_param->min = tflite_tensor->quantization->min[0]; | |||
| SetMsTensorFromTflite(tflite_tensor, inTensor); | |||
| } | |||
| for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)]; | |||
| if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && | |||
| tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { | |||
| continue; | |||
| } | |||
| if (!tflite_tensor->quantization->max.empty()) { | |||
| quant_param->max = tflite_tensor->quantization->max[0]; | |||
| auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i)); | |||
| if (outTensor == nullptr) { | |||
| MS_LOG(ERROR) << "Parse tflite quant params outTensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| SetMsTensorFromTflite(tflite_tensor, outTensor); | |||
| } | |||
| dst_op->quantType = schema::QuantType_AwareTrainning; | |||
| return RET_OK; | |||
| } | |||
| @@ -105,11 +126,15 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||
| for (const auto &index : tflite_op->outputs) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | |||
| if (tflite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; | |||
| MS_LOG(ERROR) << "tensor with id = " << index << " is null"; | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); | |||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type); | |||
| // change dataType to int8 to fit ms-lite op | |||
| if (tensor->dataType == TypeId::kNumberTypeUInt8) { | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| } | |||
| tensor->dims = tflite_tensor->shape; | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); | |||
| @@ -120,7 +145,8 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT | |||
| STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) { | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op, | |||
| TensorCache *tensor_cache) { | |||
| auto op_type = GetTfliteNodeType(tflite_op, tflite_model); | |||
| std::vector<int32_t> op_inputs(tflite_op->inputs); | |||
| if (op_type == "DeConv2D") { | |||
| @@ -130,12 +156,11 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t | |||
| for (const auto &tflite_index : op_inputs) { | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; | |||
| if (tflite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null"; | |||
| MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null"; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_name = tflite_tensor->name; | |||
| auto op = tfliteOpMap[tflite_op.get()]; | |||
| unsigned int index = tensorCache->FindTensor(tensor_name); | |||
| unsigned int index = tensor_cache->FindTensor(tensor_name); | |||
| if (index != -1) { | |||
| op->inputIndex.push_back(index); | |||
| } | |||
| @@ -146,19 +171,20 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t | |||
| STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| schema::MetaGraphT *subGraph, | |||
| mindspore::lite::TensorCache *tensorCache) { | |||
| schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache, | |||
| const QuantType &quantType) { | |||
| auto i = 0; | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto opType = GetTfliteNodeType(tflite_op, tflite_model); | |||
| std::unique_ptr<schema::CNodeT> op(new schema::CNodeT); | |||
| op->name = opType + "-" + std::to_string(i++); | |||
| op->quantType = quantType; | |||
| MS_LOG(INFO) << "parse op: " << op->name.c_str(); | |||
| auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); | |||
| if (node_parser == nullptr) { | |||
| MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str(); | |||
| MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str(); | |||
| continue; | |||
| // return RET_NULL_PTR; | |||
| } | |||
| @@ -172,7 +198,19 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_ | |||
| status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!"; | |||
| MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed"; | |||
| return RET_ERROR; | |||
| } | |||
| status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed"; | |||
| return RET_ERROR; | |||
| } | |||
| status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -189,8 +227,10 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT> | |||
| const auto &tflite_tensor = tflite_subgraph->tensors[index]; | |||
| std::unique_ptr<schema::TensorT> tensor(new schema::TensorT()); | |||
| tensor->format = schema::Format_NHWC; | |||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type); | |||
| tensor->nodeType = schema::NodeType_ValueNode; | |||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8 | |||
| ? GetTfliteDataType(tflite_tensor->type) | |||
| : TypeId::kNumberTypeInt8; | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| tensor->dims = tflite_tensor->shape; | |||
| tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); | |||
| } | |||
| @@ -212,7 +252,8 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache & | |||
| } | |||
| } | |||
| MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { | |||
| MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { | |||
| MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; | |||
| return nullptr; | |||
| @@ -224,7 +265,6 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||
| MS_LOG(ERROR) << "read tflite model failed"; | |||
| return nullptr; | |||
| } | |||
| if (tflite_model->subgraphs.size() != 1) { | |||
| MS_LOG(ERROR) << "read tflite model subgraphs failed"; | |||
| return nullptr; | |||
| @@ -238,30 +278,15 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||
| // set dst subGraph op attr and tensor_cache. | |||
| std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT); | |||
| subGraph->name = "MS_model converted by TF-Lite"; | |||
| auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); | |||
| auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseOp failed."; | |||
| return nullptr; | |||
| } | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); | |||
| if (status_tmp != RET_OK) { | |||
| MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!"; | |||
| } | |||
| } | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); | |||
| if (statusTmp != RET_OK) { | |||
| MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!"; | |||
| } | |||
| } | |||
| SetGraphTensorIndex(tensorCache, subGraph.get()); | |||
| SetAllTensors(tensorCache, subGraph.get()); | |||
| return subGraph.release(); | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -40,22 +40,25 @@ class TfliteModelParser : public ModelParser { | |||
| virtual ~TfliteModelParser(); | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile); | |||
| MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType = QuantType_QUANT_NONE) override; | |||
| private: | |||
| std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf); | |||
| void SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor); | |||
| void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); | |||
| void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, | |||
| schema::MetaGraphT *subGraphDef); | |||
| void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); | |||
| STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph, | |||
| TensorCache *tensor_cache); | |||
| TensorCache *tensor_cache, const QuantType &quantType); | |||
| STATUS ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op, | |||
| TensorCache *tensor_cache); | |||
| std::string GetTfliteNodeType(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| const std::unique_ptr<tflite::ModelT> &tflite_model); | |||
| @@ -63,13 +66,13 @@ class TfliteModelParser : public ModelParser { | |||
| STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); | |||
| STATUS SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||
| schema::CNodeT *op, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op, | |||
| TensorCache *tensorCache); | |||
| STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model, | |||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache); | |||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op, | |||
| TensorCache *tensor_cache); | |||
| std::map<std::string, schema::CNodeT *> opMap; | |||
| std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap; | |||
| @@ -4,7 +4,9 @@ include_directories(${3RD_DIR}/flatbuffers/include) | |||
| include_directories(${3RD_DIR}/opencv/build/include/opencv4) | |||
| add_library(quantizer_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | |||
| @@ -0,0 +1,594 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/aware_quantizer.h" | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/quantizer/calc_quant_param.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/common/converter_op_utils.h" | |||
| #include "tools/common/node_util.h" | |||
| using std::string; | |||
| using std::vector; | |||
| namespace mindspore::lite::quant { | |||
| struct InputArray { | |||
| std::unique_ptr<QuantParamT> quantParam; | |||
| float mMin = 0.0f; | |||
| float mMax = 0.0f; | |||
| bool narrowRange = false; | |||
| int numBits = 8; | |||
| TypeId dataType = TypeId::kTypeUnknown; | |||
| InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) { | |||
| this->dataType = dataType; | |||
| constexpr float qmin = 0; | |||
| constexpr float qmax = 255; | |||
| mMin = (qmin - mean) / stdDev; | |||
| mMax = (qmax - mean) / stdDev; | |||
| } | |||
| STATUS InitQuantParam() { | |||
| this->quantParam = std::make_unique<schema::QuantParamT>(); | |||
| auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits); | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensorIdx) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto &tensor = graph->allTensors.at(inputTensorIdx); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (!tensor->quantParams.empty()) { | |||
| auto param = GetTensorQuantParam(tensor); | |||
| if (param != nullptr && param->inited) { | |||
| MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam"; | |||
| return RET_OK; | |||
| } | |||
| tensor->quantParams.clear(); | |||
| } | |||
| std::unique_ptr<schema::QuantParamT> tmpQuantParam(new QuantParamT()); | |||
| tmpQuantParam->inited = this->quantParam->inited; | |||
| tmpQuantParam->scale = this->quantParam->scale; | |||
| tmpQuantParam->zeroPoint = this->quantParam->zeroPoint; | |||
| tmpQuantParam->min = this->quantParam->min; | |||
| tmpQuantParam->max = this->quantParam->max; | |||
| tensor->quantParams.push_back(std::move(tmpQuantParam)); | |||
| return RET_OK; | |||
| } | |||
| }; | |||
| const std::array<schema::PrimitiveType, 7> AwareQuantizer::propagatedOps = { | |||
| {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, | |||
| schema::PrimitiveType_DetectionPostProcess}}; | |||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, | |||
| const string &meanValues) | |||
| : FbQuantizer(graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| string::size_type sz; | |||
| const float stdValue = std::stof(stdValues, &sz); | |||
| sz = 0; | |||
| const float mean = std::stof(meanValues, &sz); | |||
| if (inputInferType == "FLOAT") { | |||
| mInputArray = new InputArray(mean, stdValue); | |||
| } else { | |||
| mInputArray = new InputArray(mean, stdValue, TypeId::kNumberTypeUInt8); | |||
| } | |||
| mInputArray->InitQuantParam(); | |||
| } | |||
| STATUS AwareQuantizer::RemoveFakeQuant() { | |||
| // for (auto &subGraph : graphDefT->subgraphs) { | |||
| // auto status = GenerateDefaultQuantParam(subGraph.get()); | |||
| // if (status != RET_OK) { | |||
| // MS_LOGE("GenerateDefaultQuantParam failed: %d", status); | |||
| // return RET_ERROR; | |||
| // } | |||
| // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { | |||
| // auto *node = (*iter).get(); | |||
| // if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { | |||
| // continue; | |||
| // } | |||
| // auto inputIndexes = node->inputIndex; | |||
| // if (inputIndexes.size() != 3) { | |||
| // MS_LOGE("invalid fakequant node's input tensors count!"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // bool narrorRange; | |||
| // int numBits; | |||
| // if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMaxVars) { | |||
| // narrorRange = node->attr.AsFakeQuantWithMinMaxVars()->narrowRange; | |||
| // numBits = node->attr.AsFakeQuantWithMinMaxVars()->numBits; | |||
| // } | |||
| // if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMax) { | |||
| // narrorRange = false; | |||
| // numBits = 8; | |||
| // } | |||
| // | |||
| // TensorDefT *tensor0 = subGraph->allTensors.at(inputIndexes[0]).get(); | |||
| // TensorDefT *tensor1 = subGraph->allTensors.at(inputIndexes[1]).get(); | |||
| // TensorDefT *tensor2 = subGraph->allTensors.at(inputIndexes[2]).get(); | |||
| // MS_ASSERT(tensor0 != nullptr); | |||
| // MS_ASSERT(tensor1 != nullptr); | |||
| // MS_ASSERT(tensor2 != nullptr); | |||
| // // calculate quant param | |||
| // MS_ASSERT(tensor1->dataType == DataType_DT_FLOAT); | |||
| // MS_ASSERT(tensor2->dataType == DataType_DT_FLOAT); | |||
| // auto *minData = reinterpret_cast<const float *>(tensor1->data.data()); | |||
| // auto *maxData = reinterpret_cast<const float *>(tensor2->data.data()); | |||
| // MS_ASSERT(minData != nullptr); | |||
| // MS_ASSERT(maxData != nullptr); | |||
| // std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT()); | |||
| // if (quantParam == nullptr) { | |||
| // MS_LOGE("new quantParam failed"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // auto realMin = (double)minData[0]; | |||
| // auto realMax = (double)maxData[0]; | |||
| // status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits); | |||
| // if (status != RET_OK) { | |||
| // MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str()); | |||
| // return RET_ERROR; | |||
| // } | |||
| // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) { | |||
| // CalFakeNode(tensor0, quantParam.get()); | |||
| // } | |||
| // std::unique_ptr<QuantParamArrayT> quantParamArray(new (std::nothrow) QuantParamArrayT()); | |||
| // if (quantParamArray == nullptr) { | |||
| // MS_LOGE("new quantParamArray failed"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // quantParamArray->param.push_back(std::move(quantParam)); | |||
| // auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray); | |||
| // if (quantParamArrayCopy == nullptr) { | |||
| // MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str()); | |||
| // return RET_ERROR; | |||
| // } | |||
| // node->quantParam.emplace_back(std::move(quantParamArrayCopy)); | |||
| // node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no | |||
| // preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray)); | |||
| // | |||
| // // BroadCast fakeQuantNode QuantParam | |||
| // status = BroadCastQuantParam(subGraph, *iter); | |||
| // if (status != RET_OK) { | |||
| // MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status); | |||
| // return status; | |||
| // } | |||
| // // save post node index for SetAttrToConvolution | |||
| // auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node); | |||
| // // remove fakequantwithminmax node | |||
| // status = IsolateNode(subGraph.get(), node); | |||
| // if (status != RET_OK) { | |||
| // MS_LOGE("in aware quant IsolateNode failed!"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // // set filter param to node | |||
| // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) { | |||
| // auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get(); | |||
| // if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || | |||
| // GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { | |||
| // auto status = SetAttrToConvolution(subGraph.get(), postNode); | |||
| // if (status != RET_OK) { | |||
| // MS_LOGE("in aware quant SetAttrToConvolution failed!"); | |||
| // return RET_ERROR; | |||
| // } | |||
| // } | |||
| // } | |||
| // } | |||
| // | |||
| // // remove IsolatedNode | |||
| // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) { | |||
| // if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { | |||
| // iter = subGraph->nodes.erase(iter); | |||
| // } else { | |||
| // iter++; | |||
| // } | |||
| // } | |||
| // // set graphInputNode inputTensor quantParams | |||
| // MS_ASSERT(subGraph->inputIndex.size() == 1); | |||
| // for (auto graphInputIndex : subGraph->inputIndex) { | |||
| // auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex); | |||
| // for (auto nodeIdx : linkedPostIdx) { | |||
| // MS_ASSERT(subGraph->nodes.size() > nodeIdx); | |||
| // mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get()); | |||
| // } | |||
| // } | |||
| // } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| for (const auto &tensor : subGraph->allTensors) { | |||
| if (!tensor->quantParams.empty()) { | |||
| continue; | |||
| } | |||
| std::unique_ptr<schema::QuantParamT> defaultQuantParam(new QuantParamT()); | |||
| tensor->quantParams.emplace_back(std::move(defaultQuantParam)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||
| // MS_ASSERT(subGraph != nullptr); | |||
| // MS_ASSERT(node != nullptr); | |||
| // auto inputIndexes = node->inputIndex; | |||
| // MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D || | |||
| // GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D); | |||
| // if (inputIndexes.size() < 2) { | |||
| // MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size()); | |||
| // return RET_ERROR; | |||
| // } | |||
| // TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get(); | |||
| // MS_ASSERT(filterTensor != nullptr); | |||
| // auto filterDims = filterTensor->dims; | |||
| // MS_ASSERT(filterDims.size() == 4); | |||
| // if (GetCNodeTType(*node) == OpT_Conv2D) { | |||
| // if (node->fmkType == FmkType_MS) { | |||
| // node->attr.AsConv2D()->channelOut = (int32_t)filterDims[0]; | |||
| // node->attr.AsConv2D()->channelIn = (int32_t)filterDims[1]; | |||
| // node->attr.AsConv2D()->kernelH = (int32_t)filterDims[2]; | |||
| // node->attr.AsConv2D()->kernelW = (int32_t)filterDims[3]; | |||
| // } else if (node->fmkType == FmkType_TF) { | |||
| // node->attr.AsConv2D()->kernelH = (int32_t)filterDims[0]; | |||
| // node->attr.AsConv2D()->kernelW = (int32_t)filterDims[1]; | |||
| // node->attr.AsConv2D()->channelIn = (int32_t)filterDims[2]; | |||
| // node->attr.AsConv2D()->channelOut = (int32_t)filterDims[3]; | |||
| // } else { | |||
| // MS_LOGE("Unsupport"); | |||
| // } | |||
| // } | |||
| // if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) { | |||
| // if (node->fmkType == FmkType_MS) { | |||
| // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0]; | |||
| // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1]; | |||
| // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2]; | |||
| // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3]; | |||
| // } else if (node->fmkType == FmkType_TF) { | |||
| // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0]; | |||
| // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1]; | |||
| // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2]; | |||
| // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3]; | |||
| // } else { | |||
| // MS_LOGE("Unsupport"); | |||
| // } | |||
| // } | |||
| // if (GetCNodeTType(*node) == OpT_DeConv2D) { | |||
| // MS_ASSERT(false); | |||
| // } | |||
| // if (GetCNodeTType(*node) == OpT_DeDepthwiseConv2D) { | |||
| // MS_ASSERT(false); | |||
| // } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::GenerateQuantParam() { | |||
| // todo why? | |||
| MS_ASSERT(graph->inputIndex.size() == 1); | |||
| // set graphInputNode input | |||
| for (auto graphInputIndex : graph->inputIndex) { | |||
| auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetInputArrayQP failed"; | |||
| return status; | |||
| } | |||
| } | |||
| auto status = GenerateDefaultQuantParam(graph.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; | |||
| return status; | |||
| } | |||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| MS_ASSERT(node != nullptr); | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { | |||
| MS_ASSERT(false); | |||
| } | |||
| auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); | |||
| if (quantParamCalcer == nullptr) { | |||
| MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | |||
| } else { | |||
| status = quantParamCalcer->Calc(graph.get(), *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } else { | |||
| node->quantType = schema::QuantType_AwareTrainning; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::DoQuantize() { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||
| continue; | |||
| } | |||
| if (node->quantType != schema::QuantType_AwareTrainning) { | |||
| continue; | |||
| } | |||
| STATUS status; | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || | |||
| GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { | |||
| auto inputIndexes = node->inputIndex; | |||
| if (inputIndexes.size() < 2) { | |||
| MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; | |||
| return RET_ERROR; | |||
| } | |||
| // quant weight | |||
| status = QuantConvWeight(graph.get(), node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvWeight failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| // quant bias | |||
| if (inputIndexes.size() == 3) { | |||
| status = QuantConvBias(graph.get(), node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvBias failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { | |||
| status = QuantDetectionPostProcessConstTensor(graph.get(), node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { | |||
| status = QuantAddConstTensor(graph.get(), node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantAddConstTensor failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| const auto nodeType = GetCNodeTType(*node); | |||
| auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType); | |||
| if (find != propagatedOps.end()) { | |||
| auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get(); | |||
| auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get(); | |||
| MS_ASSERT(inputTensor != nullptr); | |||
| MS_ASSERT(outputTensor != nullptr); | |||
| outputTensor->dataType = inputTensor->dataType; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| for (size_t i = 0; i < node->inputIndex.size(); i++) { | |||
| auto inTensorIdx = node->inputIndex.at(i); | |||
| MS_ASSERT(graph->allTensors.size() > inTensorIdx); | |||
| auto &inTensor = graph->allTensors.at(inTensorIdx); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| if (inTensor->refCount == 999) { | |||
| switch (inTensor->dataType) { | |||
| case TypeId::kNumberTypeFloat: { | |||
| auto quantParam = GetTensorQuantParam(inTensor); | |||
| MS_ASSERT(quantParam != nullptr); | |||
| MS_ASSERT(quantParam->inited); | |||
| auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); | |||
| vector<uint8_t> qDatas(constTensorShapeSize); | |||
| void *inData = inTensor->data.data(); | |||
| auto *castedInData = static_cast<float *>(inData); | |||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<uint8_t>(castedInData[j], quantParam.get()); | |||
| } | |||
| inTensor->data = std::move(qDatas); | |||
| inTensor->dataType = kNumberTypeUInt8; | |||
| } break; | |||
| case kNumberTypeUInt8: | |||
| break; | |||
| default: | |||
| // MS_LOGE("Unsupported dataType: %d", inTensor->dataType); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); | |||
| MS_ASSERT(constTensor != nullptr); | |||
| const auto *constData = reinterpret_cast<const float *>(constTensor->data.data()); | |||
| if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) { | |||
| size_t constTensorShapeSize = GetShapeSize(*constTensor); | |||
| std::unique_ptr<QuantParamT> quantParam = GetTensorQuantParam(constTensor); | |||
| if (quantParam == nullptr) { | |||
| // MS_LOGE("new QuantParamT failed"); | |||
| return RET_NULL_PTR; | |||
| } | |||
| vector<uint8_t> qDatas(constTensorShapeSize); | |||
| for (size_t j = 0; j < constTensorShapeSize; j++) { | |||
| float rawData = constData[j]; | |||
| qDatas[j] = QuantizeData<uint8_t>(rawData, quantParam.get()); | |||
| } | |||
| constTensor->data = std::move(qDatas); | |||
| constTensor->dataType = TypeId::kNumberTypeUInt8; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto inputIndexes = node->inputIndex; | |||
| MS_ASSERT(inputIndexes.size() >= 3); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0)); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1)); | |||
| MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); | |||
| auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); | |||
| MS_ASSERT(biasTensor != nullptr); | |||
| if (biasTensor->dataType != TypeId::kNumberTypeFloat) { | |||
| // MS_LOGD("conv %s's bias data is not float", node->name.c_str()); | |||
| return RET_OK; | |||
| } | |||
| if (biasTensor->dataType == TypeId::kNumberTypeInt32) { | |||
| return RET_OK; | |||
| } | |||
| if (biasTensor->dataType != TypeId::kNumberTypeFloat) { | |||
| // MS_LOGE("conv %s's bias data is not float", node->name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| auto &inputTensor = graph->allTensors.at(inputIndexes.at(0)); | |||
| auto &weightTensor = graph->allTensors.at(inputIndexes.at(1)); | |||
| MS_ASSERT(inputTensor != nullptr); | |||
| MS_ASSERT(weightTensor != nullptr); | |||
| auto inputScale = inputTensor->quantParams.front()->scale; | |||
| auto weightScale = weightTensor->quantParams.front()->scale; | |||
| auto scale = inputScale * weightScale; | |||
| // set bias quant param | |||
| std::unique_ptr<QuantParamT> biasQuantParam = GetTensorQuantParam(biasTensor); | |||
| if (biasQuantParam == nullptr) { | |||
| // MS_LOGE("new QuantParamT failed"); | |||
| return RET_ERROR; | |||
| } | |||
| biasQuantParam->inited = true; | |||
| biasQuantParam->scale = scale; | |||
| biasQuantParam->zeroPoint = 0; | |||
| biasQuantParam->numBits = 8; | |||
| biasQuantParam->narrowRange = false; | |||
| biasQuantParam->min = 0.0; | |||
| biasQuantParam->max = 0.0; | |||
| // quant bias data | |||
| auto bShapeSize = GetShapeSize(*(biasTensor.get())); | |||
| auto *qDatas = new (std::nothrow) int32_t[bShapeSize]; | |||
| if (qDatas == nullptr) { | |||
| // MS_LOGE("new qDatas failed"); | |||
| return RET_ERROR; | |||
| } | |||
| void *biasData = biasTensor->data.data(); | |||
| auto *rawDatas = static_cast<float *>(biasData); | |||
| for (size_t i = 0; i < bShapeSize; ++i) { | |||
| qDatas[i] = (int32_t)std::round(rawDatas[i] / scale); | |||
| } | |||
| biasTensor->dataType = TypeId::kNumberTypeInt32; | |||
| biasTensor->data.clear(); | |||
| biasTensor->data.resize(bShapeSize * sizeof(int32_t)); | |||
| auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| // MS_LOGE("memcpy_s failed: %d", ret); | |||
| return RET_ERROR; | |||
| } | |||
| delete[] qDatas; | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { | |||
| MS_ASSERT(subGraph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); | |||
| auto inputIndexes = node->inputIndex; | |||
| MS_ASSERT(inputIndexes.size() >= 2); | |||
| MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); | |||
| auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1)); | |||
| if (weightTensor->dataType == TypeId::kNumberTypeInt8) { | |||
| return RET_OK; | |||
| } | |||
| if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; | |||
| return RET_ERROR; | |||
| } | |||
| size_t wShapeSize = GetShapeSize(*(weightTensor.get())); | |||
| void *oriWeightData = weightTensor->data.data(); | |||
| MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); | |||
| vector<int8_t> qDatas(wShapeSize); | |||
| auto weightQauntParam = GetTensorQuantParam(weightTensor); | |||
| if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| } else { // tflite awareing quant | |||
| auto *weightData = static_cast<uint8_t *>(oriWeightData); | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = (int32_t)weightData[j] - 128; | |||
| } | |||
| weightQauntParam->zeroPoint -= 128; | |||
| weightTensor->quantParams.clear(); | |||
| weightTensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| } | |||
| ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize); | |||
| weightTensor->dataType = TypeId::kNumberTypeInt8; | |||
| return RET_OK; | |||
| } | |||
| STATUS AwareQuantizer::DetermineNodeQuantType() { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| bool canQuant = true; | |||
| for (auto &inTensorIdx : node->inputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > inTensorIdx); | |||
| auto &inTensor = graph->allTensors.at(inTensorIdx); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || | |||
| !inTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| if (canQuant) { | |||
| for (auto &outTensorIdx : node->outputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > outTensorIdx); | |||
| auto &outTensor = graph->allTensors.at(outTensorIdx); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||
| !outTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||
| node->quantType = schema::QuantType_AwareTrainning; | |||
| } else { | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MS_AWARE_QUANTIZER_H | |||
| #define MS_AWARE_QUANTIZER_H | |||
| #include <array> | |||
| #include <string> | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "include/errorcode.h" | |||
| namespace mindspore::lite::quant { | |||
| struct InputArray; | |||
| class AwareQuantizer : public FbQuantizer { | |||
| public: | |||
| AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues, | |||
| const std::string &meanValues); | |||
| ~AwareQuantizer() { delete (mInputArray); } | |||
| STATUS RemoveFakeQuant() override; | |||
| STATUS GenerateQuantParam() override; | |||
| STATUS DetermineNodeQuantType() override; | |||
| STATUS DoQuantize() override; // override; | |||
| private: | |||
| // RemoveFakeQuant | |||
| STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); | |||
| STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||
| STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node); | |||
| STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node); | |||
| float inputScale = 0.0f; | |||
| InputArray *mInputArray; | |||
| static const std::array<schema::PrimitiveType, 7> propagatedOps; | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||
| @@ -0,0 +1,504 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/quantizer/calc_quant_param.h" | |||
| #include <cfloat> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "schema/inner/ops_generated.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore::lite { | |||
| STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| // int32 weight no need to quant | |||
| if (tensor.dataType == TypeId::kNumberTypeInt32 || tensor.dataType == TypeId::kNumberTypeUInt8) { | |||
| return RET_OK; | |||
| } | |||
| if (tensor.dataType != TypeId::kNumberTypeFloat) { | |||
| // MS_LOGW("Const Tensor without quantParam should has float dataType, in fact: %d", tensor.dataType); | |||
| return RET_ERROR; | |||
| } | |||
| const auto *constData = reinterpret_cast<const float *>(tensor.data.data()); | |||
| size_t constTensorShapeSize = GetShapeSize(tensor); | |||
| float min = 0.0f; | |||
| float max = 0.0f; | |||
| // find min and max | |||
| for (size_t i = 0; i < constTensorShapeSize; i++) { | |||
| min = std::min(min, constData[i]); | |||
| max = std::max(max, constData[i]); | |||
| } | |||
| if (min == 0.0f && max == 0.0f) { | |||
| max = 1.0f; | |||
| } | |||
| bool isQuantExact = true; | |||
| for (size_t i = 0; i < constTensorShapeSize; i++) { | |||
| isQuantExact &= (constData[i] == min || constData[i] == max); | |||
| } | |||
| if (!isQuantExact) { | |||
| // //MS_LOGD("compute quantParam for const tensor may be a cause of poor inference accuracy"); | |||
| } | |||
| return quant::CalQuantizationParams(quantParam, min, max); | |||
| } | |||
| // init inTensor quantParam from preNode if possable | |||
| // init outTensor quantParam from postNode if possable | |||
| int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| MS_ASSERT(node.inputIndex.size() > 0); | |||
| MS_ASSERT(node.quantParam.size() == node.inputIndex.size() + node.outputIndex.size()); | |||
| inputParamDone = 0; | |||
| auto inputTensorSize = node.inputIndex.size(); | |||
| for (size_t i = 0; i < inputTensorSize; i++) { | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); | |||
| auto &tensor = graph->allTensors.at(node.inputIndex.at(i)); | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto quantParam = GetTensorQuantParam(tensor); | |||
| if (quantParam->inited) { // inited | |||
| inputParamDone++; | |||
| continue; | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (tensor->refCount == schema::NodeType_ValueNode && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { | |||
| auto status = ComputeConstQuantParam((*tensor), quantParam.get()); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("ComputeConstQuantParam failed: %d", status); | |||
| return status; | |||
| } | |||
| tensor->quantParams.front() = std::move(quantParam); | |||
| inputParamDone++; | |||
| continue; | |||
| } | |||
| } | |||
| outputParamDone = 0; | |||
| for (unsigned int i : node.outputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > i); | |||
| auto &tensor = graph->allTensors.at(i); | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto quantParam = GetTensorQuantParam(tensor); | |||
| MS_ASSERT(quantParam != nullptr); | |||
| if (quantParam->inited) { // inited | |||
| outputParamDone++; | |||
| continue; | |||
| } | |||
| if (tensor->refCount == 999) { | |||
| MS_ASSERT(false); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||
| auto status = QuantParamCalcer::Calc(subGraph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| if (inputParamDone != node.inputIndex.size()) { | |||
| MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| if (outputParamDone != node.outputIndex.size()) { | |||
| MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name.c_str(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| if (inputParamDone != node.inputIndex.size()) { | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(0)); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.at(0)); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outputQuantParam = GetTensorQuantParam(outTensor); | |||
| MS_ASSERT(outputQuantParam != nullptr); | |||
| if (!outputQuantParam->inited) { | |||
| // MS_LOGW("Can not determine inputTensor quantParam from outputTensor for node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| for (unsigned int i : node.inputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); | |||
| auto &inTensor = graph->allTensors.at(i); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| if (inQuantParam->inited) { | |||
| continue; | |||
| } | |||
| inTensor->quantParams.front() = std::move(inQuantParam); | |||
| } | |||
| } | |||
| if (outputParamDone != node.outputIndex.size()) { | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); | |||
| auto &inTensor = graph->allTensors.at(node.inputIndex.at(0)); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| if (!inQuantParam->inited) { | |||
| // MS_LOGW("Can not determine outputTensor quantParam from inputTensor for node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t i = 0; i < node.outputIndex.size(); i++) { | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i)); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| if (outQuantParam->inited) { | |||
| continue; | |||
| } | |||
| // todo copy quant params | |||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| class CalcConcat : public QuantParamCalcer { | |||
| public: | |||
| CalcConcat() = default; | |||
| int Calc(MetaGraphT *graph, const CNodeT &node) override { | |||
| MS_ASSERT(node.outputIndex.size() == 1); | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| if (inputParamDone != node.inputIndex.size()) { | |||
| // MS_LOGW("Can not determine concat inputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| if (outputParamDone != 1) { | |||
| MS_ASSERT(outputParamDone == 0); | |||
| float minMin = FLT_MAX; | |||
| float maxMax = FLT_MIN; | |||
| bool narrowRange = false; | |||
| int numBits = -1; | |||
| for (size_t i = 0; i < node.inputIndex.size(); i++) { | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); | |||
| auto &inTensor = graph->allTensors.at(i); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| MS_ASSERT(inQuantParam != nullptr); | |||
| if (!inQuantParam->inited) { | |||
| return RET_ERROR; | |||
| } | |||
| if (numBits == -1) { | |||
| narrowRange = inQuantParam->narrowRange; | |||
| numBits = inQuantParam->numBits; | |||
| } else { | |||
| MS_ASSERT(narrowRange == quantParam->narrowRange); | |||
| MS_ASSERT(numBits == quantParam->numBits); | |||
| } | |||
| if (minMin > inQuantParam->min) { | |||
| minMin = inQuantParam->min; | |||
| } | |||
| if (maxMax < inQuantParam->max) { | |||
| maxMax = inQuantParam->max; | |||
| } | |||
| } | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); | |||
| return RET_ERROR; | |||
| } | |||
| outputParamDone++; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| }; | |||
| class CalcAdd : public QuantParamCalcer { | |||
| public: | |||
| CalcAdd() = default; | |||
| int Calc(MetaGraphT *graph, const CNodeT &node) override { | |||
| MS_ASSERT(node.inputIndex.size() == 2); | |||
| MS_ASSERT(node.outputIndex.size() == 1); | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| if (inputParamDone != 2) { | |||
| // MS_LOGW("Can not determine add inputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| if (outputParamDone != 1) { | |||
| MS_ASSERT(outputParamDone == 0); | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); | |||
| auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); | |||
| MS_ASSERT(tensor0 != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); | |||
| auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); | |||
| MS_ASSERT(tensor1 != nullptr); | |||
| auto biasTensor = &tensor0; | |||
| auto paramTensor = &tensor1; | |||
| if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { | |||
| biasTensor = &tensor0; | |||
| paramTensor = &tensor1; | |||
| } else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| biasTensor = &tensor1; | |||
| paramTensor = &tensor0; | |||
| } else { | |||
| // MS_LOGW("Can not determine add outputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| auto quantParam = GetTensorQuantParam(*paramTensor); | |||
| MS_ASSERT(quantParam != nullptr); | |||
| MS_ASSERT(quantParam->inited); | |||
| auto min = quantParam->min; | |||
| auto max = quantParam->max; | |||
| { | |||
| if ((*biasTensor)->dataType == TypeId::kNumberTypeFloat) { | |||
| MS_ASSERT((*biasTensor)->data.size() == sizeof(float) / sizeof(uint8_t)); | |||
| void *oriTensorData = (*biasTensor)->data.data(); | |||
| auto *bias = static_cast<float *>(oriTensorData); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); | |||
| return RET_ERROR; | |||
| } | |||
| } else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) { | |||
| MS_ASSERT((*biasTensor)->data.size() == 1); | |||
| void *oriTensorData = (*biasTensor)->data.data(); | |||
| auto *bias = static_cast<uint8_t *>(oriTensorData); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| // MS_LOGW("Unsupported tensor dataType: %d", (*biasTensor)->dataType); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| }; | |||
| class CalcRealDiv : public QuantParamCalcer { | |||
| public: | |||
| CalcRealDiv() = default; | |||
| int Calc(MetaGraphT *graph, const CNodeT &node) override { | |||
| MS_ASSERT(node.inputIndex.size() == 2); | |||
| MS_ASSERT(node.outputIndex.size() == 1); | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| if (inputParamDone != 2) { | |||
| // MS_LOGW("Can not determine realdiv inputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| if (outputParamDone != 1) { | |||
| MS_ASSERT(outputParamDone == 0); | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); | |||
| auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); | |||
| MS_ASSERT(tensor0 != nullptr); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); | |||
| auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); | |||
| MS_ASSERT(tensor1 != nullptr); | |||
| if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| auto quantParam = GetTensorQuantParam(tensor1); | |||
| auto min = quantParam->min; | |||
| auto max = quantParam->max; | |||
| { | |||
| if (tensor1->dataType == TypeId::kNumberTypeFloat) { | |||
| MS_ASSERT(tensor1->data.size() == sizeof(float) / sizeof(uint8_t)); | |||
| void *oriTensorData = tensor1->data.data(); | |||
| auto *div = static_cast<float *>(oriTensorData); | |||
| MS_ASSERT(*div != 0); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div)); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); | |||
| return RET_ERROR; | |||
| } | |||
| } else if (tensor1->dataType == TypeId::kNumberTypeUInt8) { | |||
| MS_ASSERT(tensor1->data.size() == 1); | |||
| void *oriTensorData = tensor1->data.data(); | |||
| auto *div = static_cast<uint8_t *>(oriTensorData); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div)); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| // MS_LOGW("Unsupported tensor dataType: %d", tensor1->dataType); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { | |||
| // MS_LOGW("Can not determine realDiv outputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| }; | |||
| class CalcToSet : public QuantParamCalcer { | |||
| public: | |||
| CalcToSet(float min, float max) : min(min), max(max) {} | |||
| int Calc(MetaGraphT *graph, const CNodeT &node) override { | |||
| MS_ASSERT(node.inputIndex.size() == 1); | |||
| MS_ASSERT(node.outputIndex.size() == 1); | |||
| auto status = QuantParamCalcer::Calc(graph, node); | |||
| if (status != RET_OK) { | |||
| // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); | |||
| return status; | |||
| } | |||
| // input | |||
| if (inputParamDone != node.inputIndex.size()) { | |||
| // MS_LOGW("Can not determine inputTensor quantParam, node %s", node.name.c_str()); | |||
| return RET_ERROR; | |||
| } | |||
| // output | |||
| std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT()); | |||
| if (quantParam == nullptr) { | |||
| // MS_LOGW("new QuantParamT failed"); | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->scale = (max - min) / 256; | |||
| MS_ASSERT(quantParam->scale != 0); | |||
| quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); | |||
| quantParam->min = min; | |||
| quantParam->max = max; | |||
| quantParam->inited = true; | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| outTensor->quantParams.front() = std::move(quantParam); | |||
| return RET_OK; | |||
| } | |||
| protected: | |||
| float min; | |||
| float max; | |||
| }; | |||
| class CalcActivation : public QuantParamCalcer { | |||
| public: | |||
| CalcActivation() = default; | |||
| int Calc(MetaGraphT *subGraph, const CNodeT &node) override { | |||
| MS_ASSERT(node.inputIndex.size() == 1); | |||
| MS_ASSERT(node.outputIndex.size() == 1); | |||
| MS_ASSERT(node.attr.AsActivation() != nullptr); | |||
| if (node.primitive->value.AsActivation()->type == schema::ActivationType_SIGMOID) { | |||
| auto calcToSet = CalcToSet(0, 1); | |||
| return calcToSet.Calc(subGraph, node); | |||
| } else { | |||
| auto calCommon = CommonCalcer(); | |||
| return calCommon.Calc(subGraph, node); | |||
| } | |||
| } | |||
| }; | |||
| QuantParamCalcRegister::QuantParamCalcRegister() { | |||
| bool hasError = false; | |||
| auto baseCalcer = new (std::nothrow) QuantParamCalcer(); | |||
| if (baseCalcer == nullptr) { | |||
| // MS_LOGW("new QuantParamCalcer failed"); | |||
| hasError = true; | |||
| } | |||
| auto commonCalcer = new (std::nothrow) CommonCalcer(); | |||
| if (commonCalcer == nullptr) { | |||
| // MS_LOGW("new commonCalcer failed"); | |||
| hasError = true; | |||
| } | |||
| auto linearCalcer = new (std::nothrow) LinearCalcer(); | |||
| if (linearCalcer == nullptr) { | |||
| // MS_LOGW("new linearCalcer failed"); | |||
| hasError = true; | |||
| } | |||
| if (!hasError) { | |||
| _registerMap[schema::PrimitiveType_Concat] = new CalcConcat(); | |||
| _registerMap[schema::PrimitiveType_Activation] = new CalcActivation(); | |||
| _registerMap[schema::PrimitiveType_Add] = new CalcAdd(); | |||
| _registerMap[schema::PrimitiveType_Mul] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Resize] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Shape] = linearCalcer; // todo if shape influence postNode's output quantParam | |||
| _registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1); | |||
| _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv(); | |||
| _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Mean] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; | |||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | |||
| // todo | |||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | |||
| // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | |||
| // if quantTransNode is inserted after detection_postprocess node, there will be some errors | |||
| _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer; | |||
| } | |||
| } | |||
| QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() { | |||
| static QuantParamCalcRegister instance; | |||
| return &instance; | |||
| } | |||
| QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { | |||
| auto it = _registerMap.find(opType); | |||
| if (it != _registerMap.end()) { | |||
| return it->second; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace mindspore::lite | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef CALC_QUANT_PARAM_H | |||
| #define CALC_QUANT_PARAM_H | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| static constexpr int CONVLUTION_INPUT_NUM = 3; | |||
| class QuantParamCalcer { | |||
| public: | |||
| virtual ~QuantParamCalcer() = default; | |||
| virtual int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node); | |||
| protected: | |||
| STATUS ComputeConstQuantParam(const schema::TensorT &tensor, schema::QuantParamT *quantParam); | |||
| protected: | |||
| size_t inputParamDone = 0; | |||
| size_t outputParamDone = 0; | |||
| }; | |||
| class CommonCalcer : public QuantParamCalcer { | |||
| public: | |||
| CommonCalcer() = default; | |||
| ~CommonCalcer() override = default; | |||
| int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; | |||
| }; | |||
| class LinearCalcer : public QuantParamCalcer { | |||
| public: | |||
| LinearCalcer() = default; | |||
| ~LinearCalcer() override = default; | |||
| int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node) override; | |||
| }; | |||
| class QuantParamCalcRegister { | |||
| public: | |||
| virtual ~QuantParamCalcRegister() = default; | |||
| QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType); | |||
| static QuantParamCalcRegister *GetInstance(); | |||
| private: | |||
| QuantParamCalcRegister(); | |||
| std::unordered_map<schema::PrimitiveType, QuantParamCalcer *> _registerMap; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -39,126 +39,127 @@ QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThr | |||
| : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} | |||
| bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { | |||
| size_t i = 0; | |||
| for (i = 0; i < mConvTypes.size(); i++) { | |||
| if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { | |||
| break; | |||
| } | |||
| size_t i = 0; | |||
| for (i = 0; i < mConvTypes.size(); i++) { | |||
| if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { | |||
| break; | |||
| } | |||
| } | |||
| if ((i == mConvTypes.size()) || (node->size() < 3)) { | |||
| return false; | |||
| } | |||
| if ((i == mConvTypes.size()) || (node->size() < 3)) { | |||
| return false; | |||
| } | |||
| auto inputNode = node->input(2); | |||
| if (!inputNode->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||
| auto abstract_base = paramNode->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| return false; | |||
| } | |||
| auto inputNode = node->input(2); | |||
| if (!inputNode->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||
| auto abstract_base = paramNode->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| return false; | |||
| } | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) { | |||
| MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); | |||
| return false; | |||
| } | |||
| auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); | |||
| size_t shapeSize = 1; | |||
| for (auto dim : weight_shape) { | |||
| shapeSize = shapeSize * dim; | |||
| } | |||
| if (shapeSize < mWeightSize) { | |||
| MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; | |||
| return false; | |||
| } | |||
| if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { | |||
| MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; | |||
| return false; | |||
| } | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) { | |||
| MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); | |||
| return false; | |||
| } | |||
| auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); | |||
| size_t shapeSize = 1; | |||
| for (auto dim : weight_shape) { | |||
| shapeSize = shapeSize * dim; | |||
| } | |||
| if (shapeSize < mWeightSize) { | |||
| MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; | |||
| return false; | |||
| } | |||
| if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { | |||
| MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; | |||
| return false; | |||
| } | |||
| return true; | |||
| return true; | |||
| } | |||
| bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||
| if (primitiveT_value == nullptr) { | |||
| MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| auto type = primitiveT_value->GetPrimitiveT()->value.type; | |||
| MS_LOG(INFO) << "Primitive type: " << type; | |||
| static const std::vector<schema::PrimitiveType> uint8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Activation}; | |||
| return IsContain(uint8OpList, type); | |||
| auto type = primitiveT_value->GetPrimitiveT()->value.type; | |||
| MS_LOG(INFO) << "Primitive type: " << type; | |||
| static const std::vector<schema::PrimitiveType> uint8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, | |||
| schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_Activation}; | |||
| return IsContain(uint8OpList, type); | |||
| } | |||
| bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { | |||
| size_t i = 0; | |||
| for (i = 0; i < mMulTypes.size(); i++) { | |||
| if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { | |||
| break; | |||
| } | |||
| } | |||
| if (i == mMulTypes.size()) { | |||
| return false; | |||
| size_t i = 0; | |||
| for (i = 0; i < mMulTypes.size(); i++) { | |||
| if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { | |||
| break; | |||
| } | |||
| } | |||
| if (i == mMulTypes.size()) { | |||
| return false; | |||
| } | |||
| if (node->size() < 3) { | |||
| MS_LOG(INFO) << "input size less!"; | |||
| return false; | |||
| } | |||
| if (node->size() < 3) { | |||
| MS_LOG(INFO) << "input size less!"; | |||
| return false; | |||
| } | |||
| auto inputNode1 = node->input(1); | |||
| auto inputNode2 = node->input(2); | |||
| if (inputNode1 == nullptr || inputNode2 == nullptr) { | |||
| MS_LOG(INFO) << "mul input is nullptr!"; | |||
| return false; | |||
| } | |||
| auto inputNode1 = node->input(1); | |||
| auto inputNode2 = node->input(2); | |||
| if (inputNode1 == nullptr || inputNode2 == nullptr) { | |||
| MS_LOG(INFO) << "mul input is nullptr!"; | |||
| return false; | |||
| } | |||
| ParameterPtr paramNode = nullptr; | |||
| if (inputNode1->isa<Parameter>()) { | |||
| paramNode = inputNode1->cast<ParameterPtr>(); | |||
| } else if (inputNode2->isa<Parameter>()) { | |||
| paramNode = inputNode2->cast<ParameterPtr>(); | |||
| } | |||
| ParameterPtr paramNode = nullptr; | |||
| if (inputNode1->isa<Parameter>()) { | |||
| paramNode = inputNode1->cast<ParameterPtr>(); | |||
| } else if (inputNode2->isa<Parameter>()) { | |||
| paramNode = inputNode2->cast<ParameterPtr>(); | |||
| } | |||
| if (paramNode == nullptr) { | |||
| MS_LOG(INFO) << "invalid paramNode!"; | |||
| return false; | |||
| } | |||
| if (paramNode == nullptr) { | |||
| MS_LOG(INFO) << "invalid paramNode!"; | |||
| return false; | |||
| } | |||
| auto abstract_base = paramNode->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(INFO) << "abstract is nullptr"; | |||
| return false; | |||
| } | |||
| auto abstract_base = paramNode->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(INFO) << "abstract is nullptr"; | |||
| return false; | |||
| } | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) { | |||
| MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); | |||
| return false; | |||
| } | |||
| auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); | |||
| size_t shapeSize = 1; | |||
| for (auto dim : weight_shape) { | |||
| shapeSize = shapeSize * dim; | |||
| } | |||
| if (shapeSize < mWeightSize) { | |||
| MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; | |||
| return false; | |||
| } | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) { | |||
| MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); | |||
| return false; | |||
| } | |||
| auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape(); | |||
| size_t shapeSize = 1; | |||
| for (auto dim : weight_shape) { | |||
| shapeSize = shapeSize * dim; | |||
| } | |||
| if (shapeSize < mWeightSize) { | |||
| MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; | |||
| return false; | |||
| } | |||
| return true; | |||
| return true; | |||
| } | |||
| void CalFakeNode(const AnfNodePtr &inTensor) { | |||
| @@ -190,56 +191,119 @@ void CalFakeNode(const AnfNodePtr &inTensor) { | |||
| // } | |||
| } | |||
| STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, | |||
| double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| if (mMin > 0.0f) { | |||
| MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; | |||
| mMin = 0.0f; | |||
| } | |||
| if (mMax < 0.0f) { | |||
| MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; | |||
| mMax = 0.0f; | |||
| } | |||
| if (mMin > mMax) { | |||
| MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (mMin == mMax) { | |||
| if (mMin != 0.0f) { | |||
| MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->inited = true; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = 0.0f; | |||
| quantParam->zeroPoint = 0; | |||
| quantParam->narrowRange = narrowRange; | |||
| quantParam->numBits = num_bits; | |||
| return RET_OK; | |||
| STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, bool narrowRange, | |||
| int quant_max, int quant_min, int num_bits) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| if (mMin > 0.0f) { | |||
| MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; | |||
| mMin = 0.0f; | |||
| } | |||
| if (mMax < 0.0f) { | |||
| MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; | |||
| mMax = 0.0f; | |||
| } | |||
| if (mMin > mMax) { | |||
| MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (mMin == mMax) { | |||
| if (mMin != 0.0f) { | |||
| MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | |||
| return RET_ERROR; | |||
| } | |||
| auto quantMinFloat = static_cast<double>(quant_min); | |||
| auto quantMaxFloat = static_cast<double>(quant_max); | |||
| double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); | |||
| const double zeroPointFromMin = quantMinFloat - mMin / scale; | |||
| // const double zeroPointFromMax = quantMaxFloat - mMax / scale; | |||
| int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin)); | |||
| // The zero point should always be in the range of quantized value, | |||
| // [qmin, qmax]. | |||
| MS_ASSERT(zeroPoint >= quantMin); | |||
| MS_ASSERT(zeroPoint <= quantMax); | |||
| quantParam->inited = true; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = scale; | |||
| quantParam->zeroPoint = zeroPoint; | |||
| quantParam->scale = 0.0f; | |||
| quantParam->zeroPoint = 0; | |||
| quantParam->narrowRange = narrowRange; | |||
| quantParam->numBits = num_bits; | |||
| return RET_OK; | |||
| } | |||
| auto quantMinFloat = static_cast<double>(quant_min); | |||
| auto quantMaxFloat = static_cast<double>(quant_max); | |||
| double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); | |||
| const double zeroPointFromMin = quantMinFloat - mMin / scale; | |||
| // const double zeroPointFromMax = quantMaxFloat - mMax / scale; | |||
| int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin)); | |||
| // The zero point should always be in the range of quantized value, | |||
| // [qmin, qmax]. | |||
| MS_ASSERT(zeroPoint >= quantMin); | |||
| MS_ASSERT(zeroPoint <= quantMax); | |||
| quantParam->inited = true; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = scale; | |||
| quantParam->zeroPoint = zeroPoint; | |||
| quantParam->narrowRange = narrowRange; | |||
| quantParam->numBits = num_bits; | |||
| return RET_OK; | |||
| } | |||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, | |||
| bool narrowRange, int numBits) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| if (mMin > 0.0f) { | |||
| MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; | |||
| mMin = 0.0f; | |||
| } | |||
| if (mMax < 0.0f) { | |||
| MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; | |||
| mMax = 0.0f; | |||
| } | |||
| if (mMin > mMax) { | |||
| MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (mMin == mMax) { | |||
| if (mMin != 0.0f) { | |||
| MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->inited = true; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = 0.0f; | |||
| quantParam->zeroPoint = 0; | |||
| quantParam->narrowRange = narrowRange; | |||
| quantParam->numBits = numBits; | |||
| return RET_OK; | |||
| } | |||
| int quantMin = narrowRange ? 1 : 0; | |||
| int quantMax = (1 << (unsigned int)numBits) - 1; | |||
| auto quantMinFloat = static_cast<double>(quantMin); | |||
| auto quantMaxFloat = static_cast<double>(quantMax); | |||
| double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); | |||
| const double zeroPointFromMin = quantMinFloat - mMin / scale; | |||
| const double zeroPointFromMax = quantMaxFloat - mMax / scale; | |||
| const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale); | |||
| const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale); | |||
| const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax; | |||
| int zeroPoint; | |||
| if (zpDouble < quantMinFloat) { | |||
| zeroPoint = quantMin; | |||
| } else if (zpDouble > quantMaxFloat) { | |||
| zeroPoint = quantMax; | |||
| } else { | |||
| zeroPoint = static_cast<int32_t>(std::round(zpDouble)); | |||
| } | |||
| // The zero point should always be in the range of quantized value, | |||
| // [qmin, qmax]. | |||
| MS_ASSERT(zeroPoint >= quantMin); | |||
| MS_ASSERT(zeroPoint <= quantMax); | |||
| quantParam->inited = true; | |||
| quantParam->min = mMin; | |||
| quantParam->max = mMax; | |||
| quantParam->scale = scale; | |||
| quantParam->zeroPoint = zeroPoint; | |||
| quantParam->narrowRange = narrowRange; | |||
| quantParam->numBits = numBits; | |||
| return RET_OK; | |||
| } | |||
| STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, | |||
| @@ -292,14 +356,14 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ | |||
| weightPtr->set_quant_param(quantParam); | |||
| } | |||
| auto ret = memcpy_s(const_cast<float*>(rawDatas), weightPtr->tensor_size(), | |||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| auto ret = | |||
| memcpy_s(const_cast<float *>(rawDatas), weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| if (quantType == QuantType_WeightQuant) { | |||
| PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); | |||
| PostBitPack(const_cast<float *>(rawDatas), shapeSize, bitNum); | |||
| } | |||
| weightPtr->set_tensor_type(kNumberTypeInt8); | |||
| @@ -338,14 +402,13 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ | |||
| qDatas[i] = quant_max; | |||
| } else if (quant_data < quant_min) { | |||
| qDatas[i] = quant_min; | |||
| } else { | |||
| } else { | |||
| qDatas[i] = static_cast<int8_t>(quant_data); | |||
| } | |||
| } | |||
| weightPtr->set_quant_param(quantParam); | |||
| auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), | |||
| qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| return RET_ERROR; | |||
| @@ -358,34 +421,32 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ | |||
| weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); | |||
| } | |||
| return RET_OK; | |||
| return RET_OK; | |||
| } | |||
| STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { | |||
| auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | |||
| vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | |||
| vector<uint8_t> qDatas_packed; | |||
| if (bitNum < 8 && bitNum > 1) { | |||
| BitPack weight_bitpack(bitNum); | |||
| weight_bitpack.BitPacking(qDatas, qDatas_packed); | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (bitNum == 8) { | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; | |||
| return RET_ERROR; | |||
| auto *rawDatas = reinterpret_cast<uint8_t *>(weight); | |||
| vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize); | |||
| vector<uint8_t> qDatas_packed; | |||
| if (bitNum < 8 && bitNum > 1) { | |||
| BitPack weight_bitpack(bitNum); | |||
| weight_bitpack.BitPacking(qDatas, qDatas_packed); | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else if (bitNum == 8) { | |||
| if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { | |||
| MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| return RET_OK; | |||
| } | |||
| } // namespace quant | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -62,6 +62,41 @@ class QuantStrategy { | |||
| STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, | |||
| bool narrowRange, int quant_max, int quant_min, int num_bits); | |||
| STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, | |||
| bool narrowRange = false, int numBits = UINT8_QUANTIZATION); | |||
| template <typename T> | |||
| T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| MS_ASSERT(quantParam->inited); | |||
| const auto scale = quantParam->scale; | |||
| const auto zeroPoint = quantParam->zeroPoint; | |||
| const auto numBit = quantParam->numBits; | |||
| const auto narrowRange = quantParam->narrowRange; | |||
| const double maxLimit = static_cast<float>((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; | |||
| double minLimit; | |||
| if (narrowRange) { | |||
| minLimit = static_cast<float>(1 - zeroPoint) * scale; | |||
| } else { | |||
| minLimit = static_cast<float>(0 - zeroPoint) * scale; | |||
| } | |||
| return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { | |||
| double tmp = 0.0f; | |||
| if (originData > maxLimit) { | |||
| tmp = maxLimit; | |||
| } else if (originData < minLimit) { | |||
| tmp = minLimit; | |||
| } else { | |||
| tmp = originData; | |||
| } | |||
| auto quantData = static_cast<T>(std::round(tmp / scale + zeroPoint)); | |||
| if (quantData == 0 && narrowRange) { | |||
| quantData++; | |||
| } | |||
| return quantData; | |||
| }(); | |||
| } | |||
| template <typename T> | |||
| T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { | |||
| MS_ASSERT(quantParam != nullptr); | |||
| @@ -15,22 +15,19 @@ | |||
| */ | |||
| #include "mindspore/lite/tools/converter/quantizer/quantizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace quant { | |||
| Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) { | |||
| if (funcGraph == nullptr) { | |||
| return; | |||
| } | |||
| } | |||
| namespace mindspore::lite::quant { | |||
| STATUS Quantizer::GenerateQuantParam() { return RET_OK; } | |||
| STATUS Quantizer::RemoveFakeQuant() { return RET_OK; } | |||
| STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; } | |||
| } // namespace quant | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; } | |||
| STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; } | |||
| STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -18,48 +18,63 @@ | |||
| #define MS_QUANTIZER_H | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "include/errorcode.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "include/model.h" | |||
| #include "base/base.h" | |||
| #include "src/param_value_lite.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace quant { | |||
| namespace mindspore::lite::quant { | |||
| using STATUS = int; | |||
| enum QuantType { | |||
| QuantType_QUANT_NONE = 0, | |||
| QuantType_AwareTraining = 1, | |||
| QuantType_WeightQuant = 2, | |||
| QuantType_PostTraining = 3, | |||
| QuantType_MIN = QuantType_QUANT_NONE, | |||
| QuantType_MAX = QuantType_PostTraining | |||
| QuantType_QUANT_NONE = 0, | |||
| QuantType_AwareTraining = 1, | |||
| QuantType_WeightQuant = 2, | |||
| QuantType_PostTraining = 3, | |||
| QuantType_MIN = QuantType_QUANT_NONE, | |||
| QuantType_MAX = QuantType_PostTraining | |||
| }; | |||
| class Quantizer { | |||
| public: | |||
| explicit Quantizer(FuncGraphPtr graph); | |||
| explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {} | |||
| ~Quantizer() = default; | |||
| ~Quantizer() = default; | |||
| virtual STATUS RemoveFakeQuant(); | |||
| virtual STATUS RemoveFakeQuant(); | |||
| virtual STATUS GenerateQuantParam(); | |||
| virtual STATUS GenerateQuantParam(); | |||
| virtual STATUS DetermineNodeQuantType(); | |||
| virtual STATUS DetermineNodeQuantType(); | |||
| virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; | |||
| virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; | |||
| mindspore::lite::converter::Flags flags; | |||
| protected: | |||
| FuncGraphPtr funcGraph = nullptr; | |||
| FuncGraphPtr funcGraph = nullptr; | |||
| }; | |||
| } // namespace quant | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif | |||
| class FbQuantizer { | |||
| public: | |||
| explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {} | |||
| ~FbQuantizer() = default; | |||
| virtual STATUS RemoveFakeQuant(); | |||
| virtual STATUS GenerateQuantParam(); | |||
| virtual STATUS DetermineNodeQuantType(); | |||
| virtual STATUS DoQuantize() = 0; | |||
| protected: | |||
| std::shared_ptr<schema::MetaGraphT> graph = nullptr; | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||