Merge pull request !2126 from ZhangQinghua/mastertags/v0.5.0-beta
| @@ -38,6 +38,32 @@ namespace mindspore { | |||
| using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; | |||
| using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; | |||
| struct CNodeIndexHasher { | |||
| std::size_t operator()(const CNodeIndexPairPtr pair) const { | |||
| MS_EXCEPTION_IF_NULL(pair); | |||
| MS_EXCEPTION_IF_NULL(pair->first); | |||
| return hash_combine(pair->first->hash(), std::hash<int>()(pair->second)); | |||
| } | |||
| }; | |||
| struct CNodeIndexEqual { | |||
| bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { | |||
| if (lhs == nullptr || rhs == nullptr) { | |||
| return false; | |||
| } | |||
| if (lhs == rhs) { | |||
| return true; | |||
| } | |||
| if (lhs->first != rhs->first) { | |||
| return false; | |||
| } | |||
| if (lhs->second != rhs->second) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| }; | |||
| template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>> | |||
| using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>; | |||
| using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>; | |||
| @@ -633,103 +633,7 @@ void FuncGraphTransaction::Commit() { | |||
| manager_->CommitChanges(changes); | |||
| } | |||
| FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) | |||
| : manager_(manager), include_func_graph_none_(false) {} | |||
| DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| } | |||
| void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } | |||
| void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| } else { | |||
| d[key] += count; | |||
| } | |||
| return false; | |||
| } | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto &d = count_nodes_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| return true; | |||
| } else { | |||
| d[key] -= count; | |||
| if (d[key] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| template <typename ValueT, class CollectorHash, class CollectorEqual> | |||
| bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph, | |||
| const ValueT &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| return Dec(func_graph, key, -count); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key | |||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) == 0) { | |||
| d[key] = count; | |||
| return true; | |||
| } else { | |||
| d[key] += count; | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { | |||
| auto &d = count_func_graphs_map_[func_graph]; | |||
| if (d.count(key) != 0) { | |||
| if (d[key] == count) { | |||
| (void)d.erase(key); | |||
| return true; | |||
| } else { | |||
| d[key] -= count; | |||
| if (d[key] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { | |||
| if (count > 0) { | |||
| return Inc(func_graph, key, count); | |||
| } else if (count < 0) { | |||
| return Dec(func_graph, key, -count); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() | |||
| << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); | |||
| } | |||
| } | |||
| DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { | |||
| DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { | |||
| MS_EXCEPTION_IF_NULL(manager_); | |||
| manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); | |||
| validate_ = false; | |||
| @@ -839,16 +743,15 @@ void FVTotalComputer::RealRecompute() { | |||
| for (auto &fg : manager->func_graphs()) { | |||
| fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); | |||
| count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); | |||
| count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); | |||
| } | |||
| for (auto &fg : manager->func_graphs()) { | |||
| // add all free variable nodes | |||
| AnfNodeCounterMap items = fg->free_variables(); | |||
| for (auto &iter : items) { | |||
| auto curr = fg; | |||
| while (curr != nullptr) { | |||
| (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); | |||
| fv_total_analysis_[curr][iter.first] = iter.second; | |||
| curr = manager->parent(curr); | |||
| if (curr != nullptr) { | |||
| const AnfNodeSet &all_nodes = curr->nodes(); | |||
| @@ -859,6 +762,7 @@ void FVTotalComputer::RealRecompute() { | |||
| } | |||
| } | |||
| // add all FGs of free variables | |||
| auto &used = fg->func_graphs_used(); | |||
| for (auto &iter : used) { | |||
| auto p = manager->parent(iter.first); | |||
| @@ -867,21 +771,11 @@ void FVTotalComputer::RealRecompute() { | |||
| } | |||
| auto curr = fg; | |||
| while (curr != p) { | |||
| (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); | |||
| fv_total_analysis_[curr][iter.first] = iter.second; | |||
| curr = manager->parent(curr); | |||
| } | |||
| } | |||
| } | |||
| for (auto &fg : manager->func_graphs()) { | |||
| auto &fvp = count_nodes_map_[fg]; | |||
| auto &fvg = count_func_graphs_map_[fg]; | |||
| for (auto &item : fvp) { | |||
| fv_total_analysis_[fg][item.first] = item.second; | |||
| } | |||
| for (auto &item : fvg) { | |||
| fv_total_analysis_[fg][item.first] = item.second; | |||
| } | |||
| } | |||
| } | |||
| void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { | |||
| @@ -88,14 +88,6 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool ma | |||
| FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true); | |||
| struct Signals { | |||
| Signal<void(FuncGraphPtr)> AddFuncGraph; | |||
| Signal<void(FuncGraphPtr)> DropFuncGraph; | |||
| Signal<void(AnfNodePtr)> AddNode; | |||
| Signal<void(AnfNodePtr)> DropNode; | |||
| Signal<void(AnfNodePtr, int, AnfNodePtr)> AddEdge; | |||
| Signal<void(AnfNodePtr, int, AnfNodePtr)> DropEdge; | |||
| Signal<void(FuncGraphPtr, FuncGraphPtr)> MoveAllCNode; | |||
| Signal<void()> InvalidateCollector; | |||
| Signal<void()> InvalidateComputer; | |||
| }; | |||
| @@ -103,136 +95,15 @@ enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; | |||
| using CNodeIndexPair = std::pair<AnfNodePtr, int>; | |||
| using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>; | |||
| using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>; | |||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||
| using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>; | |||
| // analysis base class | |||
| class FuncGraphAnalysis { | |||
| public: | |||
| explicit FuncGraphAnalysis(const FuncGraphManager *const manager); | |||
| virtual ~FuncGraphAnalysis() { manager_ = nullptr; } | |||
| virtual size_t size() const { return 0; } | |||
| virtual void OnAddFuncGraph(FuncGraphPtr) {} | |||
| virtual void OnDropFuncGraph(FuncGraphPtr) {} | |||
| virtual void OnMoveAllCNode(FuncGraphPtr, FuncGraphPtr) {} | |||
| protected: | |||
| // subclass can reset their own member; | |||
| virtual void ExtraReset() {} | |||
| virtual void OnAddNode(AnfNodePtr n) {} | |||
| virtual void OnDropNode(AnfNodePtr n) {} | |||
| virtual void OnAddEdge(AnfNodePtr, int, AnfNodePtr) {} | |||
| virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {} | |||
| const FuncGraphManager *manager_; | |||
| bool include_func_graph_none_; | |||
| }; | |||
| using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; | |||
| struct CNodeIndexHasher { | |||
| std::size_t operator()(const CNodeIndexPairPtr pair) const { | |||
| MS_EXCEPTION_IF_NULL(pair); | |||
| MS_EXCEPTION_IF_NULL(pair->first); | |||
| return hash_combine(pair->first->hash(), std::hash<int>()(pair->second)); | |||
| } | |||
| }; | |||
| struct CNodeIndexEqual { | |||
| bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { | |||
| if (lhs == nullptr || rhs == nullptr) { | |||
| return false; | |||
| } | |||
| if (lhs == rhs) { | |||
| return true; | |||
| } | |||
| if (lhs->first != rhs->first) { | |||
| return false; | |||
| } | |||
| if (lhs->second != rhs->second) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| }; | |||
| // graphs analysis which compute in write, read needn't recompute | |||
| class DepCollector : public FuncGraphAnalysis { | |||
| public: | |||
| explicit DepCollector(const FuncGraphManager *manager); | |||
| ~DepCollector() override = default; | |||
| void Reset() { ExtraReset(); } | |||
| void OnInvalidateCollector() { Reset(); } | |||
| protected: | |||
| // inherit from FuncGraphAnalysis | |||
| void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||
| void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override; | |||
| // subclass can override; | |||
| virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {} | |||
| }; | |||
| class CounterFuncGraphCollector : public DepCollector { | |||
| public: | |||
| explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterFuncGraphCollector() override = default; | |||
| FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } | |||
| // inherit from FuncGraphAnalysis | |||
| size_t size() const override { return count_func_graphs_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); | |||
| FuncGraphToFuncGraphCounterMap count_func_graphs_map_; | |||
| protected: | |||
| void ExtraReset() override { count_func_graphs_map_.clear(); } | |||
| }; | |||
| template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> | |||
| class CounterAnfNodeCollector : public DepCollector { | |||
| public: | |||
| explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} | |||
| ~CounterAnfNodeCollector() override = default; | |||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; } | |||
| size_t size() const override { return count_nodes_map_.size(); } | |||
| void OnAddFuncGraph(FuncGraphPtr fg) final { | |||
| count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>(); | |||
| } | |||
| void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } | |||
| bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count); | |||
| FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_; | |||
| protected: | |||
| void ExtraReset() override { count_nodes_map_.clear(); } | |||
| }; | |||
| using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; | |||
| // graphs analysis which need dynamic compute by DepCollector in each read | |||
| class DepComputer : public FuncGraphAnalysis { | |||
| // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read | |||
| class DepComputer { | |||
| public: | |||
| explicit DepComputer(const FuncGraphManager *manager); | |||
| ~DepComputer() override = default; | |||
| virtual ~DepComputer() { manager_ = nullptr; } | |||
| virtual size_t size() const { return 0; } | |||
| void Reset() { | |||
| ExtraReset(); | |||
| @@ -250,15 +121,14 @@ class DepComputer : public FuncGraphAnalysis { | |||
| bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } | |||
| void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } | |||
| void OnDropFuncGraph(FuncGraphPtr) final { Reset(); } | |||
| protected: | |||
| // subclass can reset their own member; | |||
| virtual void ExtraReset() {} | |||
| // subclass do the real compute | |||
| virtual void RealRecompute() {} | |||
| virtual void RealRecompute(FuncGraphPtr) {} | |||
| const FuncGraphManager *manager_; | |||
| bool validate_; | |||
| OrderedMap<FuncGraphPtr, bool> func_graphs_validate_; | |||
| @@ -345,12 +215,9 @@ class ScopeComputer final : public DepComputer { | |||
| using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; | |||
| class FVTotalComputer final : public DepComputer, | |||
| public CounterAnfNodeCollector<AnfNodePtr>, | |||
| public CounterFuncGraphCollector { | |||
| class FVTotalComputer final : public DepComputer { | |||
| public: | |||
| explicit FVTotalComputer(const FuncGraphManager *m) | |||
| : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} | |||
| explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} | |||
| ~FVTotalComputer() override = default; | |||
| FVTotalMap &fv_total_analysis() { return fv_total_analysis_; } | |||
| @@ -104,7 +104,7 @@ class NestingSpecs { | |||
| return name; | |||
| } | |||
| void Check(std::shared_ptr<FuncGraphAnalysis> results) { | |||
| void Check(std::shared_ptr<DepComputer> results) { | |||
| if (expected_.empty() && expected_recursive_.empty()) { | |||
| return; | |||
| } | |||
| @@ -120,18 +120,6 @@ class NestingSpecs { | |||
| CheckRecursive(recursive); | |||
| return; | |||
| } | |||
| auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results); | |||
| if (counter_g != nullptr) { | |||
| CheckGraphCounter(counter_g); | |||
| return; | |||
| } | |||
| auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results); | |||
| if (counter_p != nullptr) { | |||
| CheckAnfNodeCounter(counter_p); | |||
| return; | |||
| } | |||
| } | |||
| private: | |||
| @@ -193,59 +181,6 @@ class NestingSpecs { | |||
| ASSERT_EQ(clean_results, expected_); | |||
| } | |||
| // Add CheckNesting function | |||
| void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { | |||
| std::map<std::string, std::set<std::string>> clean_results; | |||
| for (auto& iter : results->count_nodes_map()) { | |||
| auto key = iter.first; | |||
| auto value = iter.second; | |||
| if (key == nullptr) { | |||
| continue; | |||
| } | |||
| std::string k = Name(key); | |||
| std::set<std::string> v; | |||
| for (auto& node : value) { | |||
| auto fg = node.first; | |||
| if (!Name(fg).empty()) { | |||
| v.insert(Name(fg)); | |||
| } | |||
| } | |||
| if (!v.empty()) { | |||
| clean_results[k] = v; | |||
| } | |||
| } | |||
| ASSERT_EQ(clean_results, expected_); | |||
| } | |||
| void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { | |||
| std::map<std::string, std::set<std::string>> clean_results; | |||
| for (auto& iter : results->count_func_graphs_map()) { | |||
| auto key = iter.first; | |||
| auto value = iter.second; | |||
| if (key == nullptr) { | |||
| continue; | |||
| } | |||
| std::string k = Name(key); | |||
| std::set<std::string> v; | |||
| for (auto& node : value) { | |||
| auto fg = node.first; | |||
| if (!Name(fg).empty()) { | |||
| v.insert(Name(fg)); | |||
| } | |||
| } | |||
| if (!v.empty()) { | |||
| clean_results[k] = v; | |||
| } | |||
| } | |||
| ASSERT_EQ(clean_results, expected_); | |||
| } | |||
| void CheckRecursive(std::shared_ptr<RecursiveComputer> results) { | |||
| std::map<std::string, bool> clean_results; | |||
| for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) { | |||