From 1cd9445087375b3f321c35aab7ae9f37cb254a17 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Wed, 9 Sep 2020 16:20:06 +0800 Subject: [PATCH] Aware training support patial quant --- mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/tools/common/node_util.cc | 2 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 5 ++ .../tools/converter/graphdef_transform.cc | 9 ++- .../graph/dtype_trans_pass.cc | 2 +- .../parser/tflite/tflite_dequantize_parser.cc | 2 +- .../parser/tflite/tflite_model_parser.cc | 3 +- .../converter/quantizer/aware_quantizer.cc | 53 ++++++--------- .../converter/quantizer/calc_quant_param.cc | 67 ++++++++++--------- .../lite/tools/optimizer/common/gllo_utils.cc | 8 +++ .../lite/tools/optimizer/common/gllo_utils.h | 2 + .../fusion/quant_dtype_cast_fusion.cc | 47 +++++++++++++ .../fusion/quant_dtype_cast_fusion.h | 35 ++++++++++ 14 files changed, 165 insertions(+), 72 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.h diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 890b000b8f..78c38f7261 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -203,6 +203,7 @@ if(BUILD_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ) endif() ### train diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index f750b6eda4..b1d8d970ce 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -75,7 +75,7 @@ static const std::vector int8OpList = { schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, - schema::PrimitiveType_MatMul}; + schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index e4d45e421c..1de0a3896a 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -61,6 +61,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/conv_scale_fusion.cc ../optimizer/fusion/conv_bn_fusion.cc ../optimizer/fusion/constant_folding_fusion.cc + ../optimizer/fusion/quant_dtype_cast_fusion.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index cc3298e8af..3f47e71115 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -24,6 +24,7 @@ #include "tools/optimizer/fusion/conv_scale_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" +#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -43,6 +44,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver // for now - trainning is not supporting fuse operations if (config != nullptr && config->trainModel == false) { + // remove quantdtype when awaretraining + if (config->quantType == QuantType_AwareTraining) { + pm->AddPass(std::make_shared()); + } pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 280a46df03..82188fff45 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -102,7 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // generate and infer quant parameters { - if (mQuantizer != nullptr) { + if (fbQuantizer != nullptr) { Optimizer topologicalOptimizer; topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); status = topologicalOptimizer.Run(graphDefT); @@ -110,14 +110,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; return status; } - if (!(this->graphDefT->fmkType == converter::FmkType_TF && - this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { - status = mQuantizer->GenerateQuantParam(); + if (ctx.quantType == QuantType_AwareTraining) { + status = fbQuantizer->GenerateQuantParam(); if (status != RET_OK) { MS_LOG(ERROR) << "GenerateQuantParam failed"; return status; } - status = mQuantizer->DetermineNodeQuantType(); + status = fbQuantizer->DetermineNodeQuantType(); if (status != RET_OK) { MS_LOG(ERROR) << "DetermineNodeQuant failed"; return status; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index d5b740b46e..e52cc50457 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -151,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); auto &graphInIdxes = graph->inputIndex; - if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { + if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { continue; } iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index 2865c086eb..d6f340d1d4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -46,7 +46,7 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr &t MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; } - if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) { + if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 84324c3650..564fa4f138 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -77,8 +77,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptrtype) == TypeId::kNumberTypeUInt8 - && tensor->dataType == TypeId::kNumberTypeInt8) { + if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { quant_param->zeroPoint = quant_param->zeroPoint - 128; } diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 9846087828..8e64886cb8 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -115,11 +115,6 @@ STATUS AwareQuantizer::GenerateQuantParam() { return status; } } - auto status = GenerateDefaultQuantParam(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; - return status; - } auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { @@ -135,7 +130,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { - status = quantParamCalcer->Calc(graph, *node); + auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; @@ -167,17 +162,23 @@ STATUS AwareQuantizer::DoQuantize() { return RET_ERROR; } // quant weight - status = QuantConvWeight(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvWeight failed!"; - return RET_ERROR; + auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1)); + if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) { + status = QuantConvWeight(graph, node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantConvWeight failed!"; + return RET_ERROR; + } } // quant bias if (inputIndexes.size() == 3) { - status = QuantConvBias(graph, node.get()); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantConvBias failed!"; - return RET_ERROR; + auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2)); + if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) { + status = QuantConvBias(graph, node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantConvBias failed!"; + return RET_ERROR; + } } } } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { @@ -376,29 +377,17 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { for (auto &node : graph->nodes) { MS_ASSERT(node != nullptr); bool canQuant = true; - for (auto &inTensorIdx : node->inputIndex) { - MS_ASSERT(graph->allTensors.size() > inTensorIdx); - auto &inTensor = graph->allTensors.at(inTensorIdx); - MS_ASSERT(inTensor != nullptr); - if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || - !inTensor->quantParams.front()->inited) { + for (auto &outTensorIdx : node->outputIndex) { + MS_ASSERT(graph->allTensors.size() > outTensorIdx); + auto &outTensor = graph->allTensors.at(outTensorIdx); + MS_ASSERT(outTensor != nullptr); + if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || + !outTensor->quantParams.front()->inited) { canQuant = false; break; } } - if (canQuant) { - for (auto &outTensorIdx : node->outputIndex) { - MS_ASSERT(graph->allTensors.size() > outTensorIdx); - auto &outTensor = graph->allTensors.at(outTensorIdx); - MS_ASSERT(outTensor != nullptr); - if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || - !outTensor->quantParams.front()->inited) { - canQuant = false; - break; - } - } - } if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { node->quantType = schema::QuantType_AwareTraining; } else { diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index da3a9c84ef..d04a7fb2b7 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -70,6 +70,9 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &tensor = graph->allTensors.at(node.inputIndex.at(i)); MS_ASSERT(tensor != nullptr); auto quantParam = GetTensorQuantParam(tensor); + if (quantParam == nullptr) { + continue; + } if (quantParam->inited) { // inited inputParamDone++; continue; @@ -77,8 +80,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); MS_ASSERT(tensor != nullptr); - if (tensor->refCount == schema::NodeType::NodeType_ValueNode && - !IsContain(graph->inputIndex, node.inputIndex.at(i))) { + if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { auto status = ComputeConstQuantParam((*tensor), quantParam.get()); if (status != RET_OK) { MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status; @@ -95,13 +97,12 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &tensor = graph->allTensors.at(i); MS_ASSERT(tensor != nullptr); auto quantParam = GetTensorQuantParam(tensor); - MS_ASSERT(quantParam != nullptr); - if (quantParam->inited) { // inited + if (quantParam != nullptr && quantParam->inited) { // inited outputParamDone++; continue; } - if (tensor->refCount == 999) { + if (!tensor->data.empty()) { MS_ASSERT(false); } } @@ -146,10 +147,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &inTensor = graph->allTensors.at(i); MS_ASSERT(inTensor != nullptr); auto inQuantParam = GetTensorQuantParam(inTensor); - if (inQuantParam->inited) { + if (inQuantParam == nullptr || inQuantParam->inited) { continue; } - inTensor->quantParams.front() = std::move(inQuantParam); + inTensor->quantParams.front() = std::move(outputQuantParam); } } if (outputParamDone != node.outputIndex.size()) { @@ -157,7 +158,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &inTensor = graph->allTensors.at(node.inputIndex.at(0)); MS_ASSERT(inTensor != nullptr); auto inQuantParam = GetTensorQuantParam(inTensor); - if (!inQuantParam->inited) { + if (inQuantParam == nullptr || !inQuantParam->inited) { MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name; return RET_ERROR; } @@ -166,10 +167,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); MS_ASSERT(outTensor != nullptr); auto outQuantParam = GetTensorQuantParam(outTensor); - if (outQuantParam->inited) { + if (outQuantParam == nullptr || outQuantParam->inited) { continue; } - outTensor->quantParams.front() = std::move(outQuantParam); + outTensor->quantParams.front() = std::move(inQuantParam); } } return RET_OK; @@ -225,13 +226,14 @@ class CalcConcat : public QuantParamCalcer { MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); auto &outTensor = graph->allTensors.at(node.outputIndex.front()); MS_ASSERT(outTensor != nullptr); - auto outQuantParam = GetTensorQuantParam(outTensor); + auto outQuantParam = std::make_unique(); status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits); if (status != RET_OK) { MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } + outTensor->quantParams.front() = std::move(outQuantParam); outputParamDone++; } @@ -261,7 +263,7 @@ class CalcAdd : public QuantParamCalcer { MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); auto &outTensor = graph->allTensors.at(node.outputIndex.front()); MS_ASSERT(outTensor != nullptr); - auto outQuantParam = GetTensorQuantParam(outTensor); + auto outQuantParam = std::make_unique(); MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); @@ -271,10 +273,10 @@ class CalcAdd : public QuantParamCalcer { MS_ASSERT(tensor1 != nullptr); auto biasTensor = &tensor0; auto paramTensor = &tensor1; - if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { + if (!tensor0->data.empty() && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { biasTensor = &tensor0; paramTensor = &tensor1; - } else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { + } else if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { biasTensor = &tensor1; paramTensor = &tensor0; } else { @@ -310,6 +312,7 @@ class CalcAdd : public QuantParamCalcer { return RET_ERROR; } } + outTensor->quantParams.front() = std::move(outQuantParam); } return RET_OK; } @@ -337,13 +340,13 @@ class CalcRealDiv : public QuantParamCalcer { MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); auto &outTensor = graph->allTensors.at(node.outputIndex.front()); MS_ASSERT(outTensor != nullptr); - auto outQuantParam = GetTensorQuantParam(outTensor); + auto outQuantParam = std::make_unique(); MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); MS_ASSERT(tensor1 != nullptr); - if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { + if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { auto quantParam = GetTensorQuantParam(tensor1); auto min = quantParam->min; auto max = quantParam->max; @@ -371,6 +374,7 @@ class CalcRealDiv : public QuantParamCalcer { MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType; return RET_ERROR; } + outTensor->quantParams.front() = std::move(outQuantParam); } } else { MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name; @@ -399,21 +403,24 @@ class CalcToSet : public QuantParamCalcer { return RET_ERROR; } // output - std::unique_ptr quantParam(new (std::nothrow) QuantParamT()); - if (quantParam == nullptr) { - MS_LOG(WARNING) << "new QuantParamT failed"; - return RET_ERROR; + if (outputParamDone != node.outputIndex.size()) { + std::unique_ptr quantParam = std::make_unique(); + if (quantParam == nullptr) { + MS_LOG(WARNING) << "new QuantParamT failed"; + return RET_ERROR; + } + quantParam->scale = (max - min) / 256; + MS_ASSERT(quantParam->scale != 0); + quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); + quantParam->min = min; + quantParam->max = max; + quantParam->inited = true; + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &outTensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(outTensor != nullptr); + outTensor->quantParams.front() = std::move(quantParam); + outputParamDone++; } - quantParam->scale = (max - min) / 256; - MS_ASSERT(quantParam->scale != 0); - quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); - quantParam->min = min; - quantParam->max = max; - quantParam->inited = true; - MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); - auto &outTensor = graph->allTensors.at(node.outputIndex.front()); - MS_ASSERT(outTensor != nullptr); - outTensor->quantParams.front() = std::move(quantParam); return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 706fb9ebb2..750152fdd8 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -357,6 +357,14 @@ bool IsPoolingNode(const BaseRef &n) { return false; } +bool IsQuantNode(const BaseRef &n) { + if (utils::isa(n) || utils::isa(n)) { + auto type = opt::GetCNodeType(n); + return type == schema::PrimitiveType_QuantDTypeCast; + } + return false; +} + bool CheckIsAllInputsParam(const AnfNodePtr &node) { if (utils::isa(node)) { auto cnode = node->cast(); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 4554857ccd..9733986f01 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -58,6 +58,8 @@ bool IsConvNode(const BaseRef &n); bool IsPoolingNode(const BaseRef &n); +bool IsQuantNode(const BaseRef &n); + bool CheckIsAllInputsParam(const AnfNodePtr &node); size_t GetOutputTensorNum(const AnfNodePtr &node); diff --git a/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc b/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc new file mode 100644 index 0000000000..c304cc8faf --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" +#include +#include "src/ops/primitive_c.h" +#include "src/ops/conv2d.h" +#include "src/ops/depthwise_conv2d.h" +#include "src/ops/activation.h" +#include "schema/inner/model_generated.h" +#include "tools/optimizer/common/gllo_utils.h" +namespace mindspore::opt { +namespace { +constexpr size_t kActivationInputsLength = 2; +} +const BaseRef QuantDtypeCastFusion::DefinePattern() const { + auto quant_var = std::make_shared(IsQuantNode); + auto input_var = std::make_shared(); + return VectorRef({quant_var, input_var}); +} + +const AnfNodePtr QuantDtypeCastFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_LOG(DEBUG) << "quant dtype cast fusion pass process"; + CheckIfFuncGraphIsNull(func_graph); + + CheckIfAnfNodeIsNull(node); + auto act_node = node->cast(); + CheckIfCNodeIsNull(act_node); + CheckInputSize(act_node, kActivationInputsLength); + AnfNodePtr pre_node = act_node->input(1); + CheckIfAnfNodeIsNull(pre_node); + return pre_node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.h b/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.h new file mode 100644 index 0000000000..28a6294839 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/quant_dtype_cast_fusion.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LITE_QUANT_DTYPE_CAST_FUSION_H +#define LITE_QUANT_DTYPE_CAST_FUSION_H + +#include +#include "backend/optimizer/common/optimizer.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace opt { +class QuantDtypeCastFusion : public PatternProcessPass { + public: + explicit QuantDtypeCastFusion(bool multigraph = true, const std::string &name = "quant_dtype_cast_fusion") + : PatternProcessPass(name, multigraph) {} + ~QuantDtypeCastFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // LITE_QUANT_DTYPE_CAST_FUSION_H