Browse Source

!169 fix the method to calculate the children graph

Merge pull request !169 from xychow/fix-manager-children-issue
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
066f20e791
3 changed files with 18 additions and 61 deletions
  1. +6
    -32
      mindspore/ccsrc/ir/manager.cc
  2. +2
    -12
      mindspore/ccsrc/ir/manager.h
  3. +10
    -17
      tests/ut/python/pynative_mode/test_framstruct.py

+ 6
- 32
mindspore/ccsrc/ir/manager.cc View File

@@ -985,40 +985,14 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
}
}

// children include:
// A. func graphs which use variables in fg as free variables; (child_direct_)
// B. func graphs which call func func graph in A. (all_users_)
FuncGraphSetPtr ChildrenComputer::SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
if (path == nullptr || path->contains(fg)) {
return std::make_shared<FuncGraphSet>();
}
std::shared_ptr<FuncGraphSet> children = std::make_shared<FuncGraphSet>();
auto& deps = *child_direct_;
auto& users = *all_users_;
MS_LOG(DEBUG) << "" << fg->ToString() << " start func graph dep size:" << deps[fg].size();
for (auto& dep : deps[fg]) {
FuncGraphPtr child = dep.first;
children->add(child);
path->add(child);
MS_LOG(DEBUG) << "Child func graph:" << fg->ToString() << " child " << child->ToString();
for (auto& user : users[child]) {
auto user_func_graph = user.first;
MS_LOG(DEBUG) << "Func graph:" << fg->ToString() << " user " << user_func_graph->ToString();
children->add(user_func_graph);
path->add(user_func_graph);
}
children->update(SeekChildren(child, path));
}
(void)children->erase(fg);
MS_LOG(DEBUG) << "End in children: " << children->size();
return children;
}

void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
child_direct_ = &manager_->func_graph_child_direct();
all_users_ = &manager_->func_graph_users();
children_analysis_[fg].update(SeekChildren(fg));
auto used_fg_total = manager_->func_graphs_used_total(fg);
for (auto& used_fg : used_fg_total) {
if (manager_->parent(used_fg) == fg) {
children_analysis_[fg].add(used_fg);
}
}
}

void ScopeComputer::RealRecompute(FuncGraphPtr fg) {


+ 2
- 12
mindspore/ccsrc/ir/manager.h View File

@@ -398,11 +398,8 @@ class ParentComputer final : public DepComputer {
// graph's children graph except self
class ChildrenComputer final : public DepComputer {
public:
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m), all_users_(nullptr), child_direct_(nullptr) {}
~ChildrenComputer() override {
all_users_ = nullptr;
child_direct_ = nullptr;
}
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
~ChildrenComputer() override = default;

FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }

@@ -414,13 +411,6 @@ class ChildrenComputer final : public DepComputer {
void ExtraReset() override { children_analysis_.clear(); }

void RealRecompute(FuncGraphPtr fg) override;

private:
FuncGraphSetPtr SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
// when SeekChildren calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 2 parameters for SeekChildren().
FuncGraphToFuncGraphCounterMap* all_users_;
FuncGraphToFuncGraphCounterMap* child_direct_;
};

// graph's children graph include self


+ 10
- 17
tests/ut/python/pynative_mode/test_framstruct.py View File

@@ -38,16 +38,6 @@ def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)


@ms_function
def refactor_fac(n):
""" grad_refactor_fac """
if n == 0:
return 1
return n * refactor_fac(n-1)
def test_refactor():
res = refactor_fac(3)
assert res == 6

@ms_function
def while_upper_bound(upper):
rval = 2
@@ -386,16 +376,19 @@ def test_grad_while():
assert grad_while(5) == (60,)

@ms_function
def fac(n):
""" fac """
def factorial(n):
""" factorial """
if n == 0:
return 1
return n * fac(n-1)
return n * factorial(n-1)

def test_factorial():
res = factorial(3)
assert res == 6

def test_fac():
""" test_fac """
res = fac(4)
assert res == 24
def test_grad_factorial():
res = C.grad(factorial)(3)
assert res == 11

def _for(x):
""" _for """


Loading…
Cancel
Save