/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tools/converter/anf_transform.h" #include #include #include "src/common/log_adapter.h" #include "tools/optimizer/fusion/conv_biasadd_fusion.h" #include "tools/optimizer/fusion/conv_activation_fusion.h" #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h" #include "tools/optimizer/fusion/conv_scale_fusion.h" #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" #include "tools/optimizer/fusion/layer_norm_fusion.h" #include "tools/optimizer/fusion/batchmatmul_fusion.h" #include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" #include "tools/optimizer/graph/unused_cast_node_remove_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" using std::string; namespace mindspore { namespace lite { AnfTransform::AnfTransform() = default; AnfTransform::~AnfTransform() = default; FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_ASSERT(nullptr != old_graph); // fusion const_fold auto optimizer = std::make_shared(); auto pm = std::make_shared("anf fusion pass manager", false); auto graph_pm = std::make_shared("anf graph pass manager", true); auto convert_pm = std::make_shared("anf graph convert pass manager", true); // for now - trainning is not supporting fuse operations if (config != nullptr && !config->trainModel) { // remove quantdtype when awaretraining if (config->fmk == lite::converter::FmkType_ONNX) { auto remove_identity_pass = std::make_shared(); remove_identity_pass->SetFmkType(config->fmk); pm->AddPass(remove_identity_pass); } pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared(true, "conv_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); pm->AddPass(std::make_shared(true, "conv_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); pm->AddPass(std::make_shared( true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU)); pm->AddPass(std::make_shared( true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); } auto weight_format_hardcode_pass = std::make_shared(); weight_format_hardcode_pass->SetFmkType(config->fmk); weight_format_hardcode_pass->SetQuantType(config->quantType); graph_pm->AddPass(weight_format_hardcode_pass); auto weight_format_transform_pass = std::make_shared(); weight_format_transform_pass->SetFmkType(config->fmk); weight_format_transform_pass->SetQuantType(config->quantType); graph_pm->AddPass(weight_format_transform_pass); if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); remove_unused_cast_pass->SetFmkType(config->fmk); pm->AddPass(remove_unused_cast_pass); } pm->AddPass(std::make_shared()); convert_pm->AddPass(std::make_shared()); optimizer->AddPassManager(convert_pm); optimizer->AddPassManager(pm); optimizer->AddPassManager(graph_pm); auto new_graph = optimizer->Optimize(old_graph); if (new_graph == nullptr) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); return nullptr; } // quant if (config->quantType == schema::QuantType_PostTraining) { this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); return nullptr; } } else if (config->quantType == schema::QuantType_WeightQuant) { if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { MS_LOG(ERROR) << "weight quant input param error"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } this->mQuantizer = std::make_unique(new_graph, config->quantWeightSize, config->quantWeightChannel, config->bitNum); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); return nullptr; } } if (mQuantizer != nullptr) { mQuantizer->flags = *config; auto status = mQuantizer->DoQuantize(new_graph); if (status != RET_OK) { MS_LOG(ERROR) << "Quant failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } } return new_graph; } } // namespace lite } // namespace mindspore