Merge pull request !6200 from cjh9368/weight_quanttags/v1.0.0
| @@ -203,6 +203,7 @@ if(BUILD_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -75,7 +75,7 @@ static const std::vector<schema::PrimitiveType> int8OpList = { | |||
| schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, | |||
| schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, | |||
| schema::PrimitiveType_MatMul}; | |||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad}; | |||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||
| @@ -61,6 +61,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/conv_scale_fusion.cc | |||
| ../optimizer/fusion/conv_bn_fusion.cc | |||
| ../optimizer/fusion/constant_folding_fusion.cc | |||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -24,6 +24,7 @@ | |||
| #include "tools/optimizer/fusion/conv_scale_fusion.h" | |||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -43,6 +44,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| // for now - trainning is not supporting fuse operations | |||
| if (config != nullptr && config->trainModel == false) { | |||
| // remove quantdtype when awaretraining | |||
| if (config->quantType == QuantType_AwareTraining) { | |||
| pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>()); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>()); | |||
| pm->AddPass(std::make_shared<opt::ConvScaleFusion>()); | |||
| @@ -102,7 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| // generate and infer quant parameters | |||
| { | |||
| if (mQuantizer != nullptr) { | |||
| if (fbQuantizer != nullptr) { | |||
| Optimizer topologicalOptimizer; | |||
| topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| status = topologicalOptimizer.Run(graphDefT); | |||
| @@ -110,14 +110,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| if (!(this->graphDefT->fmkType == converter::FmkType_TF && | |||
| this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { | |||
| status = mQuantizer->GenerateQuantParam(); | |||
| if (ctx.quantType == QuantType_AwareTraining) { | |||
| status = fbQuantizer->GenerateQuantParam(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateQuantParam failed"; | |||
| return status; | |||
| } | |||
| status = mQuantizer->DetermineNodeQuantType(); | |||
| status = fbQuantizer->DetermineNodeQuantType(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DetermineNodeQuant failed"; | |||
| return status; | |||
| @@ -151,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | |||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||
| if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||
| continue; | |||
| } | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | |||
| @@ -46,7 +46,7 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) { | |||
| if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| @@ -77,8 +77,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||
| } | |||
| // change quant param min to 0 to fit ms-lite ops | |||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 | |||
| && tensor->dataType == TypeId::kNumberTypeInt8) { | |||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { | |||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | |||
| } | |||
| @@ -115,11 +115,6 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||
| return status; | |||
| } | |||
| } | |||
| auto status = GenerateDefaultQuantParam(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; | |||
| return status; | |||
| } | |||
| auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| @@ -135,7 +130,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||
| << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; | |||
| node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE); | |||
| } else { | |||
| status = quantParamCalcer->Calc(graph, *node); | |||
| auto status = quantParamCalcer->Calc(graph, *node); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); | |||
| node->quantType = schema::QuantType_QUANT_NONE; | |||
| @@ -167,17 +162,23 @@ STATUS AwareQuantizer::DoQuantize() { | |||
| return RET_ERROR; | |||
| } | |||
| // quant weight | |||
| status = QuantConvWeight(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvWeight failed!"; | |||
| return RET_ERROR; | |||
| auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1)); | |||
| if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { | |||
| status = QuantConvWeight(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvWeight failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| // quant bias | |||
| if (inputIndexes.size() == 3) { | |||
| status = QuantConvBias(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvBias failed!"; | |||
| return RET_ERROR; | |||
| auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2)); | |||
| if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { | |||
| status = QuantConvBias(graph, node.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantConvBias failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { | |||
| @@ -376,29 +377,17 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| bool canQuant = true; | |||
| for (auto &inTensorIdx : node->inputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > inTensorIdx); | |||
| auto &inTensor = graph->allTensors.at(inTensorIdx); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || | |||
| !inTensor->quantParams.front()->inited) { | |||
| for (auto &outTensorIdx : node->outputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > outTensorIdx); | |||
| auto &outTensor = graph->allTensors.at(outTensorIdx); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||
| !outTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| if (canQuant) { | |||
| for (auto &outTensorIdx : node->outputIndex) { | |||
| MS_ASSERT(graph->allTensors.size() > outTensorIdx); | |||
| auto &outTensor = graph->allTensors.at(outTensorIdx); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || | |||
| !outTensor->quantParams.front()->inited) { | |||
| canQuant = false; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||
| node->quantType = schema::QuantType_AwareTraining; | |||
| } else { | |||
| @@ -70,6 +70,9 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto &tensor = graph->allTensors.at(node.inputIndex.at(i)); | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto quantParam = GetTensorQuantParam(tensor); | |||
| if (quantParam == nullptr) { | |||
| continue; | |||
| } | |||
| if (quantParam->inited) { // inited | |||
| inputParamDone++; | |||
| continue; | |||
| @@ -77,8 +80,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (tensor->refCount == schema::NodeType::NodeType_ValueNode && | |||
| !IsContain(graph->inputIndex, node.inputIndex.at(i))) { | |||
| if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { | |||
| auto status = ComputeConstQuantParam((*tensor), quantParam.get()); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status; | |||
| @@ -95,13 +97,12 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto &tensor = graph->allTensors.at(i); | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto quantParam = GetTensorQuantParam(tensor); | |||
| MS_ASSERT(quantParam != nullptr); | |||
| if (quantParam->inited) { // inited | |||
| if (quantParam != nullptr && quantParam->inited) { // inited | |||
| outputParamDone++; | |||
| continue; | |||
| } | |||
| if (tensor->refCount == 999) { | |||
| if (!tensor->data.empty()) { | |||
| MS_ASSERT(false); | |||
| } | |||
| } | |||
| @@ -146,10 +147,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto &inTensor = graph->allTensors.at(i); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| if (inQuantParam->inited) { | |||
| if (inQuantParam == nullptr || inQuantParam->inited) { | |||
| continue; | |||
| } | |||
| inTensor->quantParams.front() = std::move(inQuantParam); | |||
| inTensor->quantParams.front() = std::move(outputQuantParam); | |||
| } | |||
| } | |||
| if (outputParamDone != node.outputIndex.size()) { | |||
| @@ -157,7 +158,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto &inTensor = graph->allTensors.at(node.inputIndex.at(0)); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| if (!inQuantParam->inited) { | |||
| if (inQuantParam == nullptr || !inQuantParam->inited) { | |||
| MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -166,10 +167,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| if (outQuantParam->inited) { | |||
| if (outQuantParam == nullptr || outQuantParam->inited) { | |||
| continue; | |||
| } | |||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||
| outTensor->quantParams.front() = std::move(inQuantParam); | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -225,13 +226,14 @@ class CalcConcat : public QuantParamCalcer { | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| auto outQuantParam = std::make_unique<QuantParamT>(); | |||
| status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits); | |||
| if (status != RET_OK) { | |||
| MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||
| outputParamDone++; | |||
| } | |||
| @@ -261,7 +263,7 @@ class CalcAdd : public QuantParamCalcer { | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| auto outQuantParam = std::make_unique<QuantParamT>(); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); | |||
| auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); | |||
| @@ -271,10 +273,10 @@ class CalcAdd : public QuantParamCalcer { | |||
| MS_ASSERT(tensor1 != nullptr); | |||
| auto biasTensor = &tensor0; | |||
| auto paramTensor = &tensor1; | |||
| if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { | |||
| if (!tensor0->data.empty() && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { | |||
| biasTensor = &tensor0; | |||
| paramTensor = &tensor1; | |||
| } else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| } else if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| biasTensor = &tensor1; | |||
| paramTensor = &tensor0; | |||
| } else { | |||
| @@ -310,6 +312,7 @@ class CalcAdd : public QuantParamCalcer { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -337,13 +340,13 @@ class CalcRealDiv : public QuantParamCalcer { | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outQuantParam = GetTensorQuantParam(outTensor); | |||
| auto outQuantParam = std::make_unique<QuantParamT>(); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); | |||
| MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); | |||
| auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); | |||
| MS_ASSERT(tensor1 != nullptr); | |||
| if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { | |||
| auto quantParam = GetTensorQuantParam(tensor1); | |||
| auto min = quantParam->min; | |||
| auto max = quantParam->max; | |||
| @@ -371,6 +374,7 @@ class CalcRealDiv : public QuantParamCalcer { | |||
| MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType; | |||
| return RET_ERROR; | |||
| } | |||
| outTensor->quantParams.front() = std::move(outQuantParam); | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name; | |||
| @@ -399,21 +403,24 @@ class CalcToSet : public QuantParamCalcer { | |||
| return RET_ERROR; | |||
| } | |||
| // output | |||
| std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT()); | |||
| if (quantParam == nullptr) { | |||
| MS_LOG(WARNING) << "new QuantParamT failed"; | |||
| return RET_ERROR; | |||
| if (outputParamDone != node.outputIndex.size()) { | |||
| std::unique_ptr<QuantParamT> quantParam = std::make_unique<QuantParamT>(); | |||
| if (quantParam == nullptr) { | |||
| MS_LOG(WARNING) << "new QuantParamT failed"; | |||
| return RET_ERROR; | |||
| } | |||
| quantParam->scale = (max - min) / 256; | |||
| MS_ASSERT(quantParam->scale != 0); | |||
| quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); | |||
| quantParam->min = min; | |||
| quantParam->max = max; | |||
| quantParam->inited = true; | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| outTensor->quantParams.front() = std::move(quantParam); | |||
| outputParamDone++; | |||
| } | |||
| quantParam->scale = (max - min) / 256; | |||
| MS_ASSERT(quantParam->scale != 0); | |||
| quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); | |||
| quantParam->min = min; | |||
| quantParam->max = max; | |||
| quantParam->inited = true; | |||
| MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); | |||
| auto &outTensor = graph->allTensors.at(node.outputIndex.front()); | |||
| MS_ASSERT(outTensor != nullptr); | |||
| outTensor->quantParams.front() = std::move(quantParam); | |||
| return RET_OK; | |||
| } | |||
| @@ -357,6 +357,14 @@ bool IsPoolingNode(const BaseRef &n) { | |||
| return false; | |||
| } | |||
| bool IsQuantNode(const BaseRef &n) { | |||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | |||
| auto type = opt::GetCNodeType(n); | |||
| return type == schema::PrimitiveType_QuantDTypeCast; | |||
| } | |||
| return false; | |||
| } | |||
| bool CheckIsAllInputsParam(const AnfNodePtr &node) { | |||
| if (utils::isa<CNode>(node)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -58,6 +58,8 @@ bool IsConvNode(const BaseRef &n); | |||
| bool IsPoolingNode(const BaseRef &n); | |||
| bool IsQuantNode(const BaseRef &n); | |||
| bool CheckIsAllInputsParam(const AnfNodePtr &node); | |||
| size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/ops/conv2d.h" | |||
| #include "src/ops/depthwise_conv2d.h" | |||
| #include "src/ops/activation.h" | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kActivationInputsLength = 2; | |||
| } | |||
| const BaseRef QuantDtypeCastFusion::DefinePattern() const { | |||
| auto quant_var = std::make_shared<CondVar>(IsQuantNode); | |||
| auto input_var = std::make_shared<Var>(); | |||
| return VectorRef({quant_var, input_var}); | |||
| } | |||
| const AnfNodePtr QuantDtypeCastFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_LOG(DEBUG) << "quant dtype cast fusion pass process"; | |||
| CheckIfFuncGraphIsNull(func_graph); | |||
| CheckIfAnfNodeIsNull(node); | |||
| auto act_node = node->cast<CNodePtr>(); | |||
| CheckIfCNodeIsNull(act_node); | |||
| CheckInputSize(act_node, kActivationInputsLength); | |||
| AnfNodePtr pre_node = act_node->input(1); | |||
| CheckIfAnfNodeIsNull(pre_node); | |||
| return pre_node; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_QUANT_DTYPE_CAST_FUSION_H | |||
| #define LITE_QUANT_DTYPE_CAST_FUSION_H | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class QuantDtypeCastFusion : public PatternProcessPass { | |||
| public: | |||
| explicit QuantDtypeCastFusion(bool multigraph = true, const std::string &name = "quant_dtype_cast_fusion") | |||
| : PatternProcessPass(name, multigraph) {} | |||
| ~QuantDtypeCastFusion() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // LITE_QUANT_DTYPE_CAST_FUSION_H | |||