|
|
|
@@ -208,20 +208,12 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An |
|
|
|
const FuncGraphPtr &root = nullptr) { |
|
|
|
// insert new node before the node |
|
|
|
FuncGraphManagerPtr manager = func_graph->manager(); |
|
|
|
auto node_user_map = manager->node_users(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
ScopePtr scope = node->scope(); |
|
|
|
MS_EXCEPTION_IF_NULL(scope); |
|
|
|
std::vector<AnfNodePtr> node_input; |
|
|
|
AnfNodePtr pre_node_ = pre_node; |
|
|
|
if (root && !param_name.empty()) { |
|
|
|
TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map); |
|
|
|
if (next_node_dtype) { |
|
|
|
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving" |
|
|
|
<< " communication."; |
|
|
|
pre_node_ = CreateFP16Cast(node, pre_node, next_node_dtype); |
|
|
|
} |
|
|
|
node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name); |
|
|
|
node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); |
|
|
|
} else { |
|
|
|
node_input = CreateInput(op, pre_node, instance_name); |
|
|
|
} |
|
|
|
@@ -1572,7 +1564,16 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group |
|
|
|
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name; |
|
|
|
} else { |
|
|
|
InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, |
|
|
|
auto pre_node = node; |
|
|
|
AnfNodePtr pre_node_ = node; |
|
|
|
auto node_user_map = manager->node_users(); |
|
|
|
TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(cnode, node_user_map); |
|
|
|
if (next_node_dtype) { |
|
|
|
MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving" |
|
|
|
<< " communication."; |
|
|
|
pre_node_ = CreateFP16Cast(cnode, pre_node, next_node_dtype); |
|
|
|
} |
|
|
|
InsertNode(op, cnode, IntToSize(res.second), pre_node_, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, |
|
|
|
root); |
|
|
|
allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>(); |
|
|
|
MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name; |
|
|
|
@@ -3188,7 +3189,7 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr FindPrimitiveWithAtrribute(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) { |
|
|
|
AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) { |
|
|
|
std::queue<AnfNodePtr> visited; |
|
|
|
AnfNodePtr queue_node = nullptr; |
|
|
|
CNodePtr cnode = nullptr; |
|
|
|
@@ -3239,7 +3240,7 @@ static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, cons |
|
|
|
if (cnode->in_forward_flag()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
expand_dims_node = FindPrimitiveWithAtrribute(cnode, node_user_map, MAX_BFS_DEPTH); |
|
|
|
expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH); |
|
|
|
if (!expand_dims_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|