|
|
|
@@ -1607,72 +1607,79 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
bool IsGradSensNode(const AnfNodePtr &node) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
CNodePtr return_node = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
if (return_node->size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; |
|
|
|
} |
|
|
|
AnfNodePtr pre_node = return_node->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_node); |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem) |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
AnfNodePtr expect_tuple_getitem = cnode->input(0); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem); |
|
|
|
if (!expect_tuple_getitem->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); |
|
|
|
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) { |
|
|
|
return false; |
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
|
|
|
|
// return -> cast |
|
|
|
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { |
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
} |
|
|
|
ValueNodePtr expect_tuple_getitem_value_node = expect_tuple_getitem_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_value_node); |
|
|
|
PrimitivePtr expect_tuple_getitem_prim = expect_tuple_getitem_value_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_prim); |
|
|
|
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { |
|
|
|
return false; |
|
|
|
|
|
|
|
// notice: the GetNext op has not input |
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { |
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem)-->cnode |
|
|
|
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_anonymous); |
|
|
|
if (!expect_anonymous->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
// size of common cnode is larger than 1 |
|
|
|
if (pre_cnode->size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; |
|
|
|
} |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); |
|
|
|
AnfNodePtr expect_j = expect_anonymous_cnode->input(0); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j); |
|
|
|
if (!expect_j->isa<CNode>()) { |
|
|
|
return false; |
|
|
|
// return -> tuple_getitem -> loss |
|
|
|
if (current_prim->name() == TUPLE_GETITEM) { |
|
|
|
AnfNodePtr pre_pre_node = pre_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_pre_node); |
|
|
|
|
|
|
|
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>(); |
|
|
|
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
PrimitivePtr prim = value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
MS_LOG(DEBUG) << "The loss name is " << prim->name(); |
|
|
|
return pre_pre_cnode; |
|
|
|
} |
|
|
|
auto expect_j_cnode = expect_j->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_cnode); |
|
|
|
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) { |
|
|
|
return false; |
|
|
|
|
|
|
|
// return -> make_tuple |
|
|
|
if (current_prim->name() == MAKE_TUPLE) { |
|
|
|
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; |
|
|
|
} |
|
|
|
ValueNodePtr expect_j_value_node = expect_j_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_value_node); |
|
|
|
PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_prim); |
|
|
|
return (expect_j_prim->name() == J); |
|
|
|
|
|
|
|
// return -> loss |
|
|
|
MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { |
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
TensorLayouts ret; |
|
|
|
if (!IsValueNode<FuncGraph>(cnode->input(1))) { |
|
|
|
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; |
|
|
|
} |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
auto loss_cnode = FindLossCNode(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode); |
|
|
|
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
LossNodeInfo node_info = GetLossNodeInfo(node); |
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
|
|
|
|
TensorLayouts ret; |
|
|
|
if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { |
|
|
|
MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; |
|
|
|
return ret; |
|
|
|
@@ -1680,7 +1687,6 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { |
|
|
|
|
|
|
|
OperatorInfoPtr operator_info = loss_cnode->operator_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(operator_info); |
|
|
|
|
|
|
|
TensorInfo loss_grad_tensor_info; |
|
|
|
size_t op_output_size = operator_info->outputs_tensor_info().size(); |
|
|
|
MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " |
|
|
|
@@ -1805,6 +1811,100 @@ void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePt |
|
|
|
HandleDropoutNode(distribute_operator, cnode); |
|
|
|
} |
|
|
|
|
|
|
|
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { |
|
|
|
// J->CNode->Graph |
|
|
|
std::set<FuncGraphPtr> graph_set; |
|
|
|
for (auto &node : root_all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
if (expect_j_prim->name() != J) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsValueNode<FuncGraph>(cnode->input(1))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
MS_LOG(DEBUG) << "Find the forward graph success"; |
|
|
|
graph_set.insert(graph); |
|
|
|
} |
|
|
|
} |
|
|
|
return graph_set; |
|
|
|
} |
|
|
|
|
|
|
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
void StepSplitSens(const AnfNodePtr &node) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem) |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
AnfNodePtr expect_tuple_getitem = cnode->input(0); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem); |
|
|
|
if (!expect_tuple_getitem->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); |
|
|
|
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto expect_tuple_getitem_prim = GetValueNode<PrimitivePtr>(expect_tuple_getitem_cnode->input(0)); |
|
|
|
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem)-->cnode |
|
|
|
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_anonymous); |
|
|
|
if (!expect_anonymous->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); |
|
|
|
AnfNodePtr expect_j = expect_anonymous_cnode->input(0); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j); |
|
|
|
if (!expect_j->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto expect_j_cnode = expect_j->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_cnode); |
|
|
|
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto expect_j_prim = GetValueNode<PrimitivePtr>(expect_j_cnode->input(0)); |
|
|
|
if (expect_j_prim->name() == J) { |
|
|
|
auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode); |
|
|
|
if (!loss_grad_layout.empty()) { |
|
|
|
SplitSens(node, loss_grad_layout[0]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
AnfNodePtr root_return_node = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(root_return_node); |
|
|
|
std::vector<CNodePtr> loss_node; |
|
|
|
const auto &all_nodes = root->nodes(); |
|
|
|
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
if (graph_set.empty()) { |
|
|
|
loss_node.push_back(FindLossCNode(root)); |
|
|
|
} |
|
|
|
(void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), |
|
|
|
[](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); |
|
|
|
return loss_node; |
|
|
|
} |
|
|
|
|
|
|
|
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, |
|
|
|
const FuncGraphManagerPtr &manager) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
@@ -1812,18 +1912,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt |
|
|
|
TensorRedistribution tensor_redistribution; |
|
|
|
AnfNodePtr grad_sens_node = nullptr; |
|
|
|
|
|
|
|
CNodePtr loss_cnode = FindLossCNodeFromRoot(root); |
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode); |
|
|
|
// get output layout of loss must before inserting the operators below |
|
|
|
TensorLayouts loss_layout = GetLossNodeGradOutputLayout(loss_cnode); |
|
|
|
|
|
|
|
std::vector<CNodePtr> loss_cnode = FindLossCNodeFromRoot(root); |
|
|
|
// split sens must before inserting the operators. |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
// find sens node |
|
|
|
if ((grad_sens_node == nullptr) && IsGradSensNode(node)) { |
|
|
|
grad_sens_node = node; |
|
|
|
MS_LOG(INFO) << "Find the sens node success"; |
|
|
|
} |
|
|
|
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. |
|
|
|
// If the type of sens node is not Tensor, it is unsupported now, do nothing default. |
|
|
|
StepSplitSens(node); |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &node : all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
@@ -1837,7 +1934,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt |
|
|
|
} |
|
|
|
|
|
|
|
bool is_loss_cnode = false; |
|
|
|
if (cnode == loss_cnode) { |
|
|
|
auto iter = std::find(loss_cnode.begin(), loss_cnode.end(), cnode); |
|
|
|
if (iter != loss_cnode.end()) { |
|
|
|
is_loss_cnode = true; |
|
|
|
} |
|
|
|
// insert forward ops |
|
|
|
@@ -1857,12 +1955,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt |
|
|
|
StepSplitTensor(node, manager); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. |
|
|
|
// If the type of sens node is not Tensor, it is unsupported now, do nothing default. |
|
|
|
if (grad_sens_node && !loss_layout.empty()) { |
|
|
|
SplitSens(grad_sens_node, loss_layout[0]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
@@ -2003,134 +2095,57 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
CNodePtr return_node = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
if (return_node->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; |
|
|
|
} |
|
|
|
AnfNodePtr pre_node = return_node->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_node); |
|
|
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
auto current_value = pre_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_value); |
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_prim); |
|
|
|
|
|
|
|
// return -> cast |
|
|
|
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { |
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
} |
|
|
|
|
|
|
|
// notice: the GetNext op has not input |
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { |
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
} |
|
|
|
|
|
|
|
// size of common cnode is larger than 1 |
|
|
|
if (pre_cnode->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; |
|
|
|
} |
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss |
|
|
|
if (current_prim->name() == TUPLE_GETITEM) { |
|
|
|
AnfNodePtr pre_pre_node = pre_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_pre_node); |
|
|
|
|
|
|
|
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>(); |
|
|
|
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
PrimitivePtr prim = value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
MS_LOG(INFO) << "The loss name is " << prim->name(); |
|
|
|
return pre_pre_cnode; |
|
|
|
} else if (current_prim->name() == MAKE_TUPLE) { |
|
|
|
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; |
|
|
|
} |
|
|
|
|
|
|
|
// return -> loss |
|
|
|
MS_LOG(INFO) << "The loss name is " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
const auto &all_nodes = root->nodes(); |
|
|
|
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
return graph_set; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { |
|
|
|
for (auto &node : root_all_nodes) { |
|
|
|
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto loss_cnode = FindLossCNode(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode); |
|
|
|
auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); |
|
|
|
std::vector<AnfNodePtr> root_forward_nodes; |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if ((cnode->inputs().size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ValueNodePtr expect_j_value_node = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_value_node); |
|
|
|
PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(expect_j_prim); |
|
|
|
if (expect_j_prim->name() != J) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Find J prim: " << expect_j_value_node->DebugString() << "."; |
|
|
|
if (IsValueNode<FuncGraph>(cnode->input(1))) { |
|
|
|
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
MS_LOG(INFO) << "Find the forward graph success"; |
|
|
|
return graph; |
|
|
|
auto root_node_id = node->UniqueIdThroughCopy(); |
|
|
|
if (loss_cnode_id == root_node_id) { |
|
|
|
root_forward_nodes = DeepLinkedGraphSearch(cnode); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
AnfNodePtr root_return_node = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(root_return_node); |
|
|
|
const auto &all_nodes = root->nodes(); |
|
|
|
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
return FindLossCNode(root); |
|
|
|
} else { |
|
|
|
return FindLossCNode(func_graph); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) { |
|
|
|
FuncGraphPtr forward_graph = root; |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
AnfNodePtr root_return_node = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(root_return_node); |
|
|
|
const auto &all_nodes = root->nodes(); |
|
|
|
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
if (func_graph != nullptr) { |
|
|
|
forward_graph = func_graph; |
|
|
|
} |
|
|
|
return forward_graph; |
|
|
|
return root_forward_nodes; |
|
|
|
} |
|
|
|
|
|
|
|
void MarkForwardCNode(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
AnfNodePtr root_return_node = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(root_return_node); |
|
|
|
auto &all_nodes = root->nodes(); |
|
|
|
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
auto all_nodes = root->nodes(); |
|
|
|
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
|
|
|
|
if (func_graph == nullptr) { |
|
|
|
// Can not find the forward graph, so the ops in root graph are forward. |
|
|
|
if (graph_set.empty()) { |
|
|
|
MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; |
|
|
|
SetForwardFlag(all_nodes); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); |
|
|
|
AnfNodePtr return_node = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
std::vector<AnfNodePtr> all_dfs_nodes = DeepLinkedGraphSearch(return_node); |
|
|
|
SetForwardFlag(all_dfs_nodes); |
|
|
|
for (auto &func_graph : graph_set) { |
|
|
|
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); |
|
|
|
auto return_node = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); |
|
|
|
SetForwardFlag(all_dfs_nodes); |
|
|
|
auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); |
|
|
|
if (root_forward_nodes.empty()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Mark forward flag for the nodes in root graph. |
|
|
|
SetForwardFlag(root_forward_nodes); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|