Browse Source

optimize anf transform

pull/14283/head
mengyuanli 4 years ago
parent
commit
d734528f76
2 changed files with 36 additions and 36 deletions
  1. +31
    -34
      mindspore/lite/tools/converter/anf_transform.cc
  2. +5
    -2
      mindspore/lite/tools/converter/anf_transform.h

+ 31
- 34
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -298,10 +298,6 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
MS_LOG(ERROR) << "config should be specified"; MS_LOG(ERROR) << "config should be specified";
return nullptr; return nullptr;
} }
if (old_graph->has_flag("HasTransformed")) {
old_graph->set_flag("HasTransformed", false);
return old_graph;
}


auto status = RunPrecedingPass(old_graph, *config); auto status = RunPrecedingPass(old_graph, *config);
if (status != RET_OK) { if (status != RET_OK) {
@@ -356,43 +352,44 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return new_graph; return new_graph;
} }


STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphVector *subgraphs,
std::vector<ValueNodePtr> *vnodes) {
auto nodes = TopoSort(main_graph->get_return());
void AnfTransform::GetAllFuncGraph(const FuncGraphPtr &func_graph) {
if (func_graphs_.find(func_graph) == func_graphs_.end()) {
func_graphs_.insert(func_graph);
} else {
return;
}

auto nodes = func_graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg) {
vnodes->push_back(utils::cast<ValueNodePtr>(node));
subgraphs->push_back(fg);
if (IsValueNode<FuncGraph>(node)) {
auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
GetAllFuncGraph(new_fg);
}
if (utils::isa<CNodePtr>(node)) {
auto cnode = node->cast<CNodePtr>();
for (auto &input : cnode->inputs()) {
if (input->isa<ValueNode>()) {
if (IsValueNode<FuncGraph>(input)) {
auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
GetAllFuncGraph(new_fg);
}
}
}
} }
} }
return RET_OK;
} }


FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) { FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
// transform main_graph
auto new_main_graph = TransformSingleFuncGraph(main_graph, config);
if (new_main_graph == nullptr) {
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
GetAllFuncGraph(main_graph);


// transform sub_graph
FuncGraphVector subgraphs{};
std::vector<ValueNodePtr> vnodes{};
int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return nullptr;
}
for (size_t i = 0; i < subgraphs.size(); i++) {
auto new_graph = Transform(subgraphs.at(i), config);
new_graph->set_flag("HasTransformed", true);
vnodes.at(i)->set_value(new_graph);
for (auto &fg : func_graphs_) {
auto new_main_graph = TransformSingleFuncGraph(fg, config);
if (new_main_graph == nullptr) {
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
} }

return new_main_graph;
return main_graph;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

+ 5
- 2
mindspore/lite/tools/converter/anf_transform.h View File

@@ -19,6 +19,7 @@


#include <memory> #include <memory>
#include <vector> #include <vector>
#include <set>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "tools/common/storage.h" #include "tools/common/storage.h"
@@ -38,8 +39,6 @@ class AnfTransform {
private: private:
std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr; std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr;


STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphVector *subgraphs, std::vector<ValueNodePtr> *vnodes);

FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);


static int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config); static int AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
@@ -61,6 +60,10 @@ class AnfTransform {
static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);


int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);

void GetAllFuncGraph(const FuncGraphPtr &func_graph);

std::set<FuncGraphPtr> func_graphs_{};
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


Loading…
Cancel
Save