Browse Source

Refactoring FuncGraphManager module: Move all info. of nodes and edges from FuncGraphManager into FuncGraph.

tags/v0.3.0-alpha
Zhang Qinghua 5 years ago
parent
commit
3ae925115f
11 changed files with 364 additions and 177 deletions
  1. +1
    -0
      mindspore/ccsrc/ir/base.h
  2. +174
    -33
      mindspore/ccsrc/ir/func_graph.cc
  3. +59
    -8
      mindspore/ccsrc/ir/func_graph.h
  4. +6
    -6
      mindspore/ccsrc/ir/func_graph_cloner.cc
  5. +109
    -83
      mindspore/ccsrc/ir/manager.cc
  6. +5
    -37
      mindspore/ccsrc/ir/manager.h
  7. +1
    -1
      mindspore/ccsrc/optimizer/ad/dfunctor.cc
  8. +2
    -2
      mindspore/ccsrc/optimizer/irpass/branch_culling.cc
  9. +1
    -1
      mindspore/ccsrc/parallel/step_parallel.cc
  10. +1
    -1
      mindspore/ccsrc/pipeline/action.cc
  11. +5
    -5
      mindspore/ccsrc/vm/transform.cc

+ 1
- 0
mindspore/ccsrc/ir/base.h View File

@@ -29,6 +29,7 @@
#include "utils/visible.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"

namespace mindspore {
template <typename T>


+ 174
- 33
mindspore/ccsrc/ir/func_graph.cc View File

@@ -195,25 +195,88 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
return this->debug_info_;
}

const AnfNodeSet &FuncGraph::nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &nodes = mng->nodes();
return nodes[shared_from_base<FuncGraph>()];
const AnfNodeSet &FuncGraph::nodes() { return nodes_; }

void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; }

void FuncGraph::ClearNodes() { nodes_.clear(); }

void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); }

void FuncGraph::DropNode(AnfNodePtr node) {
nodes_.erase(node);
auto graph = node->func_graph();
// Remove the node from order list.
if (graph) {
graph->EraseUnusedNodeInOrder(node);
}
}

const AnfNodeCounterMap &FuncGraph::value_nodes() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &cts = mng->valuenodes();
return cts[shared_from_base<FuncGraph>()];
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }

void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; }

void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }

void FuncGraph::AddValueNode(AnfNodePtr node, int count) {
if (value_nodes_.count(node) == 0) {
value_nodes_[node] = count;
} else {
value_nodes_[node] += count;
}
}

const AnfNodeCounterMap &FuncGraph::free_variables_direct() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &fv_direct = mng->free_variables_direct();
return fv_direct[shared_from_base<FuncGraph>()];
void FuncGraph::DropValueNode(AnfNodePtr node) {
if (value_nodes_.count(node) != 0) {
if (value_nodes_[node] == 1) {
(void)value_nodes_.erase(node);
} else {
value_nodes_[node]--;
if (value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}

const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }

void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) {
auto it = others.begin();
for (; it != others.end(); it++) {
if (it->first->func_graph().get() != this) {
(void)AddFreeVariable(it->first, it->second);
}
}
}

void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }

bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) {
if (free_variables_.count(node) == 0) {
free_variables_[node] = count;
return true;
} else {
free_variables_[node] += count;
return false;
}
}

bool FuncGraph::DropFreeVariable(AnfNodePtr node) {
if (free_variables_.count(node) != 0) {
if (free_variables_[node] == 1) {
(void)free_variables_.erase(node);
return true;
} else {
free_variables_[node]--;
if (free_variables_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of free variable '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
}

const BaseRefCounterMap &FuncGraph::free_variables_total() {
@@ -249,11 +312,36 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs;
}

const FuncGraphCounterMap &FuncGraph::func_graphs_used() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &used = mng->func_graphs_used();
return used[shared_from_base<FuncGraph>()];
const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; }

void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; }

void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); }

bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) {
if (func_graph_value_nodes_.count(node) == 0) {
func_graph_value_nodes_[node] = count;
return true;
} else {
func_graph_value_nodes_[node] += count;
return false;
}
}

bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) {
if (func_graph_value_nodes_.count(node) != 0) {
if (func_graph_value_nodes_[node] == 1) {
(void)func_graph_value_nodes_.erase(node);
return true;
} else {
func_graph_value_nodes_[node]--;
if (func_graph_value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
return false;
}

const FuncGraphSet &FuncGraph::func_graphs_used_total() {
@@ -263,15 +351,68 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return used;
}

const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() {
auto mng = manager_.lock();
if (mng == nullptr) {
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
<< " NodeInfo: " << trace::GetDebugInfo(debug_info());
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }

void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) {
auto it = others.begin();
for (; it != others.end(); it++) {
// Ignore the user graph who may own itself.
if (it->first->first->func_graph().get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second);
}
}
}

void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }

void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) {
if (func_graph_cnodes_index_.count(pair) == 0) {
func_graph_cnodes_index_[pair] = count;
} else {
func_graph_cnodes_index_[pair] += count;
}
}

void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
if (func_graph_cnodes_index_.count(pair) != 0) {
if (func_graph_cnodes_index_[pair] == 1) {
(void)func_graph_cnodes_index_.erase(pair);
} else {
func_graph_cnodes_index_[pair]--;
if (func_graph_cnodes_index_[pair] < 0) {
MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
}

const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; }

void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; }

void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); }

void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) {
if (j_func_graph_value_nodes_.count(node) == 0) {
j_func_graph_value_nodes_[node] = count;
} else {
j_func_graph_value_nodes_[node] += count;
}
}

void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) {
if (j_func_graph_value_nodes_.count(node) != 0) {
if (j_func_graph_value_nodes_[node] == 1) {
(void)j_func_graph_value_nodes_.erase(node);
} else {
j_func_graph_value_nodes_[node]--;
if (j_func_graph_value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
}
MS_EXCEPTION_IF_NULL(mng);
auto &cnode = mng->func_graph_cnodes_index();
return cnode[shared_from_base<FuncGraph>()];
}

FuncGraphPtr FuncGraph::parent() {
@@ -662,10 +803,10 @@ void FuncGraph::EraseUnusedNodeInOrder() {
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
auto mng = manager_.lock();
if (mng) {
auto nodes = mng->nodes()[shared_from_base<FuncGraph>()];
auto &all_nodes = nodes();
// Erase unused cnode.
for (auto it = order_.begin(); it != order_.end();) {
if (nodes.count(*it)) {
if (all_nodes.count(*it)) {
(void)it++;
} else {
MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order.";
@@ -702,11 +843,11 @@ void FuncGraph::CheckOrder() {
}
auto mng = manager_.lock();
if (mng != nullptr) {
const auto &nodes = mng->nodes()[shared_from_base<FuncGraph>()];
if (nodes.size() != (order_.size() + parameters_.size())) {
const auto &all_nodes = nodes();
if (all_nodes.size() != (order_.size() + parameters_.size())) {
DumpCNodeList();
MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size "
<< nodes.size() - parameters_.size() << ".";
<< all_nodes.size() - parameters_.size() << ".";
}
}
MS_LOG(DEBUG) << "Check order okay.";


+ 59
- 8
mindspore/ccsrc/ir/func_graph.h View File

@@ -26,6 +26,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <functional>

#include "ir/anf.h"
#include "ir/manager.h"
@@ -36,8 +37,13 @@
namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;

template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;

using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;

const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@@ -183,12 +189,24 @@ class FuncGraph : public FuncGraphBase {

// get all nodes belonging to this func graph
const AnfNodeSet &nodes();
void CopyNodes(const AnfNodeSet &other_nodes);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);

// get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes();

// get all vars directly pointed to in this func graph
const AnfNodeCounterMap &free_variables_direct();
void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes);
void ClearValueNodes();
void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);

// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const AnfNodeCounterMap &others);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);

// get all vars required by this func graph
const BaseRefCounterMap &free_variables_total();
@@ -199,14 +217,29 @@ class FuncGraph : public FuncGraphBase {
// get all vars that are func graphs
std::vector<FuncGraphPtr> free_variables_func_graphs();

// get all func graphs directly used by this func graph
const FuncGraphCounterMap &func_graphs_used();
// get all value nodes of func graph directly used by this func graph
const AnfNodeCounterMap &func_graph_value_nodes();
void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others);
void ClearFuncGraphValueNodes();
bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1);
bool DropFuncGraphValueNode(AnfNodePtr node);

// get all value nodes of J func graph directly used by this func graph
const AnfNodeCounterMap &j_func_graph_value_nodes();
void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others);
void ClearJFuncGraphValueNodes();
void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1);
void DropJFuncGraphValueNode(AnfNodePtr node);

// get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total();

// get all user value nodes of this func graph
// get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);

// Return the parent of this graph.
FuncGraphPtr parent();
@@ -270,6 +303,24 @@ class FuncGraph : public FuncGraphBase {
// graph is manipulated by manager and others
friend FuncGraphManager;

// all nodes of the function
AnfNodeSet nodes_;

// all value nodes of the function
AnfNodeCounterMap value_nodes_;

// all func graph value nodes of the function
AnfNodeCounterMap func_graph_value_nodes_;

// all free variables of the function
AnfNodeCounterMap free_variables_;

// all value nodes calling J in the function
AnfNodeCounterMap j_func_graph_value_nodes_;

// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;

// parameters of this function
std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_;


+ 6
- 6
mindspore/ccsrc/ir/func_graph_cloner.cc View File

@@ -123,7 +123,7 @@ void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
if (!clone_all_valuenodes_) {
return;
}
auto &value_nodes = manager_->valuenodes()[func_graph];
auto &value_nodes = func_graph->value_nodes();
for (auto &value_node : value_nodes) {
auto old_node = value_node.first;
MS_EXCEPTION_IF_NULL(old_node);
@@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) {
return;
}
auto &used_graphs = manager_->func_graphs_used()[func_graph];
for (auto &used_graph : used_graphs) {
todo_.push_back({used_graph.first, nullptr, {}});
auto &used = func_graph->func_graph_value_nodes();
for (auto &fg_value_node : used) {
todo_.push_back({GetValueNode<FuncGraphPtr>(fg_value_node.first), nullptr, {}});
}
}

@@ -185,7 +185,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
}
target_func_graph->set_return(return_node);

auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second);
@@ -441,7 +441,7 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_);
const AnfNodeSet &nodes = manager_->nodes()[func_graph];
const AnfNodeSet &nodes = func_graph->nodes();
for (auto &node : nodes) {
CloneNode(node, target_func_graph);
}


+ 109
- 83
mindspore/ccsrc/ir/manager.cc View File

@@ -78,19 +78,6 @@ void FuncGraphManager::Reset() {
node_users_ = NodeUsersMap();

signals_ = std::make_shared<Signals>();
// FuncGraph --> AnfNode
nodes_ = std::make_shared<NodesCollector>(this);

// FuncGraph --> {AnfNode, Count}
valuenodes_ = std::make_shared<ValueNodesCollector>(this);
free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this);

// FuncGraph --> {FuncGraph, Count}
func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);

func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
func_graph_parent_ = std::make_shared<ParentComputer>(this);
@@ -210,7 +197,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
}
AddIntoManaged(func_graph);
MS_EXCEPTION_IF_NULL(signals_);
signals_->AddFuncGraph(func_graph);
std::vector<AnfNodePtr> para = func_graph->parameters();
AcquireNodes(para);
std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
@@ -224,7 +210,6 @@ void FuncGraphManager::Clear() {
node_users_.clear();
roots_.clear();

signals_->InvalidateCollector();
signals_->InvalidateComputer();
}

@@ -303,8 +288,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
continue;
}
MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_);
auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph];
auto &users_cnode_index = func_graph->func_graph_cnodes_index();
if (!users_cnode_index.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue;
@@ -320,7 +304,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
MS_EXCEPTION_IF_NULL(signals_);
for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg);
signals_->DropFuncGraph(fg);
all_nodes_.difference_update(fg->parameters());
(void)func_graphs_.erase(fg);
if (fg->manager().get() == this) {
@@ -339,7 +322,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
return;
}
(void)users_node.erase(make_pair(node, index));
signals_->DropEdge(node, index, inp);
DropEdge(node, index, inp);
} else {
MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString();
if (inp->func_graph() != nullptr) {
@@ -352,7 +335,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_);
signals_->AddEdge(node, index, inp);
AddEdge(node, index, inp);
}
}

@@ -392,8 +375,8 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphPtr fg = node->func_graph();
if (fg != nullptr) {
AddFuncGraph(fg);
fg->AddNode(node);
}
signals_->AddNode(node);
ProcessInputs(node, kIncEdge);
}
}
@@ -424,7 +407,10 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &
}
ProcessInputs(node, kDecEdge);
(void)all_nodes_.erase(node);
signals_->DropNode(node);
if (node->func_graph() != nullptr) {
node->func_graph()->DropNode(node);
}

if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
nodes_ordered.update(cnode->inputs());
@@ -462,35 +448,21 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t

int index = 0;
(void)node_users_[source_prim].erase(make_pair(source_return, index));
signals_->DropEdge(source_return, index, source_prim);
DropEdge(source_return, index, source_prim);
index = 1;
(void)node_users_[source_output].erase(make_pair(source_return, index));
signals_->DropEdge(source_return, index, source_output);
DropEdge(source_return, index, source_output);
(void)all_nodes_.erase(source_return);
(void)node_users_.erase(source_return);
signals_->DropNode(source_return);
source->DropNode(source_return);
for (auto &node : source->nodes()) {
node->set_func_graph(target);
if (node->scope() == kDefaultScope) {
node->set_scope(scope);
}
}
for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source);
}
for (auto &fv_count : this->free_variables_direct()[source]) {
auto fv_g = fv_count.first->func_graph();
auto &count_on_g = this->func_graph_child_direct()[fv_g];
auto pair = count_on_g.find(source);
if (fv_g != target && pair != count_on_g.end()) {
(void)func_graph_child_direct_->Inc(fv_g, target, pair->second);
}
(void)count_on_g.erase(source);
}
signals_->MoveAllCNode(source, target);
signals_->InvalidateComputer();
signals_->DropFuncGraph(source);

MoveAllNodes(source, target);
all_nodes_.difference_update(source->parameters());
(void)func_graphs_.erase(source);
if (source->manager().get() == this) {
@@ -498,6 +470,64 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
}
}

inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) {
auto fg = node->func_graph();
if (input->isa<ValueNode>()) {
fg->AddValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
if (fg->AddFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input);
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJFuncGraphValueNode(input);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->AddFreeVariable(input)) {
signals_->InvalidateComputer();
}
}
}

inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) {
auto fg = node->func_graph();
if (input->isa<ValueNode>()) {
fg->DropValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
if (fg->DropFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input);
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJFuncGraphValueNode(input);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->DropFreeVariable(input)) {
signals_->InvalidateComputer();
}
}
}

inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyNodes(source->nodes());
target->CopyValueNodes(source->value_nodes());
target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index());
target->CopyFreeVariables(source->free_variables());
target->CopyFuncGraphValueNodes(source->func_graph_value_nodes());
target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes());
signals_->InvalidateComputer();
source->ClearNodes();
source->ClearValueNodes();
source->ClearFuncGraphCNodesIndex();
source->ClearFreeVariables();
source->ClearFuncGraphValueNodes();
source->ClearJFuncGraphValueNodes();
}

FuncGraphTransaction FuncGraphManager::Transact() {
auto tr = FuncGraphTransaction(this);
return tr;
@@ -630,7 +660,6 @@ void NodesCollector::OnAddNode(AnfNodePtr n) {
if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) {
nodes_analysis_[n->func_graph()] = AnfNodeSet();
}

nodes_analysis_[n->func_graph()].add(n);
}

@@ -910,17 +939,19 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
return std::make_shared<FuncGraphSet>();
}
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_;
for (auto &dep : deps[fg]) {
MS_EXCEPTION_IF_NULL(dep.first);
auto proxy = dep.first->transforms().find("proxy");
if (proxy != dep.first->transforms().end()) {
path->add(fg);
auto gt = proxy->second.func_graph();
parents->update(SeekParents(gt, path));
} else {
parents->add(dep.first);
}

// Append all the fvs in fg.
auto &fvs = fg->free_variables();
for (auto fv : fvs) {
parents->add(fv.first->func_graph());
}

// Search the fv in fg's child func graph.
auto &fg_value_nodes = fg->func_graph_value_nodes();
for (auto &fg_value_node : fg_value_nodes) {
path->add(fg);
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
parents->update(SeekParents(gt, path));
}
(void)parents->erase(fg);
return parents;
@@ -928,10 +959,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f

void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg);
all_parents_direct_ = &(manager_->func_graph_parents_direct());
MS_LOG(DEBUG) << fg->ToString() << " total func graph dep size:" << (*all_parents_direct_)[fg].size();
func_graph_parents_total_analysis_[fg].update(SeekParents(fg));
MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size();
}

bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
@@ -1001,28 +1029,30 @@ void FVTotalComputer::RealRecompute() {
}

for (auto &fg : manager->func_graphs()) {
AnfNodeCounterMap items = manager->free_variables_direct()[fg];
AnfNodeCounterMap items = fg->free_variables();
for (auto &iter : items) {
auto curr = fg;
while (curr) {
while (curr != nullptr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr);
const AnfNodeSet &nodes = manager->nodes()[curr];
if (nodes.contains(iter.first)) {
break;
if (curr != nullptr) {
const AnfNodeSet &all_nodes = curr->nodes();
if (all_nodes.contains(iter.first)) {
break;
}
}
}
}

auto items_fg = manager->func_graphs_used()[fg];
for (auto &iter : items_fg) {
auto p = manager->parent(iter.first);
auto &used = fg->func_graph_value_nodes();
for (auto &iter : used) {
auto p = manager->parent(GetValueNode<FuncGraphPtr>(iter.first));
if (p == nullptr) {
continue;
}
auto curr = fg;
while (curr != p) {
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
(void)CounterFuncGraphCollector::Mod(curr, GetValueNode<FuncGraphPtr>(iter.first), iter.second);
curr = manager->parent(curr);
}
}
@@ -1041,7 +1071,6 @@ void FVTotalComputer::RealRecompute() {

void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
auto &used = this->manager_->func_graphs_used();
std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new;

@@ -1049,8 +1078,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
while (!todo.empty()) {
todo_new.clear();
for (auto &gt : todo) {
for (auto &item : used[gt]) {
auto used_fg = item.first;
for (auto &item : gt->func_graph_value_nodes()) {
auto used_fg = GetValueNode<FuncGraphPtr>(item.first);
if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg);
continue;
@@ -1068,7 +1097,6 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {

bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(manager);
auto &used = manager->func_graphs_used();
std::vector<FuncGraphPtr> todo;
std::vector<FuncGraphPtr> todo_new;
todo.push_back(fg);
@@ -1076,8 +1104,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
while (!todo.empty()) {
todo_new.clear();
for (auto &gt : todo) {
for (auto &item : used[gt]) {
auto used_g = item.first;
for (auto &item : gt->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
if (used_g == fg) {
return true;
}
@@ -1108,9 +1136,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
}
} else {
trace->push_back(fg);
auto &used_fgs = manager_->func_graphs_used()[fg];
for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) {
CheckRecursiveGraphs(iter->first, trace);
auto &items = fg->func_graph_value_nodes();
for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
CheckRecursiveGraphs(GetValueNode<FuncGraphPtr>(iter->first), trace);
}
trace->pop_back();
if (!recursive_map_.count(fg)) {
@@ -1125,14 +1153,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false;
}
MS_EXCEPTION_IF_NULL(manager_);
auto &func_graph_counter_map = manager_->func_graph_j_direct();
if (!func_graph_counter_map[fg].empty()) {
auto &j_fg_value_nodes = fg->j_func_graph_value_nodes();
if (!j_fg_value_nodes.empty()) {
// check g1->J(fg)->g2->g cycle;
auto contains_j =
std::find_if(func_graph_counter_map[fg].begin(), func_graph_counter_map[fg].end(),
[path](const std::pair<FuncGraphPtr, int> iter) { return !path->contains(iter.first); });
if (contains_j != func_graph_counter_map[fg].end()) {
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(),
[path](const std::pair<AnfNodePtr, int> iter) { return !path->contains(GetValueNode<FuncGraphPtr>(iter.first)); });
if (contains_j != j_fg_value_nodes.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true;
}
@@ -1140,9 +1167,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
path->add(fg);

// check if func graphs used contains J(func_graph);
auto &used = this->manager_->func_graphs_used();
for (auto &item : used[fg]) {
auto used_g = item.first;
for (auto &item : fg->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
if (SeekJ(used_g, path)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true;


+ 5
- 37
mindspore/ccsrc/ir/manager.h View File

@@ -367,8 +367,8 @@ class DepComputer : public FuncGraphAnalysis {
// graph g's all direct or proxy parents
class FuncGraphParentsTotalComputer final : public DepComputer {
public:
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {}
~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; }
explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FuncGraphParentsTotalComputer() override = default;

FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }

@@ -383,9 +383,6 @@ class FuncGraphParentsTotalComputer final : public DepComputer {

private:
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
// when SeekParents calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 1 parameter for SeekParents().
FuncGraphToFuncGraphCounterMap *all_parents_direct_;
};

using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
@@ -562,30 +559,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {

NodeUsersMap &node_users() { return node_users_; }

FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }

FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }

FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
return free_variables_direct_->count_nodes_map_;
}

FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
return func_graph_cnodes_index_->count_nodes_map_;
}

FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }

FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_;
}

FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const {
return func_graph_parents_direct_->count_func_graphs_map_;
}

FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; }

FVTotalMap &free_variables_total() const;

FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
@@ -610,14 +583,6 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
// Static Analysis
NodeUsersMap node_users_;
AnfNodeSet all_nodes_; // managed nodes
std::shared_ptr<NodesCollector> nodes_;
std::shared_ptr<ValueNodesCollector> valuenodes_;
std::shared_ptr<FVDirectCollector> free_variables_direct_;
std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;

// Dynamic Analysis
std::shared_ptr<ParentComputer> func_graph_parent_;
@@ -630,6 +595,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphSetPtr MaybeDropNodes(const std::vector<AnfNodePtr> &nodes);
void ParseChanges(const std::vector<Change> &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges,
Counter<AnfNodePtr> *adds, Counter<AnfNodePtr> *rms);
void AddEdge(AnfNodePtr node, int index, AnfNodePtr input);
void DropEdge(AnfNodePtr node, int index, AnfNodePtr input);
void MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target);

FuncGraphSet roots_; // managed roots
FuncGraphSet func_graphs_; // managed func graphs


+ 1
- 1
mindspore/ccsrc/optimizer/ad/dfunctor.cc View File

@@ -491,7 +491,7 @@ void DFunctor::MapParamObject() {
void DFunctor::MapValueObject() {
// Map ValueNode.
auto manager = resources_->manager();
auto &value_nodes = manager->valuenodes()[primal_graph_];
auto &value_nodes = primal_graph_->value_nodes();
for (const auto &value_pair : value_nodes) {
auto node = value_pair.first;
auto parent_adjoint = FindAdjoint(node);


+ 2
- 2
mindspore/ccsrc/optimizer/irpass/branch_culling.cc View File

@@ -119,7 +119,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node;
// record the node input to be replaced
NodeInputReplMap repl_node_inputs;
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@@ -436,7 +436,7 @@ FuncGraphPtr TransformGraphDependNode(
ResetSharedOp();
std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_node =
std::make_shared<std::unordered_map<AnfNodePtr, AnfNodePtr>>(); // record the node to be replaced
const AnfNodeSet &nodes = manager->nodes()[graph];
const AnfNodeSet &nodes = graph->nodes();
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {


+ 1
- 1
mindspore/ccsrc/parallel/step_parallel.cc View File

@@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
SetForwardFlag(all_nodes);
} else {
for (auto &func_graph : graph_set) {
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size();
auto return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);


+ 1
- 1
mindspore/ccsrc/pipeline/action.cc View File

@@ -389,7 +389,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto manager = res->manager();
// Remove duplicated value nodes, due to replace operation, can't use reference.
auto value_nodes = manager->valuenodes()[func_graph];
auto value_nodes = func_graph->value_nodes();
HashCache hash_cache;
HashValue hashes;
for (const auto &value_pair : value_nodes) {


+ 5
- 5
mindspore/ccsrc/vm/transform.cc View File

@@ -487,12 +487,12 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {

void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr);
for (const auto &ct_graphs : cts) {
for (const auto &ct_any : ct_graphs.second) {
for (const auto &fg : fgs) {
for (const auto &ct_any : fg->value_nodes()) {
AnfNodePtr const_primitive_node = ct_any.first;
if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
auto users = manager_ptr->node_users()[const_primitive_node];
@@ -552,8 +552,8 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
};

FuncGraphTransaction tr = manager_ptr->Transact();
auto &cts = manager_ptr->valuenodes();
TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph);
auto &fgs = manager_ptr->func_graphs();
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);

return graph;
}


Loading…
Cancel
Save