|
|
@@ -54,6 +54,7 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; |
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C], |
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C], |
|
|
// it will be one item in map with key: C, and value: (B, i) |
|
|
// it will be one item in map with key: C, and value: (B, i) |
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; |
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; |
|
|
|
|
|
static void HandleNoUsedParameter(const FuncGraphPtr &root); |
|
|
|
|
|
|
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { |
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { |
|
|
if (new_node_input.empty()) { |
|
|
if (new_node_input.empty()) { |
|
|
@@ -3032,6 +3033,68 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
|
|
auto manager = graph->manager(); |
|
|
|
|
|
auto node_users = manager->node_users()[parameter]; |
|
|
|
|
|
if (node_users.empty()) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
for (auto node_user : node_users) { |
|
|
|
|
|
auto use_node = node_user.first->cast<CNodePtr>(); |
|
|
|
|
|
if (IsValueNode<FuncGraph>(use_node->input(0))) { |
|
|
|
|
|
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0)); |
|
|
|
|
|
auto parameters = graph_sub->parameters(); |
|
|
|
|
|
auto parameter_sub = parameters[node_user.second - 1]; |
|
|
|
|
|
return IsUsedParameter(graph_sub, parameter_sub); |
|
|
|
|
|
} |
|
|
|
|
|
if (use_node->input(0)->isa<CNode>()) { |
|
|
|
|
|
auto cnode = use_node->input(0)->cast<CNodePtr>(); |
|
|
|
|
|
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
|
|
auto parameters = graph_sub->parameters(); |
|
|
|
|
|
auto parameter_sub = parameters[node_user.second - 1]; |
|
|
|
|
|
return IsUsedParameter(graph_sub, parameter_sub); |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static void HandleNoUsedParameter(const FuncGraphPtr &root) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch(); |
|
|
|
|
|
if (full_batch) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); |
|
|
|
|
|
auto parameters = root->parameters(); |
|
|
|
|
|
for (auto ¶meter : parameters) { |
|
|
|
|
|
if (IsUsedParameter(root, parameter)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
auto parameter_shape = GetNodeShape(parameter); |
|
|
|
|
|
if (parameter_shape.empty()) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
Shape slice_shape = parameter_shape[0]; |
|
|
|
|
|
if (slice_shape.empty()) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
slice_shape[0] = slice_shape[0] / dev_num; |
|
|
|
|
|
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape); |
|
|
|
|
|
auto abstract = parameter->abstract(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract); |
|
|
|
|
|
auto abstract_cloned = abstract->Clone(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract_cloned); |
|
|
|
|
|
abstract_cloned->set_shape(slice_shape_ptr); |
|
|
|
|
|
parameter->set_abstract(abstract_cloned); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { |
|
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { |
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
MS_EXCEPTION_IF_NULL(optimizer); |
|
|
MS_EXCEPTION_IF_NULL(optimizer); |
|
|
@@ -3103,6 +3166,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) |
|
|
// cover Parallel shape |
|
|
// cover Parallel shape |
|
|
CoverSliceShape(root); |
|
|
CoverSliceShape(root); |
|
|
|
|
|
|
|
|
|
|
|
// handle input is not used |
|
|
|
|
|
HandleNoUsedParameter(root); |
|
|
|
|
|
|
|
|
// set the shape for optimizer's clone tensor |
|
|
// set the shape for optimizer's clone tensor |
|
|
SetClonedTensorShapeForOptimizer(root); |
|
|
SetClonedTensorShapeForOptimizer(root); |
|
|
|
|
|
|
|
|
|