|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "common/common_test.h"
- #include "common/py_func_graph_fetcher.h"
- #include "ir/dtype.h"
- #include "ir/manager.h"
- #include "ir/func_graph_cloner.h"
- #include "pipeline/jit/parse/parse.h"
- #include "frontend/operator/ops.h"
- #include "utils/log_adapter.h"
- #include "include/common/debug/draw.h"
- #include "utils/label.h"
-
- namespace mindspore {
-
- namespace {
- std::vector<std::string> SplitString(std::string str, std::string pattern) {
- std::string::size_type pos;
- std::vector<std::string> result;
- str += pattern;
- std::string::size_type size = str.size();
-
- for (std::string::size_type i = 0; i < size; ++i) {
- pos = str.find(pattern, i);
- if (pos < size) {
- std::string s = str.substr(i, pos - i);
- result.push_back(s);
- i = pos + pattern.size() - 1;
- }
- }
-
- return result;
- }
- } // namespace
- using std::dynamic_pointer_cast;
-
- using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
- using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
-
- class NestingSpecs;
-
- class Stage {
- public:
- explicit Stage(std::vector<std::string> specs) {
- for (auto arg : specs) {
- auto spec = SplitString(arg, "=");
- if (spec.size() <= 1) {
- continue;
- }
- std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
- specs_[ToFullString(spec[0])] = nesting;
- }
- }
-
- ~Stage() {}
-
- std::map<std::string, std::string> &subs() { return subs_; }
-
- void set_subs(const std::map<std::string, std::string> &subs) { subs_ = subs; }
-
- private:
- std::string ToFullString(std::string s) {
- if (s.find("fv") != std::string::npos) {
- s = s.replace(s.find("fv"), 2, "free_variable");
- }
-
- if (s.find("deps") != std::string::npos) {
- s = s.replace(s.find("deps"), 4, "dependencies");
- }
-
- return s;
- }
-
- std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
- std::map<std::string, std::string> subs_;
- };
-
- class NestingSpecs {
- public:
- NestingSpecs(Stage *stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
-
- ~NestingSpecs() {}
-
- std::string Name(Any node) {
- std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
- if (stage_->subs().find(name) != stage_->subs().end()) {
- return stage_->subs()[name];
- }
-
- return name;
- }
-
- void Check(std::shared_ptr<DepComputer> results) {
- if (expected_.empty() && expected_recursive_.empty()) {
- return;
- }
-
- auto parent = dynamic_pointer_cast<ParentComputer>(results);
- if (parent != nullptr) {
- CheckParent(parent);
- return;
- }
-
- auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
- if (recursive != nullptr) {
- CheckRecursive(recursive);
- return;
- }
- }
-
- private:
- void ParseSpecs(std::string specs) {
- if (specs.empty()) {
- return;
- }
-
- std::vector<std::string> str_list = SplitString(specs, ";");
- for (auto spec : str_list) {
- spec.erase(0, spec.find_first_not_of(" "));
- spec.erase(spec.find_last_not_of(" ") + 1);
- if (spec.empty()) {
- continue;
- }
- if (spec.find("->") != std::string::npos) {
- auto substr = SplitString(spec, "->");
- ASSERT_GT(substr.size(), 1);
- auto key = substr[0];
- auto value = substr[1];
- if (!value.empty()) {
- expected_[key] = {value};
- }
- } else if (spec.find(":") != std::string::npos) {
- auto substr = SplitString(spec, ":");
- ASSERT_GT(substr.size(), 1);
- auto key = substr[0];
- auto values = SplitString(substr[1], ",");
- std::set<std::string> values_set(values.begin(), values.end());
- if (!values_set.empty()) {
- expected_[key] = values_set;
- }
- } else {
- expected_recursive_[spec] = true;
- }
- }
- }
-
- void CheckParent(std::shared_ptr<ParentComputer> results) {
- std::map<std::string, std::set<std::string>> clean_results;
- for (auto &iter : results->parent_analysis()) {
- auto key = iter.first;
- auto value = iter.second;
- if (key == nullptr) {
- continue;
- }
- std::string k = Name(key);
-
- std::set<std::string> v;
- if (value != nullptr && !Name(value).empty()) {
- v.insert(Name(value));
- }
-
- if (!v.empty()) {
- clean_results[k] = v;
- }
- }
-
- ASSERT_EQ(clean_results, expected_);
- }
-
- void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
- std::map<std::string, bool> clean_results;
- for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
- auto key = iter->first;
- auto value = iter->second;
- if (key == nullptr) {
- continue;
- }
- std::string k = Name(key);
-
- clean_results[k] = value;
- }
-
- ASSERT_EQ(clean_results, expected_recursive_);
- }
-
- private:
- Stage *stage_;
- std::map<std::string, std::set<std::string>> expected_;
- std::map<std::string, bool> expected_recursive_;
- };
-
- bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
- for (auto node : manager->all_nodes()) {
- if (node->isa<CNode>()) {
- auto &inputs = node->cast<CNodePtr>()->inputs();
- for (size_t i = 0; i < inputs.size(); ++i) {
- auto inp = inputs[i];
- if (!manager->all_nodes().contains(inp)) {
- return false;
- }
-
- if (manager->node_users().find(inp) != manager->node_users().end()) {
- auto users = manager->node_users()[inp];
- if (!users.contains(make_pair(node, i))) {
- return false;
- }
- }
- }
- }
-
- if (manager->node_users().find(node) != manager->node_users().end()) {
- auto users = manager->node_users()[node];
- for (auto iter = users.begin(); iter != users.end(); ++iter) {
- auto node2 = iter->first;
- auto key = iter->second;
- if (!manager->all_nodes().contains(node2)) {
- return false;
- }
- if (node2->cast<CNodePtr>()->input(key) != node) {
- return false;
- }
- }
- }
- }
-
- return true;
- }
-
- class TestManager : public UT::Common {
- public:
- TestManager() : getPyFun("gtest_input.ir.manager_test") {}
-
- void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
-
- public:
- std::vector<PrimitivePtr> swaps;
- UT::PyFuncGraphFetcher getPyFun;
- };
-
- FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
- FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
- ParameterPtr x = func_graph->add_parameter();
- ParameterPtr y = func_graph->add_parameter();
- std::vector<AnfNodePtr> inputs;
- inputs.push_back(NewValueNode(prim));
- inputs.push_back(x);
- inputs.push_back(y);
- CNodePtr cnode_add = func_graph->NewCNode(inputs);
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- inputs.push_back(cnode_add);
- CNodePtr cnode_return = func_graph->NewCNode(inputs);
- func_graph->set_return(cnode_return);
- return func_graph;
- }
-
- std::vector<FuncGraphPtr> MakeNestedGraph() {
- /*
- *def f(x):
- * def g():
- * return x
- * return g
- */
- FuncGraphPtr f = std::make_shared<FuncGraph>();
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
-
- ParameterPtr x = f->add_parameter();
-
- std::vector<AnfNodePtr> inputs;
- inputs.push_back(NewValueNode(fg));
- inputs.push_back(NewValueNode(prim::kPrimReturn));
-
- CNodePtr cnode_f = f->NewCNode(inputs);
- f->set_return(cnode_f);
-
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- inputs.push_back(x);
- CNodePtr cnode_g = fg->NewCNode(inputs);
- fg->set_return(cnode_g);
-
- std::vector<FuncGraphPtr> result = {f, fg};
- return result;
- }
-
- std::vector<FuncGraphPtr> MakeNestedGraph2() {
- /* build a closure func_graph */
- /*
- *def foo(x, y):
- * def bar(x1):
- * return x1 + y
- * return bar(x)
- */
- FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
- ParameterPtr x = graph_foo->add_parameter();
- ParameterPtr y = graph_foo->add_parameter();
-
- std::vector<AnfNodePtr> inputs;
-
- // build func_graph bar
- FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
- ParameterPtr x1 = graph_bar->add_parameter();
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
- inputs.push_back(x1);
- inputs.push_back(y);
- CNodePtr cnode_add = graph_bar->NewCNode(inputs);
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- inputs.push_back(cnode_add);
- CNodePtr cnode_return = graph_bar->NewCNode(inputs);
- graph_bar->set_return(cnode_return);
-
- // build func_graph foo
- inputs.clear();
- inputs.push_back(NewValueNode(graph_bar));
- inputs.push_back(x);
- CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
-
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- inputs.push_back(cnode_graph_bar);
- cnode_return = graph_foo->NewCNode(inputs);
- graph_foo->set_return(cnode_return);
-
- std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
- return result;
- }
-
- // Add TestManager::CheckManager function to checkout the result
- void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
- auto size = mng->func_graphs().size();
-
- ASSERT_EQ(size, mng->free_variables_total().size());
- }
-
- TEST_F(TestManager, test_scalar_add_manual) {
- auto prim_scalar_add = prim::kPrimScalarAdd;
- FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
- auto mng = Manage(func_graph);
- }
-
- TEST_F(TestManager, test_scalar_replace) {
- auto prim_scalar_add = prim::kPrimScalarAdd;
-
- FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
- ParameterPtr x = func_graph->add_parameter();
- ParameterPtr y = func_graph->add_parameter();
- std::vector<AnfNodePtr> inputs;
- inputs.push_back(NewValueNode(prim_scalar_add));
- inputs.push_back(x);
- inputs.push_back(y);
- CNodePtr cnode_add = func_graph->NewCNode(inputs);
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- inputs.push_back(cnode_add);
- CNodePtr cnode_return = func_graph->NewCNode(inputs);
- func_graph->set_return(cnode_return);
-
- auto mng = Manage(func_graph);
- std::cout << "start " << x->ToString() << std::endl;
- mng->Replace(cnode_add, x);
- }
-
- TEST_F(TestManager, test_nested_manual) {
- auto graphs = MakeNestedGraph();
- auto f = graphs[0];
- auto g = graphs[1];
-
- auto mng = Manage(f);
-
- ASSERT_EQ(6, mng->all_nodes().size());
- ASSERT_EQ(2, mng->func_graphs().size());
- ASSERT_EQ(4, mng->node_users().size());
- ASSERT_EQ(1, mng->roots().size());
- CheckAnalysisSize(mng);
-
- ASSERT_EQ(2, f->nodes().size());
- ASSERT_EQ(1, g->nodes().size());
-
- auto &users = mng->node_users();
- for (auto &iter : users) {
- ASSERT_EQ(1, iter.second.size());
- }
-
- ASSERT_EQ(1, f->func_graphs_used().size());
- ASSERT_EQ(0, g->func_graphs_used().size());
-
- ASSERT_EQ(0, f->free_variables().size());
- ASSERT_EQ(1, g->free_variables().size());
-
- auto fv_total = mng->free_variables_total();
- ASSERT_EQ(0, fv_total[f].size());
- ASSERT_EQ(1, fv_total[g].size());
-
- ASSERT_EQ(0, f->func_graph_cnodes_index().size());
- ASSERT_EQ(1, g->func_graph_cnodes_index().size());
- }
-
- TEST_F(TestManager, test_deep_nested2_manual) {
- // create parser
- FuncGraphPtr func_graph = getPyFun("test_custom");
- return;
-
- // parse ast to func graph
- FuncGraphPtr gfn = BasicClone(func_graph);
- if (gfn == nullptr) {
- return;
- }
-
- auto mng = Manage(gfn);
-
- ASSERT_EQ(3, mng->func_graphs().size());
- ASSERT_EQ(1, mng->roots().size());
- ASSERT_EQ(4, gfn->nodes().size());
- ASSERT_EQ(20, mng->all_nodes().size());
- ASSERT_EQ(25, mng->node_users().size());
- CheckAnalysisSize(mng);
- }
-
- TEST_F(TestManager, test_deep_nested_manual) {
- FuncGraphPtr f = std::make_shared<FuncGraph>();
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
- FuncGraphPtr h = std::make_shared<FuncGraph>();
-
- ParameterPtr x = f->add_parameter();
- ParameterPtr y = f->add_parameter();
- ParameterPtr z = f->add_parameter();
-
- std::vector<AnfNodePtr> inputs;
- inputs.push_back(NewValueNode(fg));
- inputs.push_back(x);
- inputs.push_back(y);
- CNodePtr cnode_1 = f->NewCNode(inputs);
-
- inputs.clear();
- inputs.push_back(cnode_1);
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- CNodePtr cnode_0 = f->NewCNode(inputs);
- f->set_return(cnode_0);
-
- ParameterPtr x1 = fg->add_parameter();
- ParameterPtr y1 = fg->add_parameter();
- inputs.clear();
- inputs.push_back(NewValueNode(h));
- inputs.push_back(x1);
- CNodePtr cnode_3 = fg->NewCNode(inputs);
-
- inputs.clear();
- inputs.push_back(cnode_3);
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- CNodePtr cnode_2 = fg->NewCNode(inputs);
- fg->set_return(cnode_2);
-
- ParameterPtr x2 = h->add_parameter();
-
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
- inputs.push_back(x2);
- inputs.push_back(y1);
- CNodePtr cnode_6 = h->NewCNode(inputs);
-
- inputs.clear();
- inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
- inputs.push_back(z);
- inputs.push_back(cnode_6);
- CNodePtr cnode_5 = h->NewCNode(inputs);
-
- inputs.clear();
- inputs.push_back(cnode_5);
- inputs.push_back(NewValueNode(prim::kPrimReturn));
- CNodePtr cnode_4 = h->NewCNode(inputs);
- h->set_return(cnode_4);
-
- auto mng = Manage(f);
-
- ASSERT_EQ(3, mng->func_graphs().size());
- ASSERT_EQ(1, mng->roots().size());
- ASSERT_EQ(20, mng->all_nodes().size());
- CheckAnalysisSize(mng);
- }
-
- TEST_F(TestManager, test_parent1_manual) {
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
-
- Parameter param(fg);
- std::vector<AnfNodePtr> params;
- CNodePtr app = std::make_shared<CNode>(params, fg);
- fg->set_return(app);
- fg->set_parameters(params);
-
- std::shared_ptr<FuncGraphManager> manager = MakeManager();
- manager->AddFuncGraph(fg, true);
- FuncGraphPtr p = fg->parent();
- assert(p == nullptr);
- }
-
- TEST_F(TestManager, test_parent_manual) {
- auto prim_scalar_add = prim::kPrimScalarAdd;
- FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
-
- std::shared_ptr<FuncGraphManager> manager = MakeManager();
- manager->AddFuncGraph(fg);
- FuncGraphPtr p = fg->parent();
- assert(p == nullptr);
- }
-
- TEST_F(TestManager, test_flat) {
- std::vector<std::shared_ptr<Stage>> stages;
- std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
- std::map<std::string, int> size_list;
- size_list["nodes"] = 2;
- }
-
- TEST_F(TestManager, test_nested) {
- std::vector<std::shared_ptr<Stage>> stages;
- std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
- std::map<std::string, int> size_list;
- return;
- }
-
- TEST_F(TestManager, test_calls) {
- std::vector<std::shared_ptr<Stage>> stages;
- std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
- "fvs_direct=h:a", "fvs_total=h:a; g:h"};
- std::map<std::string, int> size_list;
- return;
- }
-
- TEST_F(TestManager, test_unused_param) {
- std::vector<std::shared_ptr<Stage>> stages;
- std::vector<std::string> specs = {"nodes=X:x,y"};
- std::map<std::string, int> size_list;
- }
-
- TEST_F(TestManager, test_cannot_replace_return) {
- FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
- ASSERT_NE(fg, nullptr);
-
- auto mng = Manage(fg);
- ASSERT_EQ(fg->manager(), mng);
-
- ASSERT_NE(mng, nullptr);
- ASSERT_GT(fg->parameters().size(), 0);
- ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
- }
-
- TEST_F(TestManager, test_weak_manager) {
- FuncGraphPtr fg = getPyFun("ir_get_fn");
-
- auto mng1 = MakeManager({fg}, false);
- ASSERT_EQ(fg->manager(), nullptr);
- auto mng2 = MakeManager({fg}, true);
- ASSERT_EQ(fg->manager(), mng2);
- auto mng3 = MakeManager({fg}, false);
- ASSERT_EQ(fg->manager(), mng2);
- }
-
- TEST_F(TestManager, test_drop_root) {
- FuncGraphPtr fg = getPyFun("ir_get_fn");
-
- auto mng = Manage(fg);
- const auto &fgs = mng->func_graphs();
- ASSERT_TRUE(fgs.contains(fg));
- FuncGraphSet s;
- s.add(fg);
- mng->MaybeDropFuncGraphs(s);
- ASSERT_TRUE(fgs.contains(fg));
- }
-
- TEST_F(TestManager, test_keep_roots) {
- FuncGraphPtr fg1 = getPyFun("ir_get_fn");
- FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
-
- auto mng = Manage(fg1);
- ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
- ASSERT_TRUE(mng->func_graphs().contains(fg1));
-
- mng->AddFuncGraph(fg2);
- ASSERT_EQ(mng->func_graphs().size(), 2);
- ASSERT_TRUE(mng->func_graphs().contains(fg2));
-
- mng->KeepRoots();
- ASSERT_EQ(mng->func_graphs().size(), 1);
- ASSERT_TRUE(mng->func_graphs().contains(fg1));
-
- mng->KeepRoots({fg2});
- ASSERT_EQ(mng->func_graphs().size(), 1);
- ASSERT_TRUE(mng->func_graphs().contains(fg2));
- }
-
- TEST_F(TestManager, test_keep_roots_recursion) {
- return;
-
- FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
- ASSERT_NE(fg, nullptr);
- auto mng = Manage(fg);
- parse::ResolveAll(mng);
-
- ASSERT_NE(mng, nullptr);
- ASSERT_EQ(mng->func_graphs().size(), 4);
-
- ASSERT_GT(fg->parameters().size(), 0);
- mng->Replace(fg->output(), fg->parameters()[0]);
- ASSERT_EQ(mng->func_graphs().size(), 3);
-
- mng->KeepRoots();
- ASSERT_EQ(mng->func_graphs().size(), 1);
- }
-
- TEST_F(TestManager, test_add_edge_replace) {
- // fg(x, y, u):
- // x1 = load(x, u)
- // a = add(x1, y)
- // u1 = update_state(u, x1);
- // out = depend(a, u1)
- // return out
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
- auto x = fg->add_parameter();
- auto y = fg->add_parameter();
- auto u = fg->add_parameter();
- auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
- auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
- auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
- auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
- fg->set_output(out);
-
- // Create manager.
- auto mgr = Manage(fg);
- ASSERT_NE(mgr, nullptr);
-
- // Before AddEdge.
- // a = add(x1, y)
- // u1 = update_state(u, x1);
- // out = depend(a, u1)
- auto a_users = mgr->node_users()[a];
- ASSERT_EQ(a_users.size(), 1);
-
- mgr->AddEdge(u1, a);
-
- // After AddEdge.
- // a = add(x1, y)
- // u1 = update_state(u, x1, a);
- // out = depend(a, u1)
- a_users = mgr->node_users()[a];
- ASSERT_EQ(a_users.size(), 2);
-
- // Remove edge by replace update_state.
- auto u2 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
- mgr->Replace(u1, u2);
-
- // After replace update_state.
- // a = add(x1, y)
- // u2 = update_state(u, x1);
- // out = depend(a, u2)
- a_users = mgr->node_users()[a];
- ASSERT_EQ(a_users.size(), 1);
-
- mgr->AddEdge(u2, a);
-
- // After AddEdge to u2.
- // a = add(x1, y)
- // u2 = update_state(u, x1, a);
- // out = depend(a, u2)
- a_users = mgr->node_users()[a];
- ASSERT_EQ(a_users.size(), 2);
- }
-
- TEST_F(TestManager, test_add_edge_replace_new) {
- // fg(x, y, u):
- // x1 = load(x, u)
- // a = add(x1, y)
- // u1 = update_state(u, x1);
- // out = depend(a, u1)
- // return out
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
- auto x = fg->add_parameter();
- auto y = fg->add_parameter();
- auto u = fg->add_parameter();
- auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
- auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
- auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
- auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
- fg->set_output(out);
-
- // Create manager.
- auto mgr = Manage(fg);
- ASSERT_NE(mgr, nullptr);
-
- auto new_add = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
- mgr->AddEdge(u1, new_add);
-
- // x1 = load(x, u)
- // a = add(x1, y)
- // new_add = add(x1, y)
- // u1 = update_state(u, x1, new_add);
- // out = depend(a, u1)
- // return out
- ASSERT_EQ(mgr->node_users()[x1].size(), 3);
- ASSERT_EQ(mgr->node_users()[y].size(), 2);
- ASSERT_EQ(mgr->node_users()[new_add].size(), 1);
-
- auto new_add1 = fg->NewCNode({NewValueNode(prim::kPrimAdd), y, y});
- mgr->Replace(new_add, new_add1);
-
- // x1 = load(x, u)
- // a = add(x1, y)
- // new_add1 = add(y, y)
- // u1 = update_state(u, x1, new_add1);
- // out = depend(a, u1)
- // return out
- ASSERT_EQ(mgr->node_users()[x1].size(), 2);
- ASSERT_EQ(mgr->node_users()[y].size(), 3);
- ASSERT_EQ(mgr->node_users()[new_add].size(), 0);
- ASSERT_EQ(mgr->node_users()[new_add1].size(), 1);
- }
-
- TEST_F(TestManager, test_set_edge) {
- // fg(x, y, u):
- // t = make_tuple(x, y)
- // d = depend(t, u);
- // get_item = tuple_get_item(d, 0)
- // return get_item
- FuncGraphPtr fg = std::make_shared<FuncGraph>();
- auto x = fg->add_parameter();
- auto y = fg->add_parameter();
- auto u = fg->add_parameter();
- auto t = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), x, y});
- auto d = fg->NewCNode({NewValueNode(prim::kPrimDepend), t, u});
- auto get_item = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), d, NewValueNode(0)});
- fg->set_output(get_item);
-
- // Create manager.
- auto mgr = Manage(fg);
- ASSERT_NE(mgr, nullptr);
-
- // Before SetEdge.
- ASSERT_EQ(mgr->node_users()[t].size(), 1);
- ASSERT_EQ(mgr->node_users()[d].size(), 1);
-
- auto depend = get_item->input(1)->cast<CNodePtr>();
- mgr->SetEdge(get_item, 1, depend->input(1));
-
- // After SetEdge.
- ASSERT_EQ(get_item->input(1), t);
- ASSERT_EQ(depend->input(1), t);
- ASSERT_EQ(mgr->node_users()[d].size(), 0);
- ASSERT_EQ(mgr->node_users()[t].size(), 1); // depend removed.
- ASSERT_EQ(mgr->node_users()[t].front().first, get_item);
- }
-
- } // namespace mindspore
|