Browse Source

!1359 Optimize the IR modules.

Merge pull request !1359 from ZhangQinghua/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
848d19207f
12 changed files with 438 additions and 587 deletions
  1. +1
    -0
      mindspore/ccsrc/ir/base.h
  2. +198
    -33
      mindspore/ccsrc/ir/func_graph.cc
  3. +61
    -7
      mindspore/ccsrc/ir/func_graph.h
  4. +6
    -6
      mindspore/ccsrc/ir/func_graph_cloner.cc
  5. +115
    -275
      mindspore/ccsrc/ir/manager.cc
  6. +33
    -164
      mindspore/ccsrc/ir/manager.h
  7. +1
    -1
      mindspore/ccsrc/optimizer/ad/dfunctor.cc
  8. +2
    -2
      mindspore/ccsrc/optimizer/irpass/branch_culling.cc
  9. +1
    -1
      mindspore/ccsrc/pipeline/action.cc
  10. +5
    -5
      mindspore/ccsrc/vm/transform.cc
  11. +12
    -89
      tests/ut/cpp/ir/manager_test.cc
  12. +3
    -4
      tests/ut/cpp/optimizer/cconv_test.cc

+ 1
- 0
mindspore/ccsrc/ir/base.h View File

@@ -29,6 +29,7 @@
#include "utils/visible.h" #include "utils/visible.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ordered_set.h" #include "utils/ordered_set.h"
#include "utils/ordered_map.h"


namespace mindspore { namespace mindspore {
template <typename T> template <typename T>


+ 198
- 33
mindspore/ccsrc/ir/func_graph.cc View File

@@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
: flags_(), : flags_(),
transforms_(), transforms_(),
parameter_default_value_(), parameter_default_value_(),
seen_(0),
parameters_(), parameters_(),
has_vararg_(false), has_vararg_(false),
has_kwarg_(false), has_kwarg_(false),
@@ -195,25 +196,93 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return this->debug_info_; return this->debug_info_;
} }


const AnfNodeSet &FuncGraph::nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &nodes = mng->nodes();
return nodes[shared_from_base<FuncGraph>()];
const AnfNodeSet &FuncGraph::nodes() { return nodes_; }

void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }

void FuncGraph::ClearNodes() { nodes_.clear(); }

void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); }

void FuncGraph::DropNode(AnfNodePtr node) {
nodes_.erase(node);
auto graph = node->func_graph();
// Remove the node from order list.
if (graph) {
graph->EraseUnusedNodeInOrder(node);
}
} }


const AnfNodeCounterMap &FuncGraph::value_nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &cts = mng->valuenodes();
return cts[shared_from_base<FuncGraph>()];
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }

void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
auto &others = source->value_nodes();
for (auto it = others.begin(); it != others.end(); it++) {
AddValueNode(it->first, it->second);
}
} }


const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &fv_direct = mng->free_variables_direct();
return fv_direct[shared_from_base<FuncGraph>()];
void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }

void FuncGraph::AddValueNode(AnfNodePtr node, int count) {
if (value_nodes_.count(node) == 0) {
value_nodes_[node] = count;
} else {
value_nodes_[node] += count;
}
}

void FuncGraph::DropValueNode(AnfNodePtr node) {
if (value_nodes_.count(node) != 0) {
if (value_nodes_[node] == 1) {
(void)value_nodes_.erase(node);
} else {
value_nodes_[node]--;
if (value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}

const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }

void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
auto &others = source->free_variables();
for (auto it = others.begin(); it != others.end(); it++) {
if (it->first->func_graph().get() != this) {
(void)AddFreeVariable(it->first, it->second);
}
}
}

void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }

bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) {
if (free_variables_.count(node) == 0) {
free_variables_[node] = count;
return true;
} else {
free_variables_[node] += count;
return false;
}
}

bool FuncGraph::DropFreeVariable(AnfNodePtr node) {
if (free_variables_.count(node) != 0) {
if (free_variables_[node] == 1) {
(void)free_variables_.erase(node);
return true;
} else {
free_variables_[node]--;
if (free_variables_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of free variable '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
} }


const BaseRefCounterMap &FuncGraph::free_variables_total() { const BaseRefCounterMap &FuncGraph::free_variables_total() {
@@ -249,11 +318,42 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs; return func_graphs;
} }


const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &used = mng->func_graphs_used();
return used[shared_from_base<FuncGraph>()];
const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }

void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
auto &others = source->func_graphs_used();
for (auto it = others.begin(); it != others.end(); it++) {
(void)AddFuncGraphUsed(it->first, it->second);
}
func_graphs_used_.erase(source);
}

void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }

bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
if (func_graphs_used_.count(fg) == 0) {
func_graphs_used_[fg] = count;
return true;
} else {
func_graphs_used_[fg] += count;
return false;
}
}

bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
if (func_graphs_used_.count(fg) != 0) {
if (func_graphs_used_[fg] == 1) {
(void)func_graphs_used_.erase(fg);
return true;
} else {
func_graphs_used_[fg]--;
if (func_graphs_used_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
} }


const FuncGraphSet &FuncGraph::func_graphs_used_total() { const FuncGraphSet &FuncGraph::func_graphs_used_total() {
@@ -263,15 +363,75 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return used; return used;
} }


const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() {
auto mng = manager_.lock();
if (mng == nullptr) {
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
<< " NodeInfo: " << trace::GetDebugInfo(debug_info());
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }

void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
auto &others = source->func_graph_cnodes_index();
for (auto it = others.begin(); it != others.end(); it++) {
// Ignore the user graph who may own itself.
auto fg = it->first->first->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg.get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second);
}
}
}

void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }

void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) {
if (func_graph_cnodes_index_.count(pair) == 0) {
func_graph_cnodes_index_[pair] = count;
} else {
func_graph_cnodes_index_[pair] += count;
}
}

void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
if (func_graph_cnodes_index_.count(pair) != 0) {
if (func_graph_cnodes_index_[pair] == 1) {
(void)func_graph_cnodes_index_.erase(pair);
} else {
func_graph_cnodes_index_[pair]--;
if (func_graph_cnodes_index_[pair] < 0) {
MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}

const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }

void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
auto &others = source->j_func_graphs();
for (auto it = others.begin(); it != others.end(); it++) {
AddJFuncGraph(it->first, it->second);
}
}

void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }

void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
if (j_func_graphs_.count(fg) == 0) {
j_func_graphs_[fg] = count;
} else {
j_func_graphs_[fg] += count;
}
}

void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
if (j_func_graphs_.count(fg) != 0) {
if (j_func_graphs_[fg] == 1) {
(void)j_func_graphs_.erase(fg);
} else {
j_func_graphs_[fg]--;
if (j_func_graphs_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
} }
MS_EXCEPTION_IF_NULL(mng);
auto &cnode = mng->func_graph_cnodes_index();
return cnode[shared_from_base<FuncGraph>()];
} }


FuncGraphPtr FuncGraph::parent() { FuncGraphPtr FuncGraph::parent() {
@@ -662,10 +822,10 @@ void FuncGraph::EraseUnusedNodeInOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng) { if (mng) {
auto nodes = mng->nodes()[shared_from_base<FuncGraph>()];
auto &all_nodes = nodes();
// Erase unused cnode. // Erase unused cnode.
for (auto it = order_.begin(); it != order_.end();) { for (auto it = order_.begin(); it != order_.end();) {
if (nodes.count(*it)) {
if (all_nodes.count(*it)) {
(void)it++; (void)it++;
} else { } else {
MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order.";
@@ -702,11 +862,11 @@ void FuncGraph::CheckOrder() {
} }
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng != nullptr) { if (mng != nullptr) {
const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
if (nodes.size() != (order_.size() + parameters_.size())) {
const auto &all_nodes = nodes();
if (all_nodes.size() != (order_.size() + parameters_.size())) {
DumpCNodeList(); DumpCNodeList();
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
<< nodes.size() - parameters_.size() << ".";
<< all_nodes.size() - parameters_.size() << ".";
} }
} }
MS_LOG(DEBUG) << "Check order okay."; MS_LOG(DEBUG) << "Check order okay.";
@@ -840,6 +1000,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
} }
} }


size_t NewFgSeenGeneration() {
static size_t fg_seen_generation = 0;
return ++fg_seen_generation;
}

const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph"); const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
const char kFuncGraphFlagUndetermined[] = "Undeterminate"; const char kFuncGraphFlagUndetermined[] = "Undeterminate";
} // namespace mindspore } // namespace mindspore

+ 61
- 7
mindspore/ccsrc/ir/func_graph.h View File

@@ -26,6 +26,7 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <functional>


#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
@@ -36,8 +37,13 @@
namespace mindspore { namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;

template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;

using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;


const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {


// get all nodes belonging to this func graph // get all nodes belonging to this func graph
const AnfNodeSet &nodes(); const AnfNodeSet &nodes();
void CopyNodes(const FuncGraphPtr &source);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);


// get all value_nodes belonging to this func graph // get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes(); const AnfNodeCounterMap &value_nodes();

// get all vars directly pointed to in this func graph
const AnfNodeCounterMap &free_variables_direct();
void CopyValueNodes(const FuncGraphPtr &source);
void ClearValueNodes();
void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);

// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);


// get all vars required by this func graph // get all vars required by this func graph
const BaseRefCounterMap &free_variables_total(); const BaseRefCounterMap &free_variables_total();
@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs // get all vars that are func graphs
std::vector<FuncGraphPtr> free_variables_func_graphs(); std::vector<FuncGraphPtr> free_variables_func_graphs();


// get all func graphs directly used by this func graph
// get all value nodes of func graph directly used by this func graph
const FuncGraphCounterMap &func_graphs_used(); const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphsUsed();
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);

// get all value nodes of J func graph directly used by this func graph
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);


// get all func graphs nested used by this func graph // get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total(); const FuncGraphSet &func_graphs_used_total();


// get all user value nodes of this func graph
// get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index(); const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);


// Return the parent of this graph. // Return the parent of this graph.
FuncGraphPtr parent(); FuncGraphPtr parent();
@@ -256,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_; std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
size_t seen_;


std::list<CNodePtr> GetOrderedCnodes(); std::list<CNodePtr> GetOrderedCnodes();
void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(const AnfNodePtr &n);
@@ -270,6 +304,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others // graph is manipulated by manager and others
friend FuncGraphManager; friend FuncGraphManager;


// all nodes of the function
AnfNodeSet nodes_;

// all value nodes of the function
AnfNodeCounterMap value_nodes_;

// all func graph value nodes of the function
FuncGraphCounterMap func_graphs_used_;

// all free variables of the function
AnfNodeCounterMap free_variables_;

// all value nodes calling J in the function
FuncGraphCounterMap j_func_graphs_;

// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;

// parameters of this function // parameters of this function
std::vector<AnfNodePtr> parameters_; std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_; std::vector<AnfNodePtr> paramter_obj_nodes_;
@@ -313,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
return fg->NewCNode(inputs); return fg->NewCNode(inputs);
} }


size_t NewFgSeenGeneration();

// Find the root cnodes of a segment of cnodes. // Find the root cnodes of a segment of cnodes.
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment); std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
// Find the leaf cnodes of a segment of cnodes. // Find the leaf cnodes of a segment of cnodes.


+ 6
- 6
mindspore/ccsrc/ir/func_graph_cloner.cc View File

@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if (!clone_all_valuenodes_) { if (!clone_all_valuenodes_) {
return; return;
} }
auto &value_nodes = manager_->valuenodes()[func_graph];
auto &value_nodes = func_graph->value_nodes();
for (auto &value_node : value_nodes) { for (auto &value_node : value_nodes) {
auto old_node = value_node.first; auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) { if (!clone_all_used_graphs_) {
return; return;
} }
auto &used_graphs = manager_->func_graphs_used()[func_graph];
for (auto &used_graph : used_graphs) {
todo_.push_back({used_graph.first, nullptr, {}});
auto &used = func_graph->func_graphs_used();
for (auto &fg : used) {
todo_.push_back({fg.first, nullptr, {}});
} }
} }


@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
} }
target_func_graph->set_return(return_node); target_func_graph->set_return(return_node);


auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>(); auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second); auto valuenode = parent->input(cnode.first->second);
@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet &nodes = manager_->nodes()[func_graph];
const AnfNodeSet &nodes = func_graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
CloneNode(node, target_func_graph); CloneNode(node, target_func_graph);
} }


+ 115
- 275
mindspore/ccsrc/ir/manager.cc View File

@@ -78,19 +78,6 @@ void FuncGraphManager::Reset() {
node_users_ = NodeUsersMap(); node_users_ = NodeUsersMap();


signals_ = std::make_shared<Signals>(); signals_ = std::make_shared<Signals>();
// FuncGraph --> AnfNode
nodes_ = std::make_shared<NodesCollector>(this);

// FuncGraph --> {AnfNode, Count}
valuenodes_ = std::make_shared<ValueNodesCollector>(this);
free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this);

// FuncGraph --> {FuncGraph, Count}
func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);


func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this); func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
func_graph_parent_ = std::make_shared<ParentComputer>(this); func_graph_parent_ = std::make_shared<ParentComputer>(this);
@@ -209,8 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
return; return;
} }
AddIntoManaged(func_graph); AddIntoManaged(func_graph);
MS_EXCEPTION_IF_NULL(signals_);
signals_->AddFuncGraph(func_graph);
std::vector<AnfNodePtr> para = func_graph->parameters(); std::vector<AnfNodePtr> para = func_graph->parameters();
AcquireNodes(para); AcquireNodes(para);
std::vector<AnfNodePtr> return_vec({func_graph->get_return()}); std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
@@ -224,7 +209,6 @@ void FuncGraphManager::Clear() {
node_users_.clear(); node_users_.clear();
roots_.clear(); roots_.clear();


signals_->InvalidateCollector();
signals_->InvalidateComputer(); signals_->InvalidateComputer();
} }


@@ -303,8 +287,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
continue; continue;
} }
MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_);
auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph];
auto &users_cnode_index = func_graph->func_graph_cnodes_index();
if (!users_cnode_index.empty() && !ignore_users) { if (!users_cnode_index.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue; continue;
@@ -317,10 +300,8 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
std::vector<AnfNodePtr> return_vec = {func_graph->get_return()}; std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
todo.update(MaybeDropNodes(return_vec)); todo.update(MaybeDropNodes(return_vec));
} }
MS_EXCEPTION_IF_NULL(signals_);
for (auto &fg : dropped) { for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
signals_->DropFuncGraph(fg);
all_nodes_.difference_update(fg->parameters()); all_nodes_.difference_update(fg->parameters());
(void)func_graphs_.erase(fg); (void)func_graphs_.erase(fg);
if (fg->manager().get() == this) { if (fg->manager().get() == this) {
@@ -339,7 +320,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
return; return;
} }
(void)users_node.erase(make_pair(node, index)); (void)users_node.erase(make_pair(node, index));
signals_->DropEdge(node, index, inp);
DropEdge(node, index, inp);
} else { } else {
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
if (inp->func_graph() != nullptr) { if (inp->func_graph() != nullptr) {
@@ -351,8 +332,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
} }
auto &users_node = node_users_[inp]; auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index)); users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_);
signals_->AddEdge(node, index, inp);
AddEdge(node, index, inp);
} }
} }


@@ -392,8 +372,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphPtr fg = node->func_graph(); FuncGraphPtr fg = node->func_graph();
if (fg != nullptr) { if (fg != nullptr) {
AddFuncGraph(fg); AddFuncGraph(fg);
fg->AddNode(node);
} }
signals_->AddNode(node);
ProcessInputs(node, kIncEdge); ProcessInputs(node, kIncEdge);
} }
} }
@@ -401,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet nodes_ordered(nodes); AnfNodeSet nodes_ordered(nodes);
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
MS_EXCEPTION_IF_NULL(signals_);

while (!nodes_ordered.empty()) { while (!nodes_ordered.empty()) {
AnfNodePtr node = nodes_ordered.pop(); AnfNodePtr node = nodes_ordered.pop();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@@ -424,7 +402,10 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
} }
ProcessInputs(node, kDecEdge); ProcessInputs(node, kDecEdge);
(void)all_nodes_.erase(node); (void)all_nodes_.erase(node);
signals_->DropNode(node);
if (node->func_graph() != nullptr) {
node->func_graph()->DropNode(node);
}

if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
nodes_ordered.update(cnode->inputs()); nodes_ordered.update(cnode->inputs());
@@ -462,35 +443,21 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t


int index = 0; int index = 0;
(void)node_users_[source_prim].erase(make_pair(source_return, index)); (void)node_users_[source_prim].erase(make_pair(source_return, index));
signals_->DropEdge(source_return, index, source_prim);
DropEdge(source_return, index, source_prim);
index = 1; index = 1;
(void)node_users_[source_output].erase(make_pair(source_return, index)); (void)node_users_[source_output].erase(make_pair(source_return, index));
signals_->DropEdge(source_return, index, source_output);
DropEdge(source_return, index, source_output);
(void)all_nodes_.erase(source_return); (void)all_nodes_.erase(source_return);
(void)node_users_.erase(source_return); (void)node_users_.erase(source_return);
signals_->DropNode(source_return);
source->DropNode(source_return);
for (auto &node : source->nodes()) { for (auto &node : source->nodes()) {
node->set_func_graph(target); node->set_func_graph(target);
if (node->scope() == kDefaultScope) { if (node->scope() == kDefaultScope) {
node->set_scope(scope); node->set_scope(scope);
} }
} }
for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source);
}
for (auto &fv_count : this->free_variables_direct()[source]) {
auto fv_g = fv_count.first->func_graph();
auto &count_on_g = this->func_graph_child_direct()[fv_g];
auto pair = count_on_g.find(source);
if (fv_g != target && pair != count_on_g.end()) {
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
}
(void)count_on_g.erase(source);
}
signals_->MoveAllCNode(source, target);
signals_->InvalidateComputer();
signals_->DropFuncGraph(source);

MoveAllNodes(source, target);
all_nodes_.difference_update(source->parameters()); all_nodes_.difference_update(source->parameters());
(void)func_graphs_.erase(source); (void)func_graphs_.erase(source);
if (source->manager().get() == this) { if (source->manager().get() == this) {
@@ -498,6 +465,64 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
} }
} }


inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) {
auto fg = node->func_graph();
if (input->isa<ValueNode>()) {
fg->AddValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
auto used = GetValueNode<FuncGraphPtr>(input);
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->AddFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJFuncGraph(used);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->AddFreeVariable(input)) {
signals_->InvalidateComputer();
}
}
}

inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) {
auto fg = node->func_graph();
if (input->isa<ValueNode>()) {
fg->DropValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
auto used = GetValueNode<FuncGraphPtr>(input);
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->DropFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJFuncGraph(used);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->DropFreeVariable(input)) {
signals_->InvalidateComputer();
}
}
}

inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyNodes(source);
target->CopyValueNodes(source);
target->CopyFuncGraphCNodesIndex(source);
target->CopyFreeVariables(source);
target->CopyFuncGraphsUsed(source);
target->CopyJFuncGraphs(source);
signals_->InvalidateComputer();
source->ClearNodes();
source->ClearValueNodes();
source->ClearFuncGraphCNodesIndex();
source->ClearFreeVariables();
source->ClearFuncGraphsUsed();
source->ClearJFuncGraphs();
}

FuncGraphTransaction FuncGraphManager::Transact() { FuncGraphTransaction FuncGraphManager::Transact() {
auto tr = FuncGraphTransaction(this); auto tr = FuncGraphTransaction(this);
return tr; return tr;
@@ -610,54 +635,14 @@ void FuncGraphTransaction::Commit() {
} }


FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
: manager_(manager), include_func_graph_none_(false) {
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge);
manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge);
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
}

NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
include_func_graph_none_ = true;
nodes_analysis_[nullptr] = AnfNodeSet();

manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode);
manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode);
}

void NodesCollector::OnAddNode(AnfNodePtr n) {
if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) {
nodes_analysis_[n->func_graph()] = AnfNodeSet();
}

nodes_analysis_[n->func_graph()].add(n);
}

void NodesCollector::OnDropNode(AnfNodePtr n) {
(void)nodes_analysis_[n->func_graph()].erase(n);
auto graph = n->func_graph();
// Remove the node from order list.
if (graph) {
graph->EraseUnusedNodeInOrder(n);
}
}

void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// change the owner of node except for the src's return node
for (auto &it : nodes_analysis_[src]) {
nodes_analysis_[dst].add(it);
}
(void)nodes_analysis_.erase(src);
}

void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
: manager_(manager), include_func_graph_none_(false) {}


DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
} }


void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }

void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }


template <typename ValueT, class CollectorHash, class CollectorEqual> template <typename ValueT, class CollectorHash, class CollectorEqual>
@@ -706,65 +691,6 @@ bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const F
} }
} }


void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (inp->isa<ValueNode>()) {
(void)Mod(node->func_graph(), inp, direction);
}
}

void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
}

void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
direction);
}
}

void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
// Ignore the user graph who may own itself.
if (dst != it.first->first->func_graph()) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_nodes_map_.erase(src);
}

void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp);
FuncGraphPtr fg1 = node->func_graph();
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
(void)Mod(fg1, inp, direction);
}
}

void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
FuncGraphPtr fg2 = it.first->func_graph();
if (fg2 != dst) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_nodes_map_.erase(src);
}

static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
FuncGraphPtr gn = std::make_shared<FuncGraph>();
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
return gn;
}

bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph]; auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) { if (d.count(key) == 0) {
@@ -804,87 +730,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr
} }
} }


void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp);
FuncGraphPtr fg1 = node->func_graph();
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
(void)Mod(fg2, fg1, direction);
}
}

void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_func_graphs_map_[src]) {
FuncGraphPtr fg = it.first;
if (fg != dst) {
(void)Inc(dst, fg, it.second);
}
}
(void)count_func_graphs_map_.erase(src);
}

void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg1 = node->func_graph();
// possible child parent
if (IsValueNode<FuncGraph>(inp)) {
FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp);
if (Mod(fg1, ParentProxy(fg2), direction)) {
manager_->signals()->InvalidateComputer();
}
}
// from fv
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
// node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent
if (Mod(fg1, fg2, direction)) {
manager_->signals()->InvalidateComputer();
}
}
}

void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_func_graphs_map_[src]) {
if (it.first != dst) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_func_graphs_map_.erase(src);
}

void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
}
}

void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_[dst].erase(src);
(void)count_func_graphs_map_.erase(src);
}

void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph "
<< GetValueNode<FuncGraphPtr>(inp)->ToString() << " which contains J(func_graph), dir: " << direction;
}
}

void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_.erase(src);
}

DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
@@ -905,22 +750,24 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) {
} }
} }


FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
if (path == nullptr || path->contains(fg)) {
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) {
if (fg->seen_ == seen_num) {
return std::make_shared<FuncGraphSet>(); return std::make_shared<FuncGraphSet>();
} }
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
for (auto &dep : deps[fg]) {
MS_EXCEPTION_IF_NULL(dep.first);
auto proxy = dep.first->transforms().find("proxy");
if (proxy != dep.first->transforms().end()) {
path->add(fg);
auto gt = proxy->second.func_graph();
parents->update(SeekParents(gt, path));
} else {
parents->add(dep.first);
}

// Append all the fvs in fg.
auto &fvs = fg->free_variables();
for (auto fv : fvs) {
parents->add(fv.first->func_graph());
}

// Search the fv in fg's child func graph.
auto &fgs = fg->func_graphs_used();
for (auto &item : fgs) {
fg->seen_ = seen_num;
auto gt = item.first;
parents->update(SeekParents(gt, seen_num));
} }
(void)parents->erase(fg); (void)parents->erase(fg);
return parents; return parents;
@@ -928,10 +775,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f


void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
all_parents_direct_ = &(manager_->func_graph_parents_direct());
MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size();
func_graph_parents_total_analysis_[fg].update(SeekParents(fg));
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration()));
} }


bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
@@ -1001,21 +845,23 @@ void FVTotalComputer::RealRecompute() {
} }


for (auto &fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
AnfNodeCounterMap items = manager->free_variables_direct()[fg];
AnfNodeCounterMap items = fg->free_variables();
for (auto &iter : items) { for (auto &iter : items) {
auto curr = fg; auto curr = fg;
while (curr) {
while (curr != nullptr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr); curr = manager->parent(curr);
const AnfNodeSet &nodes = manager->nodes()[curr];
if (nodes.contains(iter.first)) {
break;
if (curr != nullptr) {
const AnfNodeSet &all_nodes = curr->nodes();
if (all_nodes.contains(iter.first)) {
break;
}
} }
} }
} }


auto items_fg = manager->func_graphs_used()[fg];
for (auto &iter : items_fg) {
auto &used = fg->func_graphs_used();
for (auto &iter : used) {
auto p = manager->parent(iter.first); auto p = manager->parent(iter.first);
if (p == nullptr) { if (p == nullptr) {
continue; continue;
@@ -1041,7 +887,6 @@ void FVTotalComputer::RealRecompute() {


void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
auto &used = this->manager_->func_graphs_used();
std::vector<FuncGraphPtr> todo; std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new; std::vector<FuncGraphPtr> todo_new;


@@ -1049,7 +894,7 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto &gt : todo) { for (auto &gt : todo) {
for (auto &item : used[gt]) {
for (auto &item : gt->func_graphs_used()) {
auto used_fg = item.first; auto used_fg = item.first;
if (used_fg == fg) { if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg); func_graph_used_total_analysis_[fg].add(used_fg);
@@ -1068,7 +913,6 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {


bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
auto &used = manager->func_graphs_used();
std::vector<FuncGraphPtr> todo; std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new; std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg); todo.push_back(fg);
@@ -1076,7 +920,7 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto &gt : todo) { for (auto &gt : todo) {
for (auto &item : used[gt]) {
for (auto &item : gt->func_graphs_used()) {
auto used_g = item.first; auto used_g = item.first;
if (used_g == fg) { if (used_g == fg) {
return true; return true;
@@ -1108,8 +952,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
} }
} else { } else {
trace->push_back(fg); trace->push_back(fg);
auto &used_fgs = manager_->func_graphs_used()[fg];
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
auto &items = fg->func_graphs_used();
for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
CheckRecursiveGraphs(iter->first, trace); CheckRecursiveGraphs(iter->first, trace);
} }
trace->pop_back(); trace->pop_back();
@@ -1119,31 +963,28 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
} }
} }


bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
MS_EXCEPTION_IF_NULL(path);
if (path->contains(fg)) {
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
if (fg->seen_ == seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked"; MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false; return false;
} }
MS_EXCEPTION_IF_NULL(manager_);
auto &func_graph_counter_map = manager_->func_graph_j_direct();
if (!func_graph_counter_map[fg].empty()) {
auto &j_fgs = fg->j_func_graphs();
if (!j_fgs.empty()) {
// check g1->J(fg)->g2->g cycle; // check g1->J(fg)->g2->g cycle;
auto contains_j =
std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(),
[path](const std::pair<FuncGraphPtr, int> iter) { return !path->contains(iter.first); });
if (contains_j != func_graph_counter_map[fg].end()) {
auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
return iter.first->seen_ != seen_num;
});
if (contains_j != j_fgs.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true; return true;
} }
} }
path->add(fg);
fg->seen_ = seen_num;


// check if func graphs used contains J(func_graph); // check if func graphs used contains J(func_graph);
auto &used = this->manager_->func_graphs_used();
for (auto &item : used[fg]) {
for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first; auto used_g = item.first;
if (SeekJ(used_g, path)) {
if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true; return true;
} }
@@ -1153,7 +994,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
} }


void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>();
this->j_total_analysis_[fg] = SeekJ(fg, path);
this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
} }
} // namespace mindspore } // namespace mindspore

+ 33
- 164
mindspore/ccsrc/ir/manager.h View File

@@ -140,44 +140,6 @@ class FuncGraphAnalysis {


using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;


// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;

void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }

protected:
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
};

class NodesCollector final : public DepCollector {
public:
explicit NodesCollector(const FuncGraphManager *m);
~NodesCollector() override = default;

const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
size_t size() const override { return nodes_analysis_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }

void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); }

void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;

FuncGraphToAnfNodeMap nodes_analysis_;

protected:
void ExtraReset() override { nodes_analysis_.clear(); }
void OnAddNode(AnfNodePtr n) override;
void OnDropNode(AnfNodePtr n) override;
};

struct CNodeIndexHasher { struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const { std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair); MS_EXCEPTION_IF_NULL(pair);
@@ -204,59 +166,21 @@ struct CNodeIndexEqual {
} }
}; };


template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }

size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }

bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);

FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;

protected:
void ExtraReset() override { count_nodes_map_.clear(); }
};

class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;

protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};

// Record the CNode and its input index, who points to the function graph.
class FuncGraphUsersCNodeIndexCollector final
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public: public:
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphUsersCNodeIndexCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;

protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;


class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }


protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
}; };


class CounterFuncGraphCollector : public DepCollector { class CounterFuncGraphCollector : public DepCollector {
@@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
void ExtraReset() override { count_func_graphs_map_.clear(); } void ExtraReset() override { count_func_graphs_map_.clear(); }
}; };


class FuncGraphChildDirect final : public CounterFuncGraphCollector {
public:
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;

~FuncGraphChildDirect() override = default;

protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};

// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy
// parents:
// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public: public:
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
~FuncGraphParentsDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;

protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }


// graph's all used graphs: key is g, value is g used graph
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphsUsedCollector() override = default;
size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }


protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);


class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
~FuncGraphJDirectCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;


protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
void ExtraReset() override { count_nodes_map_.clear(); }
}; };


using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
@@ -367,8 +268,8 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents // graph g's all direct or proxy parents
class FuncGraphParentsTotalComputer final : public DepComputer { class FuncGraphParentsTotalComputer final : public DepComputer {
public: public:
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {}
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphParentsTotalComputer() override = default;


FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }


@@ -382,10 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;


private: private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
// when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap *all_parents_direct_;
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num);
}; };


using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>; using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
@@ -525,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
void ExtraReset() override { j_total_analysis_.clear(); } void ExtraReset() override { j_total_analysis_.clear(); }


void RealRecompute(FuncGraphPtr fg) override; void RealRecompute(FuncGraphPtr fg) override;
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
}; };


class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
@@ -562,30 +460,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {


NodeUsersMap &node_users() { return node_users_; } NodeUsersMap &node_users() { return node_users_; }


FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }

FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }

FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
return free_variables_direct_->count_nodes_map_;
}

FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
return func_graph_cnodes_index_->count_nodes_map_;
}

FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }

FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_;
}

FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
return func_graph_parents_direct_->count_func_graphs_map_;
}

FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }

FVTotalMap &free_variables_total() const; FVTotalMap &free_variables_total() const;


FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
@@ -610,14 +484,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
// Static Analysis // Static Analysis
NodeUsersMap node_users_; NodeUsersMap node_users_;
AnfNodeSet all_nodes_; // managed nodes AnfNodeSet all_nodes_; // managed nodes
std::shared_ptr<NodesCollector> nodes_;
std::shared_ptr<ValueNodesCollector> valuenodes_;
std::shared_ptr<FVDirectCollector> free_variables_direct_;
std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;


// Dynamic Analysis // Dynamic Analysis
std::shared_ptr<ParentComputer> func_graph_parent_; std::shared_ptr<ParentComputer> func_graph_parent_;
@@ -630,6 +496,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes); FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms); Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
void AddEdge(AnfNodePtr node, int index, AnfNodePtr input);
void DropEdge(AnfNodePtr node, int index, AnfNodePtr input);
void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target);


FuncGraphSet roots_; // managed roots FuncGraphSet roots_; // managed roots
FuncGraphSet func_graphs_; // managed func graphs FuncGraphSet func_graphs_; // managed func graphs


+ 1
- 1
mindspore/ccsrc/optimizer/ad/dfunctor.cc View File

@@ -492,7 +492,7 @@ void DFunctor::MapParamObject() {
void DFunctor::MapValueObject() { void DFunctor::MapValueObject() {
// Map ValueNode. // Map ValueNode.
auto manager = resources_->manager(); auto manager = resources_->manager();
auto &value_nodes = manager->valuenodes()[primal_graph_];
auto &value_nodes = primal_graph_->value_nodes();
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {
auto node = value_pair.first; auto node = value_pair.first;
auto parent_adjoint = FindAdjoint(node); auto parent_adjoint = FindAdjoint(node);


+ 2
- 2
mindspore/ccsrc/optimizer/irpass/branch_culling.cc View File

@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node; std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node;
// record the node input to be replaced // record the node input to be replaced
NodeInputReplMap repl_node_inputs; NodeInputReplMap repl_node_inputs;
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp(); ResetSharedOp();
std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node = std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node =
std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {


+ 1
- 1
mindspore/ccsrc/pipeline/action.cc View File

@@ -391,7 +391,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager(); auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference. // Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = manager->valuenodes()[func_graph];
auto value_nodes = func_graph->value_nodes();
HashCache hash_cache; HashCache hash_cache;
HashValue hashes; HashValue hashes;
for (const auto &value_pair : value_nodes) { for (const auto &value_pair : value_nodes) {


+ 5
- 5
mindspore/ccsrc/vm/transform.cc View File

@@ -488,12 +488,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {


void TraverseGraphMap( void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr); MS_EXCEPTION_IF_NULL(tr);
for (const auto &ct_graphs : cts) {
for (const auto &ct_any : ct_graphs.second) {
for (const auto &fg : fgs) {
for (const auto &ct_any : fg->value_nodes()) {
AnfNodePtr const_primitive_node = ct_any.first; AnfNodePtr const_primitive_node = ct_any.first;
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) { if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
auto users = manager_ptr->node_users()[const_primitive_node]; auto users = manager_ptr->node_users()[const_primitive_node];
@@ -553,8 +553,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
}; };


FuncGraphTransaction tr = manager_ptr->Transact(); FuncGraphTransaction tr = manager_ptr->Transact();
auto &cts = manager_ptr->valuenodes();
TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph);
auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);


return graph; return graph;
} }


+ 12
- 89
tests/ut/cpp/ir/manager_test.cc View File

@@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter(counter_p); CheckAnfNodeCounter(counter_p);
return; return;
} }

auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
if (counter_pair != nullptr) {
CheckCNodeIndexPairCounter(counter_pair);
return;
}

auto nodes = dynamic_pointer_cast<NodesCollector>(results);
if (nodes != nullptr) {
CheckNodes(nodes);
return;
}
} }


private: private:
@@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }


void CheckNodes(std::shared_ptr<NodesCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->nodes_analysis()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);

std::set<std::string> v;
for (auto& node : value) {
if (!node->isa<CNode>() && !Name(node).empty()) {
v.insert(Name(node));
}
}

if (!v.empty()) {
clean_results[k] = v;
}
}

ASSERT_EQ(clean_results, expected_);
}

// Add CheckNesting function // Add CheckNesting function

void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) { for (auto& iter : results->count_nodes_map()) {
@@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }


void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);

std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first->first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}

if (!v.empty()) {
clean_results[k] = v;
}
}

ASSERT_EQ(clean_results, expected_);
}

void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) { for (auto& iter : results->count_func_graphs_map()) {
@@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
} }


// Add TestManager::CheckManager function to checkout the result // Add TestManager::CheckManager function to checkout the result

void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
auto size = mng->func_graphs().size(); auto size = mng->func_graphs().size();


ASSERT_EQ(size + 1, mng->nodes().size());
ASSERT_EQ(size, mng->free_variables_total().size()); ASSERT_EQ(size, mng->free_variables_total().size());
ASSERT_EQ(size, mng->valuenodes().size());
ASSERT_EQ(size, mng->free_variables_direct().size());
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
ASSERT_EQ(size, mng->func_graphs_used().size());
} }


TEST_F(TestManager, test_scalar_add_manual) { TEST_F(TestManager, test_scalar_add_manual) {
@@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);


auto nodes = mng->nodes();
ASSERT_EQ(3, nodes[nullptr].size());
ASSERT_EQ(2, nodes[f].size());
ASSERT_EQ(1, nodes[g].size());
ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(1, g->nodes().size());


auto users = mng->node_users(); auto users = mng->node_users();
for (auto& iter : users) { for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size()); ASSERT_EQ(1, iter.second.size());
} }


auto graphs_used = mng->func_graphs_used();
ASSERT_EQ(1, graphs_used[f].size());
ASSERT_EQ(0, graphs_used[g].size());
ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(0, g->func_graphs_used().size());


auto fv_direct = mng->free_variables_direct();
ASSERT_EQ(0, fv_direct[f].size());
ASSERT_EQ(1, fv_direct[g].size());
ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(1, g->free_variables().size());


auto fv_total = mng->free_variables_total(); auto fv_total = mng->free_variables_total();
ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(0, fv_total[f].size());
ASSERT_EQ(1, fv_total[g].size()); ASSERT_EQ(1, fv_total[g].size());


auto cnodes = mng->func_graph_cnodes_index();
ASSERT_EQ(0, cnodes[f].size());
ASSERT_EQ(1, cnodes[g].size());
ASSERT_EQ(0, f->func_graph_cnodes_index().size());
ASSERT_EQ(1, g->func_graph_cnodes_index().size());
} }


TEST_F(TestManager, test_deep_nested2_manual) { TEST_F(TestManager, test_deep_nested2_manual) {
@@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {


ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(4, gfn->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
ASSERT_EQ(25, mng->node_users().size()); ASSERT_EQ(25, mng->node_users().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
@@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {


ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
} }
@@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr fg = getPyFun("ir_get_fn"); FuncGraphPtr fg = getPyFun("ir_get_fn");


auto mng = Manage(fg); auto mng = Manage(fg);
const FuncGraphToAnfNodeMap& nodes = mng->nodes();
ASSERT_TRUE(nodes.find(fg) != nodes.end());
const auto &fgs = mng->func_graphs();
ASSERT_TRUE(fgs.contains(fg));
FuncGraphSet s; FuncGraphSet s;
s.add(fg); s.add(fg);
mng->MaybeDropFuncGraphs(s); mng->MaybeDropFuncGraphs(s);
ASSERT_TRUE(nodes.find(fg) != nodes.end());
ASSERT_TRUE(fgs.contains(fg));
} }


TEST_F(TestManager, test_keep_roots) { TEST_F(TestManager, test_keep_roots) {


+ 3
- 4
tests/ut/cpp/optimizer/cconv_test.cc View File

@@ -26,15 +26,14 @@
namespace mindspore { namespace mindspore {
void CheckNoFreeVariables(FuncGraphPtr root) { void CheckNoFreeVariables(FuncGraphPtr root) {
auto mng = Manage(root); auto mng = Manage(root);
for (auto &iter : mng->nodes()) {
auto g = iter.first;
auto nodes = iter.second;
for (auto &iter : mng->func_graphs()) {
auto g = iter;
if (g == nullptr) { if (g == nullptr) {
continue; continue;
} }

ASSERT_TRUE(g->parent() == nullptr); ASSERT_TRUE(g->parent() == nullptr);


auto nodes = g->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
ASSERT_EQ(node->func_graph(), g); ASSERT_EQ(node->func_graph(), g);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();


Loading…
Cancel
Save