|
|
|
@@ -62,8 +62,6 @@ namespace parallel { |
|
|
|
static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; |
|
|
|
static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE}; |
|
|
|
static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL}; |
|
|
|
static const std::vector<std::pair<const std::string, int64_t>> REDUCE_SUM_MATCH_PATTERN = { |
|
|
|
std::make_pair(MAKE_TUPLE, 1), std::make_pair(ADDN, 1), std::make_pair(SQRT, 1)}; |
|
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C], |
|
|
|
// it will be one item in map with key: C, and value: (B, i) |
|
|
|
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; |
|
|
|
@@ -3179,9 +3177,22 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); |
|
|
|
auto expand_dims_node = node_user_map.at(res_node).front().first; |
|
|
|
auto sqrt_node = MatchPattern(expand_dims_node, node_user_map, REDUCE_SUM_MATCH_PATTERN); |
|
|
|
if (!sqrt_node) return; |
|
|
|
auto find_node = res_node; |
|
|
|
uint32_t limits = 0; |
|
|
|
while (!IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT) && limits < MAX_BFS_DEPTH) { |
|
|
|
auto users = node_user_map.at(find_node); |
|
|
|
if (users.empty()) return; |
|
|
|
find_node = users.front().first; |
|
|
|
++limits; |
|
|
|
} |
|
|
|
if (!find_node || !IsSomePrimitive(find_node->cast<CNodePtr>(), SQRT)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto anf_node = find_node->cast<CNodePtr>(); |
|
|
|
if (anf_node->inputs().size() > 1 && IsSomePrimitive(anf_node->input(1)->cast<CNodePtr>(), ALL_REDUCE)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto sqrt_node = find_node; |
|
|
|
auto cur_stage_rank_list = g_device_manager->GetDeviceListInThisStage(); |
|
|
|
Group cur_stage_device_list; |
|
|
|
if (g_device_manager->CreateGroup(cur_stage_rank_list, &cur_stage_device_list) != SUCCESS) { |
|
|
|
@@ -3245,27 +3256,22 @@ AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUser |
|
|
|
|
|
|
|
static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, const AnfNodePtr ¶meter, |
|
|
|
uint32_t dev_num) { |
|
|
|
AnfNodePtr expand_dims_node = nullptr; |
|
|
|
AnfNodePtr prefix_node = nullptr; |
|
|
|
auto params_user_set = node_user_map.at(parameter); |
|
|
|
for (auto ¶m_pair : params_user_set) { |
|
|
|
expand_dims_node = nullptr; |
|
|
|
auto cnode = param_pair.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->in_forward_flag()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH); |
|
|
|
if (!expand_dims_node) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH); |
|
|
|
if (!expand_dims_node) continue; |
|
|
|
auto value = GetAttrsFromAnfNode(expand_dims_node, GRAD_SCALE); |
|
|
|
if (!value || !GetValue<bool>(value)) { |
|
|
|
continue; |
|
|
|
if (!value || !GetValue<bool>(value)) continue; |
|
|
|
if (dev_num > 0) { |
|
|
|
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV); |
|
|
|
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->fullname_with_scope() |
|
|
|
<< "succeed!"; |
|
|
|
} |
|
|
|
InsertRealDivOpToNodeInput(expand_dims_node->cast<CNodePtr>(), dev_num, PARALLEL_GLOBALNORM_DIV); |
|
|
|
MS_LOG(INFO) << "Insert the realdiv with " << dev_num << " for the parameter " << parameter->DebugString() |
|
|
|
<< "succeed!"; |
|
|
|
// If already inserted allreduce, the pattern will not be matched and thus no allreduce will be inserted. |
|
|
|
InsertAllReduceForNormValue(expand_dims_node); |
|
|
|
} |
|
|
|
@@ -3302,22 +3308,15 @@ static void HandlGlobalNormScale(const FuncGraphPtr &root, const std::vector<Anf |
|
|
|
auto parameters = root->parameters(); |
|
|
|
auto node_user_map = manager->node_users(); |
|
|
|
MS_LOG(INFO) << "Start to process the global norm"; |
|
|
|
|
|
|
|
for (auto ¶meter : parameters) { |
|
|
|
int64_t dev_num = 0; |
|
|
|
if (!ParameterRequireGrad(parameter)) continue; |
|
|
|
auto mirror_node = GetMirrorOp(node_user_map, parameter); |
|
|
|
if (!mirror_node) continue; |
|
|
|
auto device_num_ptr = GetAttrsFromAnfNode(mirror_node, DEV_NUM); |
|
|
|
if (!device_num_ptr) { |
|
|
|
MS_LOG(ERROR) << "The mirror operator is excepted to have device number attribute, but found none. This " |
|
|
|
"will cause the global norm calculation with wrong precision."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!device_num_ptr->isa<Int64Imm>()) { |
|
|
|
MS_LOG(ERROR) << "The type of device number attribute of mirror operator is not int64."; |
|
|
|
continue; |
|
|
|
if (device_num_ptr && device_num_ptr->isa<Int64Imm>()) { |
|
|
|
dev_num = GetValue<int64_t>(device_num_ptr); |
|
|
|
} |
|
|
|
auto dev_num = device_num_ptr->cast<Int64ImmPtr>()->value(); |
|
|
|
if (dev_num == 0) continue; |
|
|
|
InsertDivAndAllReduceForNorm(node_user_map, parameter, dev_num); |
|
|
|
} |
|
|
|
} |
|
|
|
|