|
|
|
@@ -1683,7 +1683,10 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(pre_node); |
|
|
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
if (pre_cnode == nullptr) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
// return -> cast |
|
|
|
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { |
|
|
|
@@ -1907,21 +1910,6 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
AnfNodePtr root_return_node = root->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(root_return_node); |
|
|
|
std::vector<CNodePtr> loss_node; |
|
|
|
const auto &all_nodes = root->nodes(); |
|
|
|
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); |
|
|
|
if (graph_set.empty()) { |
|
|
|
loss_node.push_back(FindLossCNode(root)); |
|
|
|
} |
|
|
|
(void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), |
|
|
|
[](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); |
|
|
|
return loss_node; |
|
|
|
} |
|
|
|
|
|
|
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
@@ -1968,6 +1956,10 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr & |
|
|
|
} |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1)); |
|
|
|
auto loss_cnode = FindLossCNode(func_graph); |
|
|
|
if (loss_cnode == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); |
|
|
|
sens_loss_pairs.push_back(sens_loss_pair); |
|
|
|
} |
|
|
|
@@ -2158,10 +2150,14 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) { |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> root_forward_nodes; |
|
|
|
auto loss_cnode = FindLossCNode(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode); |
|
|
|
if (loss_cnode == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode"; |
|
|
|
return root_forward_nodes; |
|
|
|
} |
|
|
|
|
|
|
|
auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); |
|
|
|
std::vector<AnfNodePtr> root_forward_nodes; |
|
|
|
for (auto &node : all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
|