Browse Source

Expand J for innermost graph first when the graph also contains J primitive

tags/v1.2.0-rc1
yujianfeng 5 years ago
parent
commit
728fac6c9f
5 changed files with 104 additions and 48 deletions
  1. +7
    -5
      mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc
  2. +18
    -18
      mindspore/core/ir/func_graph.cc
  3. +7
    -7
      mindspore/core/ir/func_graph.h
  4. +30
    -18
      mindspore/core/ir/manager.cc
  5. +42
    -0
      tests/ut/python/optimizer/test_auto_grad.py

+ 7
- 5
mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc View File

@@ -38,8 +38,9 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB
return nullptr;
}

bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) {
// if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first;
bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) {
// if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph.
// ExpandJ innermost graph first.
auto func_graph_manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(func_graph_manager);
return func_graph_manager->func_graph_j_total(func_graph);
@@ -53,9 +54,10 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r
MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString();

// high_order_grad begin;
// if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first;
if (CheckIfEmbedJFuncGraph(func_graph)) {
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later";
// if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
// ExpandJ innermost graph or primitive first.
if (CheckIfEmbedJ(func_graph)) {
MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later";
return nullptr;
}
// high_order_grad end;


+ 18
- 18
mindspore/core/ir/func_graph.cc View File

@@ -357,33 +357,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
}
}

const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
const std::unordered_map<AnfNodePtr, int> &FuncGraph::j_value_nodes() { return j_value_nodes_; }

void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
auto &others = source->j_func_graphs();
for (auto it = others.begin(); it != others.end(); it++) {
AddJFuncGraph(it->first, it->second);
void FuncGraph::CopyJValueNodes(const FuncGraphPtr &source) {
auto &others = source->j_value_nodes();
for (const auto &other : others) {
AddJValueNode(other.first, other.second);
}
}

void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
void FuncGraph::ClearJValueNodes() { j_value_nodes_.clear(); }

void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
if (j_func_graphs_.count(fg) == 0) {
j_func_graphs_[fg] = count;
void FuncGraph::AddJValueNode(const AnfNodePtr &value_node, int count) {
if (j_value_nodes_.count(value_node) == 0) {
j_value_nodes_[value_node] = count;
} else {
j_func_graphs_[fg] += count;
j_value_nodes_[value_node] += count;
}
}

void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
if (j_func_graphs_.count(fg) != 0) {
if (j_func_graphs_[fg] == 1) {
(void)j_func_graphs_.erase(fg);
void FuncGraph::DropJValueNode(const AnfNodePtr &value_node) {
if (j_value_nodes_.count(value_node) != 0) {
if (j_value_nodes_[value_node] == 1) {
(void)j_value_nodes_.erase(value_node);
} else {
j_func_graphs_[fg]--;
if (j_func_graphs_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
j_value_nodes_[value_node]--;
if (j_value_nodes_[value_node] < 0) {
MS_LOG(EXCEPTION) << "Count of J ValueNode '" << value_node->DebugString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
@@ -431,7 +431,7 @@ void FuncGraph::ClearAllManagerInfo() {
ClearFuncGraphCNodesIndex();
ClearFreeVariables();
ClearFuncGraphsUsed();
ClearJFuncGraphs();
ClearJValueNodes();
}

AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {


+ 7
- 7
mindspore/core/ir/func_graph.h View File

@@ -275,12 +275,12 @@ class FuncGraph : public FuncGraphBase {
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);

// get all value nodes of J func graph directly used by this func graph
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);
// get all value nodes in the inputs of J directly used by this func graph
const std::unordered_map<AnfNodePtr, int> &j_value_nodes();
void CopyJValueNodes(const FuncGraphPtr &source);
void ClearJValueNodes();
void AddJValueNode(const AnfNodePtr &value_node, int count = 1);
void DropJValueNode(const AnfNodePtr &value_node);

// get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total();
@@ -375,7 +375,7 @@ class FuncGraph : public FuncGraphBase {
AnfNodeCounterMap free_variables_;

// all value nodes calling J in the function
FuncGraphCounterMap j_func_graphs_;
std::unordered_map<AnfNodePtr, int> j_value_nodes_;

// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;


+ 30
- 18
mindspore/core/ir/manager.cc View File

@@ -486,9 +486,9 @@ void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) {
if (fg->AddFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJFuncGraph(used);
}
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJValueNode(input);
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->AddFreeVariable(input)) {
@@ -507,9 +507,9 @@ void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) {
if (fg->DropFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJFuncGraph(used);
}
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJValueNode(input);
}
} else if (fg != nullptr && fg != input->func_graph()) {
if (fg->DropFreeVariable(input)) {
@@ -524,7 +524,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyFuncGraphCNodesIndex(source);
target->CopyFreeVariables(source);
target->CopyFuncGraphsUsed(source);
target->CopyJFuncGraphs(source);
target->CopyJValueNodes(source);
signals_->InvalidateComputer();
source->ClearAllManagerInfo();
}
@@ -880,32 +880,44 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
}

bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_EXCEPTION_IF_NULL(fg);
if (fg->seen_ == seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false;
}
auto &j_fgs = fg->j_func_graphs();
if (!j_fgs.empty()) {
// check g1->J(fg)->g2->g cycle;
auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
return iter.first->seen_ != seen_num;
});
if (contains_j != j_fgs.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
const auto &j_values = fg->j_value_nodes();
if (!j_values.empty()) {
auto contains_j =
std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
// check g1->J(fg)->g2->g cycle.
if (IsValueNode<FuncGraph>(iter.first)) {
auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
return func_graph->seen_ != seen_num;
}
if (IsValueNode<Primitive>(iter.first)) {
// exclude the primitive of J itself.
auto prim = GetValueNode<PrimitivePtr>(iter.first);
return prim->name() != prim::kPrimJ->name();
}
return false;
});
if (contains_j != j_values.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->DebugString() << ")";
return true;
}
}
fg->seen_ = seen_num;

// check if func graphs used contains J(func_graph);
// check if func graphs used contains J(func_graph) or J(Primitive)
for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first;
if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
<< " which contains J(func_graph) or J(Primitive)";
return true;
}
}
MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)";
MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph) or J(Primitive)";
return false;
}



+ 42
- 0
tests/ut/python/optimizer/test_auto_grad.py View File

@@ -15,6 +15,7 @@
import numpy as np

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
@@ -68,3 +69,44 @@ def test_user_defined_bprop():
grad_net = TestUserDefinedBpropGradNet(net)
x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32))
grad_net(x)


class SinNet(nn.Cell):
def __init__(self):
super(SinNet, self).__init__()
self.sin = ops.Sin()

def construct(self, x):
out = self.sin(x)
return out


class SinGrad(nn.Cell):
def __init__(self, network):
super(SinGrad, self).__init__()
self.grad = ops.GradOperation()
self.network = network

def construct(self, x):
gout = self.grad(self.network)(x)
return gout


class SinGradSec(nn.Cell):
def __init__(self, network):
super(SinGradSec, self).__init__()
self.grad = ops.GradOperation()
self.network = network

def construct(self, x):
gout = self.grad(self.network)(x)
return gout


def test_second_grad_with_j_primitive():
context.set_context(mode=context.GRAPH_MODE)
net = SinNet()
first_grad = SinGrad(net)
second_grad = SinGradSec(first_grad)
x = Tensor(np.array([1.0], dtype=np.float32))
second_grad(x)

Loading…
Cancel
Save