|
|
|
@@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { |
|
|
|
return tuple_index_value->cast<Int32ImmPtr>()->value(); |
|
|
|
} |
|
|
|
|
|
|
|
// Judge whether the node is a loss, and if there are multiple outputs, |
|
|
|
// get which output is a grad according to the tuple getitem. |
|
|
|
// Currently, it is not supported that the sens is a tuple. |
|
|
|
LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(loss_node); |
|
|
|
FuncGraphPtr sub_graph = loss_node->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(sub_graph); |
|
|
|
CNodePtr return_node = sub_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
if (return_node->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; |
|
|
|
} |
|
|
|
AnfNodePtr pre_node = return_node->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_node); |
|
|
|
|
|
|
|
LossNodeInfo node_info; |
|
|
|
|
|
|
|
// return -> cast |
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { |
|
|
|
pre_node = pre_cnode->input(1); |
|
|
|
} |
|
|
|
|
|
|
|
// return -> loss |
|
|
|
if (pre_node == loss_node) { |
|
|
|
node_info.has_tuple_getitem = false; |
|
|
|
node_info.dout_index = 0; |
|
|
|
return node_info; |
|
|
|
} |
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss |
|
|
|
auto cnode = pre_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto current_value = cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_value); |
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(current_prim); |
|
|
|
// size of common cnode is larger than 1 |
|
|
|
if (cnode->inputs().size() < 2) { |
|
|
|
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; |
|
|
|
} |
|
|
|
|
|
|
|
if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { |
|
|
|
// size of tuple_getitem cnode is 3 |
|
|
|
auto tuple_index = GetTupleGetItemIndex(cnode); |
|
|
|
node_info.has_tuple_getitem = true; |
|
|
|
node_info.dout_index = tuple_index; |
|
|
|
return node_info; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid loss"; |
|
|
|
} |
|
|
|
|
|
|
|
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
size_t node_size = node->inputs().size(); |
|
|
|
@@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, |
|
|
|
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) { |
|
|
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) { |
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
bool is_loss_cnode = |
|
|
|
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), |
|
|
|
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; }); |
|
|
|
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; }); |
|
|
|
|
|
|
|
MirrorOps mirror_ops = distribute_operator->mirror_ops(); |
|
|
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); |
|
|
|
@@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
CNodePtr HandleDependLoss(const CNodePtr &cnode) { |
|
|
|
// Handle return->depend->loss |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim->name() == DEPEND) { |
|
|
|
auto depend_before = cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(depend_before); |
|
|
|
return HandleDependLoss(depend_before); |
|
|
|
} |
|
|
|
return cnode; |
|
|
|
} |
|
|
|
|
|
|
|
LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
LossNodeInfo loss_node_info; |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
CNodePtr return_node = func_graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
@@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>(); |
|
|
|
if (pre_cnode == nullptr) { |
|
|
|
return nullptr; |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
|
|
|
|
pre_cnode = HandleDependLoss(pre_cnode); |
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
// return -> cast |
|
|
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { |
|
|
|
@@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
// notice: the GetNext op has not input |
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { |
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
loss_node_info.loss_node = pre_cnode; |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
|
|
|
|
// size of common cnode is larger than 1 |
|
|
|
@@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
|
|
|
|
// return -> tuple_getitem -> loss |
|
|
|
if (current_prim->name() == TUPLE_GETITEM) { |
|
|
|
auto tuple_index = GetTupleGetItemIndex(pre_cnode); |
|
|
|
AnfNodePtr pre_pre_node = pre_cnode->input(1); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_pre_node); |
|
|
|
|
|
|
|
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>(); |
|
|
|
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
PrimitivePtr prim = value->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
MS_LOG(DEBUG) << "The loss name is " << prim->name(); |
|
|
|
return pre_pre_cnode; |
|
|
|
loss_node_info.has_tuple_getitem = true; |
|
|
|
loss_node_info.dout_index = tuple_index; |
|
|
|
loss_node_info.loss_node = pre_pre_cnode; |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
|
|
|
|
// return -> make_tuple |
|
|
|
if (current_prim->name() == MAKE_TUPLE) { |
|
|
|
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; |
|
|
|
return nullptr; |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
|
|
|
|
// return -> loss |
|
|
|
loss_node_info.loss_node = pre_cnode; |
|
|
|
MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); |
|
|
|
return pre_cnode; |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
|
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { |
|
|
|
TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) { |
|
|
|
TensorLayouts ret; |
|
|
|
auto loss_cnode = node_info.loss_node; |
|
|
|
MS_EXCEPTION_IF_NULL(loss_cnode); |
|
|
|
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
|
|
LossNodeInfo node_info = GetLossNodeInfo(node); |
|
|
|
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
@@ -2086,9 +2043,9 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no |
|
|
|
return graph_set; |
|
|
|
} |
|
|
|
|
|
|
|
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) { |
|
|
|
void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) { |
|
|
|
CNodePtr sens_node = sens_loss_pair.first; |
|
|
|
CNodePtr loss_node = sens_loss_pair.second; |
|
|
|
auto loss_node = sens_loss_pair.second; |
|
|
|
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); |
|
|
|
if (!loss_grad_layout.empty()) { |
|
|
|
SplitSens(sens_node, loss_grad_layout[0]); |
|
|
|
@@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) { |
|
|
|
} |
|
|
|
|
|
|
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) |
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) { |
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs; |
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs; |
|
|
|
for (auto &node : root->nodes()) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
continue; |
|
|
|
@@ -2140,12 +2097,12 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr & |
|
|
|
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; |
|
|
|
} |
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1)); |
|
|
|
auto loss_cnode = FindLossCNode(func_graph); |
|
|
|
if (loss_cnode == nullptr) { |
|
|
|
auto loss_node_info = FindLossCNode(func_graph); |
|
|
|
if (loss_node_info.loss_node == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); |
|
|
|
std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info); |
|
|
|
sens_loss_pairs.push_back(sens_loss_pair); |
|
|
|
} |
|
|
|
return sens_loss_pairs; |
|
|
|
@@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
TensorRedistribution tensor_redistribution; |
|
|
|
|
|
|
|
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root); |
|
|
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root); |
|
|
|
bool has_backward = !sens_loss_pairs.empty(); |
|
|
|
// split sens must before inserting the operators. |
|
|
|
for (auto &pair : sens_loss_pairs) { |
|
|
|
@@ -2372,7 +2329,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) { |
|
|
|
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> root_forward_nodes; |
|
|
|
auto loss_cnode = FindLossCNode(graph); |
|
|
|
auto loss_cnode = FindLossCNode(graph).loss_node; |
|
|
|
if (loss_cnode == nullptr) { |
|
|
|
MS_LOG(WARNING) << "Can not find the loss cnode"; |
|
|
|
return root_forward_nodes; |
|
|
|
|