|
|
|
@@ -59,62 +59,8 @@ AnfTransform::AnfTransform() = default; |
|
|
|
|
|
|
|
AnfTransform::~AnfTransform() = default; |
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
MS_ASSERT(nullptr != old_graph); |
|
|
|
if (config == nullptr) { |
|
|
|
MS_LOG(ERROR) << "config should be specified"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (old_graph->has_flag("HasTransformed")) { |
|
|
|
old_graph->set_flag("HasTransformed", false); |
|
|
|
return old_graph; |
|
|
|
} |
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) { |
|
|
|
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); |
|
|
|
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); |
|
|
|
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); |
|
|
|
|
|
|
|
if (config->fmk == converter::FmkType_MS) { |
|
|
|
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); |
|
|
|
mindir_adjust_pass->SetFmkType(config->fmk); |
|
|
|
mindir_adjust_pass->SetQuantType(config->quantType); |
|
|
|
if (!mindir_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "mindir adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>(); |
|
|
|
if (!mindir_inputs_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "mindir inputs adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// onnx pre adjustment |
|
|
|
if (config->fmk == converter::FmkType_ONNX) { |
|
|
|
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>(); |
|
|
|
if (!onnx_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "onnx adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (config->fmk == lite::converter::FmkType_TF) { |
|
|
|
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>(); |
|
|
|
if (!functionalize_control_op_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "functionalize control op pass failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || |
|
|
|
config->fmk == lite::converter::FmkType_ONNX) { |
|
|
|
graph_pm->AddPass(std::make_shared<opt::WhilePass>()); |
|
|
|
graph_pm->AddPass(std::make_shared<opt::IfPass>()); |
|
|
|
} |
|
|
|
|
|
|
|
// for now - training is not supporting fuse operations |
|
|
|
if (!config->trainModel) { |
|
|
|
@@ -137,26 +83,11 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>()); |
|
|
|
fusion_pm->AddPass(std::make_shared<opt::BiDirectionTfGruCellFusion>()); |
|
|
|
} |
|
|
|
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); |
|
|
|
weight_format_hardcode_pass->SetFmkType(config->fmk); |
|
|
|
weight_format_hardcode_pass->SetQuantType(config->quantType); |
|
|
|
graph_pm->AddPass(weight_format_hardcode_pass); |
|
|
|
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>(); |
|
|
|
weight_format_transform_pass->SetFmkType(config->fmk); |
|
|
|
weight_format_transform_pass->SetQuantType(config->quantType); |
|
|
|
graph_pm->AddPass(weight_format_transform_pass); |
|
|
|
auto infershape_pass = std::make_shared<opt::InferShapePass>(); |
|
|
|
infershape_pass->SetFmkType(config->fmk); |
|
|
|
graph_pm->AddPass(infershape_pass); |
|
|
|
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>(); |
|
|
|
slice_prepose_pass->SetFmkType(config->fmk); |
|
|
|
graph_pm->AddPass(slice_prepose_pass); |
|
|
|
|
|
|
|
if (config->fmk == lite::converter::FmkType_MS) { |
|
|
|
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>(); |
|
|
|
if (remove_unused_cast_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified"; |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
remove_unused_cast_pass->SetFmkType(config->fmk); |
|
|
|
fusion_pm->AddPass(remove_unused_cast_pass); |
|
|
|
@@ -165,11 +96,55 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>(); |
|
|
|
if (remove_unused_transpose_pass == nullptr) { |
|
|
|
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified"; |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
remove_unused_transpose_pass->SetFmkType(config->fmk); |
|
|
|
fusion_pm->AddPass(remove_unused_transpose_pass); |
|
|
|
} |
|
|
|
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); |
|
|
|
optimizer->AddPassManager(fusion_pm); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config) { |
|
|
|
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); |
|
|
|
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || |
|
|
|
config->fmk == lite::converter::FmkType_ONNX) { |
|
|
|
graph_pm->AddPass(std::make_shared<opt::WhilePass>()); |
|
|
|
graph_pm->AddPass(std::make_shared<opt::IfPass>()); |
|
|
|
} |
|
|
|
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); |
|
|
|
weight_format_hardcode_pass->SetFmkType(config->fmk); |
|
|
|
weight_format_hardcode_pass->SetQuantType(config->quantType); |
|
|
|
graph_pm->AddPass(weight_format_hardcode_pass); |
|
|
|
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>(); |
|
|
|
weight_format_transform_pass->SetFmkType(config->fmk); |
|
|
|
weight_format_transform_pass->SetQuantType(config->quantType); |
|
|
|
graph_pm->AddPass(weight_format_transform_pass); |
|
|
|
auto infershape_pass = std::make_shared<opt::InferShapePass>(); |
|
|
|
infershape_pass->SetFmkType(config->fmk); |
|
|
|
graph_pm->AddPass(infershape_pass); |
|
|
|
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>(); |
|
|
|
slice_prepose_pass->SetFmkType(config->fmk); |
|
|
|
graph_pm->AddPass(slice_prepose_pass); |
|
|
|
optimizer->AddPassManager(graph_pm); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, |
|
|
|
const converter::Flags *config) { |
|
|
|
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); |
|
|
|
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); |
|
|
|
if (config->fmk == lite::converter::FmkType_TFLITE) { |
|
|
|
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); |
|
|
|
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); |
|
|
|
} |
|
|
|
optimizer->AddPassManager(convert_pm); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, |
|
|
|
const converter::Flags *config) { |
|
|
|
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); |
|
|
|
if (!config->trainModel) { |
|
|
|
auto inne_context_ptr = std::make_shared<lite::InnerContext>(); |
|
|
|
@@ -179,47 +154,90 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>(); |
|
|
|
update_conv2d_param_pass->SetFmkType(config->fmk); |
|
|
|
const_fold_pm->AddPass(update_conv2d_param_pass); |
|
|
|
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>()); |
|
|
|
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>()); |
|
|
|
if (config->fmk == lite::converter::FmkType_TFLITE) { |
|
|
|
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>()); |
|
|
|
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>()); |
|
|
|
} |
|
|
|
optimizer->AddPassManager(const_fold_pm); |
|
|
|
optimizer->AddPassManager(convert_pm); |
|
|
|
optimizer->AddPassManager(fusion_pm); |
|
|
|
optimizer->AddPassManager(graph_pm); |
|
|
|
auto new_graph = optimizer->Optimize(old_graph); |
|
|
|
if (new_graph == nullptr) { |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); |
|
|
|
return nullptr; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
switch (config->fmk) { |
|
|
|
case converter::FmkType_MS: |
|
|
|
return RunMindirAdjustPass(old_graph, config); |
|
|
|
case converter::FmkType_ONNX: |
|
|
|
return RunOnnxAdjustPass(old_graph, config); |
|
|
|
case converter::FmkType_TF: |
|
|
|
return RunTFAdjustPass(old_graph, config); |
|
|
|
default: |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>(); |
|
|
|
mindir_adjust_pass->SetFmkType(config->fmk); |
|
|
|
mindir_adjust_pass->SetQuantType(config->quantType); |
|
|
|
if (!mindir_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "mindir adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>(); |
|
|
|
if (!mindir_inputs_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "mindir inputs adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
// onnx pre adjustment |
|
|
|
auto onnx_adjust_pass = std::make_shared<opt::OnnxInputAdjustOpPass>(); |
|
|
|
if (!onnx_adjust_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "onnx adjust failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>(); |
|
|
|
if (!functionalize_control_op_pass->Run(old_graph)) { |
|
|
|
MS_LOG(ERROR) << "functionalize control op pass failed."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, |
|
|
|
const FuncGraphPtr &new_graph) { |
|
|
|
// quant |
|
|
|
if (config->quantType == schema::QuantType_PostTraining) { |
|
|
|
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { |
|
|
|
MS_LOG(ERROR) << "bitNum must be valid pos num."; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->mQuantizer = |
|
|
|
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum)); |
|
|
|
if (mQuantizer == nullptr) { |
|
|
|
MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} else if (config->quantType == schema::QuantType_WeightQuant) { |
|
|
|
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "weight quant input param error"; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize, |
|
|
|
config->quantWeightChannel, config->bitNum); |
|
|
|
if (mQuantizer == nullptr) { |
|
|
|
MS_LOG(ERROR) << "New WeightQuantizer failed"; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
if (mQuantizer != nullptr) { |
|
|
|
@@ -228,9 +246,65 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Quant failed " << status; |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); |
|
|
|
return nullptr; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { |
|
|
|
MS_ASSERT(nullptr != old_graph); |
|
|
|
if (config == nullptr) { |
|
|
|
MS_LOG(ERROR) << "config should be specified"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (old_graph->has_flag("HasTransformed")) { |
|
|
|
old_graph->set_flag("HasTransformed", false); |
|
|
|
return old_graph; |
|
|
|
} |
|
|
|
|
|
|
|
auto status = RunAdjustPass(old_graph, config); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Run Adjust pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
|
|
|
|
status = AddConstFoldPass(optimizer, config); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Add const fold pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
status = AddConvertPass(optimizer, config); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Add convert pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
status = AddFusionPass(optimizer, config); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Add fusion pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
status = AddGraphPass(optimizer, config); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Add graph pass failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto new_graph = optimizer->Optimize(old_graph); |
|
|
|
if (new_graph == nullptr) { |
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
status = DoQuantize(old_graph, config, new_graph); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Do Quantize failed."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return new_graph; |
|
|
|
} |
|
|
|
|
|
|
|
|