| @@ -38,8 +38,9 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB | |||
| return nullptr; | |||
| } | |||
| bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) { | |||
| // if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first; | |||
| bool CheckIfEmbedJ(const FuncGraphPtr &func_graph) { | |||
| // if func graph also contain J(FuncGraph) or J(Primitive), then ignore this funcgraph. | |||
| // ExpandJ innermost graph first. | |||
| auto func_graph_manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(func_graph_manager); | |||
| return func_graph_manager->func_graph_j_total(func_graph); | |||
| @@ -53,9 +54,10 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r | |||
| MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); | |||
| // high_order_grad begin; | |||
| // if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first; | |||
| if (CheckIfEmbedJFuncGraph(func_graph)) { | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later"; | |||
| // if graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||
| // ExpandJ innermost graph or primitive first. | |||
| if (CheckIfEmbedJ(func_graph)) { | |||
| MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J, will expandJ later"; | |||
| return nullptr; | |||
| } | |||
| // high_order_grad end; | |||
| @@ -357,33 +357,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { | |||
| } | |||
| } | |||
| const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } | |||
| const std::unordered_map<AnfNodePtr, int> &FuncGraph::j_value_nodes() { return j_value_nodes_; } | |||
| void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { | |||
| auto &others = source->j_func_graphs(); | |||
| for (auto it = others.begin(); it != others.end(); it++) { | |||
| AddJFuncGraph(it->first, it->second); | |||
| void FuncGraph::CopyJValueNodes(const FuncGraphPtr &source) { | |||
| auto &others = source->j_value_nodes(); | |||
| for (const auto &other : others) { | |||
| AddJValueNode(other.first, other.second); | |||
| } | |||
| } | |||
| void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } | |||
| void FuncGraph::ClearJValueNodes() { j_value_nodes_.clear(); } | |||
| void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { | |||
| if (j_func_graphs_.count(fg) == 0) { | |||
| j_func_graphs_[fg] = count; | |||
| void FuncGraph::AddJValueNode(const AnfNodePtr &value_node, int count) { | |||
| if (j_value_nodes_.count(value_node) == 0) { | |||
| j_value_nodes_[value_node] = count; | |||
| } else { | |||
| j_func_graphs_[fg] += count; | |||
| j_value_nodes_[value_node] += count; | |||
| } | |||
| } | |||
| void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { | |||
| if (j_func_graphs_.count(fg) != 0) { | |||
| if (j_func_graphs_[fg] == 1) { | |||
| (void)j_func_graphs_.erase(fg); | |||
| void FuncGraph::DropJValueNode(const AnfNodePtr &value_node) { | |||
| if (j_value_nodes_.count(value_node) != 0) { | |||
| if (j_value_nodes_[value_node] == 1) { | |||
| (void)j_value_nodes_.erase(value_node); | |||
| } else { | |||
| j_func_graphs_[fg]--; | |||
| if (j_func_graphs_[fg] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg | |||
| j_value_nodes_[value_node]--; | |||
| if (j_value_nodes_[value_node] < 0) { | |||
| MS_LOG(EXCEPTION) << "Count of J ValueNode '" << value_node->DebugString() | |||
| << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); | |||
| } | |||
| } | |||
| @@ -431,7 +431,7 @@ void FuncGraph::ClearAllManagerInfo() { | |||
| ClearFuncGraphCNodesIndex(); | |||
| ClearFreeVariables(); | |||
| ClearFuncGraphsUsed(); | |||
| ClearJFuncGraphs(); | |||
| ClearJValueNodes(); | |||
| } | |||
| AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { | |||
| @@ -275,12 +275,12 @@ class FuncGraph : public FuncGraphBase { | |||
| bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); | |||
| bool DropFuncGraphUsed(FuncGraphPtr fg); | |||
| // get all value nodes of J func graph directly used by this func graph | |||
| const FuncGraphCounterMap &j_func_graphs(); | |||
| void CopyJFuncGraphs(const FuncGraphPtr &source); | |||
| void ClearJFuncGraphs(); | |||
| void AddJFuncGraph(FuncGraphPtr fg, int count = 1); | |||
| void DropJFuncGraph(FuncGraphPtr fg); | |||
| // get all value nodes in the inputs of J directly used by this func graph | |||
| const std::unordered_map<AnfNodePtr, int> &j_value_nodes(); | |||
| void CopyJValueNodes(const FuncGraphPtr &source); | |||
| void ClearJValueNodes(); | |||
| void AddJValueNode(const AnfNodePtr &value_node, int count = 1); | |||
| void DropJValueNode(const AnfNodePtr &value_node); | |||
| // get all func graphs nested used by this func graph | |||
| const FuncGraphSet &func_graphs_used_total(); | |||
| @@ -375,7 +375,7 @@ class FuncGraph : public FuncGraphBase { | |||
| AnfNodeCounterMap free_variables_; | |||
| // all value nodes calling J in the function | |||
| FuncGraphCounterMap j_func_graphs_; | |||
| std::unordered_map<AnfNodePtr, int> j_value_nodes_; | |||
| // all user value nodes of this func graph, recording by CNode and its input's index | |||
| CNodeIndexCounterMap func_graph_cnodes_index_; | |||
| @@ -486,9 +486,9 @@ void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| if (fg->AddFuncGraphUsed(used)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->AddJFuncGraph(used); | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->AddJValueNode(input); | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| if (fg->AddFreeVariable(input)) { | |||
| @@ -507,9 +507,9 @@ void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||
| if (fg->DropFuncGraphUsed(used)) { | |||
| signals_->InvalidateComputer(); | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->DropJFuncGraph(used); | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| fg->DropJValueNode(input); | |||
| } | |||
| } else if (fg != nullptr && fg != input->func_graph()) { | |||
| if (fg->DropFreeVariable(input)) { | |||
| @@ -524,7 +524,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||
| target->CopyFuncGraphCNodesIndex(source); | |||
| target->CopyFreeVariables(source); | |||
| target->CopyFuncGraphsUsed(source); | |||
| target->CopyJFuncGraphs(source); | |||
| target->CopyJValueNodes(source); | |||
| signals_->InvalidateComputer(); | |||
| source->ClearAllManagerInfo(); | |||
| } | |||
| @@ -880,32 +880,44 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F | |||
| } | |||
| bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| if (fg->seen_ == seen_num) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " had been checked"; | |||
| return false; | |||
| } | |||
| auto &j_fgs = fg->j_func_graphs(); | |||
| if (!j_fgs.empty()) { | |||
| // check g1->J(fg)->g2->g cycle; | |||
| auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) { | |||
| return iter.first->seen_ != seen_num; | |||
| }); | |||
| if (contains_j != j_fgs.end()) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; | |||
| const auto &j_values = fg->j_value_nodes(); | |||
| if (!j_values.empty()) { | |||
| auto contains_j = | |||
| std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) { | |||
| // check g1->J(fg)->g2->g cycle. | |||
| if (IsValueNode<FuncGraph>(iter.first)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(iter.first); | |||
| return func_graph->seen_ != seen_num; | |||
| } | |||
| if (IsValueNode<Primitive>(iter.first)) { | |||
| // exclude the primitive of J itself. | |||
| auto prim = GetValueNode<PrimitivePtr>(iter.first); | |||
| return prim->name() != prim::kPrimJ->name(); | |||
| } | |||
| return false; | |||
| }); | |||
| if (contains_j != j_values.end()) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->DebugString() << ")"; | |||
| return true; | |||
| } | |||
| } | |||
| fg->seen_ = seen_num; | |||
| // check if func graphs used contains J(func_graph); | |||
| // check if func graphs used contains J(func_graph) or J(Primitive) | |||
| for (auto &item : fg->func_graphs_used()) { | |||
| auto used_g = item.first; | |||
| if (SeekJ(used_g, seen_num)) { | |||
| MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; | |||
| MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() | |||
| << " which contains J(func_graph) or J(Primitive)"; | |||
| return true; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)"; | |||
| MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph) or J(Primitive)"; | |||
| return false; | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @@ -68,3 +69,44 @@ def test_user_defined_bprop(): | |||
| grad_net = TestUserDefinedBpropGradNet(net) | |||
| x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32)) | |||
| grad_net(x) | |||
| class SinNet(nn.Cell): | |||
| def __init__(self): | |||
| super(SinNet, self).__init__() | |||
| self.sin = ops.Sin() | |||
| def construct(self, x): | |||
| out = self.sin(x) | |||
| return out | |||
| class SinGrad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(SinGrad, self).__init__() | |||
| self.grad = ops.GradOperation() | |||
| self.network = network | |||
| def construct(self, x): | |||
| gout = self.grad(self.network)(x) | |||
| return gout | |||
| class SinGradSec(nn.Cell): | |||
| def __init__(self, network): | |||
| super(SinGradSec, self).__init__() | |||
| self.grad = ops.GradOperation() | |||
| self.network = network | |||
| def construct(self, x): | |||
| gout = self.grad(self.network)(x) | |||
| return gout | |||
| def test_second_grad_with_j_primitive(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = SinNet() | |||
| first_grad = SinGrad(net) | |||
| second_grad = SinGradSec(first_grad) | |||
| x = Tensor(np.array([1.0], dtype=np.float32)) | |||
| second_grad(x) | |||