Browse Source

!2126 Refactoring the func graph manager module.

Merge pull request !2126 from ZhangQinghua/master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
cb706951c1
4 changed files with 42 additions and 320 deletions
  1. +26
    -0
      mindspore/ccsrc/ir/func_graph.h
  2. +5
    -111
      mindspore/ccsrc/ir/manager.cc
  3. +10
    -143
      mindspore/ccsrc/ir/manager.h
  4. +1
    -66
      tests/ut/cpp/ir/manager_test.cc

+ 26
- 0
mindspore/ccsrc/ir/func_graph.h View File

@@ -38,6 +38,32 @@ namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;

struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
}
};

struct CNodeIndexEqual {
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
};

template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;


+ 5
- 111
mindspore/ccsrc/ir/manager.cc View File

@@ -633,103 +633,7 @@ void FuncGraphTransaction::Commit() {
manager_->CommitChanges(changes);
}

FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
: manager_(manager), include_func_graph_none_(false) {}

DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_);
}

void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }

void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }

template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}

template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
MS_EXCEPTION_IF_NULL(func_graph);
auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}

template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}

bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}

bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}

bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}

DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
validate_ = false;
@@ -839,16 +743,15 @@ void FVTotalComputer::RealRecompute() {

for (auto &fg : manager->func_graphs()) {
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
}

for (auto &fg : manager->func_graphs()) {
// add all free variable nodes
AnfNodeCounterMap items = fg->free_variables();
for (auto &iter : items) {
auto curr = fg;
while (curr != nullptr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
fv_total_analysis_[curr][iter.first] = iter.second;
curr = manager->parent(curr);
if (curr != nullptr) {
const AnfNodeSet &all_nodes = curr->nodes();
@@ -859,6 +762,7 @@ void FVTotalComputer::RealRecompute() {
}
}

// add all FGs of free variables
auto &used = fg->func_graphs_used();
for (auto &iter : used) {
auto p = manager->parent(iter.first);
@@ -867,21 +771,11 @@ void FVTotalComputer::RealRecompute() {
}
auto curr = fg;
while (curr != p) {
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
fv_total_analysis_[curr][iter.first] = iter.second;
curr = manager->parent(curr);
}
}
}
for (auto &fg : manager->func_graphs()) {
auto &fvp = count_nodes_map_[fg];
auto &fvg = count_func_graphs_map_[fg];
for (auto &item : fvp) {
fv_total_analysis_[fg][item.first] = item.second;
}
for (auto &item : fvg) {
fv_total_analysis_[fg][item.first] = item.second;
}
}
}

void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {


+ 10
- 143
mindspore/ccsrc/ir/manager.h View File

@@ -88,14 +88,6 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool ma
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);

struct Signals {
Signal<void(FuncGraphPtr)> AddFuncGraph;
Signal<void(FuncGraphPtr)> DropFuncGraph;
Signal<void(AnfNodePtr)> AddNode;
Signal<void(AnfNodePtr)> DropNode;
Signal<void(AnfNodePtr, int, AnfNodePtr)> AddEdge;
Signal<void(AnfNodePtr, int, AnfNodePtr)> DropEdge;
Signal<void(FuncGraphPtr, FuncGraphPtr)> MoveAllCNode;
Signal<void()> InvalidateCollector;
Signal<void()> InvalidateComputer;
};

@@ -103,136 +95,15 @@ enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };

using CNodeIndexPair = std::pair<AnfNodePtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;

using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;

// analysis base class
class FuncGraphAnalysis {
public:
explicit FuncGraphAnalysis(const FuncGraphManager *const manager);

virtual ~FuncGraphAnalysis() { manager_ = nullptr; }

virtual size_t size() const { return 0; }

virtual void OnAddFuncGraph(FuncGraphPtr) {}

virtual void OnDropFuncGraph(FuncGraphPtr) {}

virtual void OnMoveAllCNode(FuncGraphPtr, FuncGraphPtr) {}

protected:
// subclass can reset their own member;
virtual void ExtraReset() {}

virtual void OnAddNode(AnfNodePtr n) {}

virtual void OnDropNode(AnfNodePtr n) {}

virtual void OnAddEdge(AnfNodePtr, int, AnfNodePtr) {}

virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}

const FuncGraphManager *manager_;
bool include_func_graph_none_;
};

using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;

struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
}
};

struct CNodeIndexEqual {
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
};

// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;

void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }

protected:
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
};

class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);

FuncGraphToFuncGraphCounterMap count_func_graphs_map_;

protected:
void ExtraReset() override { count_func_graphs_map_.clear(); }
};

template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }

size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }

bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);

FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;

protected:
void ExtraReset() override { count_nodes_map_.clear(); }
};

using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;

// graphs analysis which need dynamic compute by DepCollector in each read
class DepComputer : public FuncGraphAnalysis {
// analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
class DepComputer {
public:
explicit DepComputer(const FuncGraphManager *manager);
~DepComputer() override = default;
virtual ~DepComputer() { manager_ = nullptr; }

virtual size_t size() const { return 0; }

void Reset() {
ExtraReset();
@@ -250,15 +121,14 @@ class DepComputer : public FuncGraphAnalysis {

bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }

void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }

void OnDropFuncGraph(FuncGraphPtr) final { Reset(); }

protected:
// subclass can reset their own member;
virtual void ExtraReset() {}
// subclass do the real compute
virtual void RealRecompute() {}
virtual void RealRecompute(FuncGraphPtr) {}

const FuncGraphManager *manager_;
bool validate_;
OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;

@@ -345,12 +215,9 @@ class ScopeComputer final : public DepComputer {

using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;

class FVTotalComputer final : public DepComputer,
public CounterAnfNodeCollector<AnfNodePtr>,
public CounterFuncGraphCollector {
class FVTotalComputer final : public DepComputer {
public:
explicit FVTotalComputer(const FuncGraphManager *m)
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
~FVTotalComputer() override = default;

FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }


+ 1
- 66
tests/ut/cpp/ir/manager_test.cc View File

@@ -104,7 +104,7 @@ class NestingSpecs {
return name;
}

void Check(std::shared_ptr<FuncGraphAnalysis> results) {
void Check(std::shared_ptr<DepComputer> results) {
if (expected_.empty() && expected_recursive_.empty()) {
return;
}
@@ -120,18 +120,6 @@ class NestingSpecs {
CheckRecursive(recursive);
return;
}

auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
if (counter_g != nullptr) {
CheckGraphCounter(counter_g);
return;
}

auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
if (counter_p != nullptr) {
CheckAnfNodeCounter(counter_p);
return;
}
}

private:
@@ -193,59 +181,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_);
}

// Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);

std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}

if (!v.empty()) {
clean_results[k] = v;
}
}

ASSERT_EQ(clean_results, expected_);
}

void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);

std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}

if (!v.empty()) {
clean_results[k] = v;
}
}

ASSERT_EQ(clean_results, expected_);
}

void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
std::map<std::string, bool> clean_results;
for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {


Loading…
Cancel
Save