You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anf_transform.cc 6.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/converter/anf_transform.h"
  17. #include <memory>
  18. #include <string>
  19. #include "src/common/log_adapter.h"
  20. #include "tools/optimizer/fusion/conv_biasadd_fusion.h"
  21. #include "tools/optimizer/fusion/conv_activation_fusion.h"
  22. #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
  23. #include "tools/optimizer/fusion/conv_scale_fusion.h"
  24. #include "tools/optimizer/fusion/conv_bn_fusion.h"
  25. #include "tools/optimizer/fusion/constant_folding_fusion.h"
  26. #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
  27. #include "tools/optimizer/fusion/layer_norm_fusion.h"
  28. #include "tools/optimizer/fusion/batchmatmul_fusion.h"
  29. #include "tools/optimizer/graph/identity_remove_pass.h"
  30. #include "tools/optimizer/graph/weight_format_hardcode_pass.h"
  31. #include "tools/optimizer/graph/weight_format_transform_pass.h"
  32. #include "tools/optimizer/graph/clip_convert_activation_pass.h"
  33. #include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
  34. #include "tools/converter/quantizer/post_training_quantizer.h"
  35. #include "tools/converter/quantizer/quant_cast.h"
  36. #include "tools/converter/quantizer/weight_quantizer.h"
  37. using std::string;
  38. namespace mindspore {
  39. namespace lite {
  40. AnfTransform::AnfTransform() = default;
  41. AnfTransform::~AnfTransform() = default;
  42. FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
  43. MS_ASSERT(nullptr != old_graph);
  44. // fusion const_fold
  45. auto optimizer = std::make_shared<opt::GraphOptimizer>();
  46. auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
  47. auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
  48. auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
  49. // for now - trainning is not supporting fuse operations
  50. if (config != nullptr && !config->trainModel) {
  51. // remove quantdtype when awaretraining
  52. if (config->fmk == lite::converter::FmkType_ONNX) {
  53. auto remove_identity_pass = std::make_shared<opt::RemoveIdentityOpPass>();
  54. remove_identity_pass->SetFmkType(config->fmk);
  55. pm->AddPass(remove_identity_pass);
  56. }
  57. pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
  58. pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
  59. pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
  60. pm->AddPass(std::make_shared<opt::LayerNormFusion>());
  61. pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
  62. pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
  63. schema::ActivationType_RELU));
  64. pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
  65. schema::ActivationType_RELU6));
  66. pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
  67. true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU));
  68. pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
  69. true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6));
  70. }
  71. auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
  72. weight_format_hardcode_pass->SetFmkType(config->fmk);
  73. weight_format_hardcode_pass->SetQuantType(config->quantType);
  74. graph_pm->AddPass(weight_format_hardcode_pass);
  75. auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
  76. weight_format_transform_pass->SetFmkType(config->fmk);
  77. weight_format_transform_pass->SetQuantType(config->quantType);
  78. graph_pm->AddPass(weight_format_transform_pass);
  79. if (config->fmk == lite::converter::FmkType_MS) {
  80. auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
  81. remove_unused_cast_pass->SetFmkType(config->fmk);
  82. pm->AddPass(remove_unused_cast_pass);
  83. }
  84. pm->AddPass(std::make_shared<opt::ConstFoldPass>());
  85. convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
  86. optimizer->AddPassManager(convert_pm);
  87. optimizer->AddPassManager(pm);
  88. optimizer->AddPassManager(graph_pm);
  89. auto new_graph = optimizer->Optimize(old_graph);
  90. if (new_graph == nullptr) {
  91. ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
  92. return nullptr;
  93. }
  94. // quant
  95. if (config->quantType == schema::QuantType_PostTraining) {
  96. this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
  97. if (mQuantizer == nullptr) {
  98. MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
  99. ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
  100. return nullptr;
  101. }
  102. } else if (config->quantType == schema::QuantType_WeightQuant) {
  103. if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
  104. MS_LOG(ERROR) << "weight quant input param error";
  105. ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
  106. return nullptr;
  107. }
  108. this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize,
  109. config->quantWeightChannel, config->bitNum);
  110. if (mQuantizer == nullptr) {
  111. MS_LOG(ERROR) << "New WeightQuantizer failed";
  112. ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
  113. return nullptr;
  114. }
  115. }
  116. if (mQuantizer != nullptr) {
  117. mQuantizer->flags = *config;
  118. auto status = mQuantizer->DoQuantize(new_graph);
  119. if (status != RET_OK) {
  120. MS_LOG(ERROR) << "Quant failed " << status;
  121. ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
  122. return nullptr;
  123. }
  124. }
  125. return new_graph;
  126. }
  127. } // namespace lite
  128. } // namespace mindspore