From 7c7006f347d3583c43a524e8f8100b64f03b06e1 Mon Sep 17 00:00:00 2001 From: lichenever Date: Wed, 4 Nov 2020 17:18:29 +0800 Subject: [PATCH] fix bug if input not used --- .../ccsrc/frontend/parallel/step_parallel.cc | 66 +++++++++++++++++++ .../ccsrc/frontend/parallel/step_parallel.h | 2 + 2 files changed, 68 insertions(+) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 33f403616c..6a5dec9d52 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -54,6 +54,7 @@ static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; // g_RefMap, for CNode B input i is a RefKey[Parameter C], // it will be one item in map with key: C, and value: (B, i) static std::map> g_RefMap; +static void HandleNoUsedParameter(const FuncGraphPtr &root); void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { @@ -3032,6 +3033,68 @@ void CheckParameterSplit(const std::vector &all_nodes) { } } +bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(parameter); + auto manager = graph->manager(); + auto node_users = manager->node_users()[parameter]; + if (node_users.empty()) { + return false; + } + for (auto node_user : node_users) { + auto use_node = node_user.first->cast(); + if (IsValueNode(use_node->input(0))) { + auto graph_sub = GetValueNode(use_node->input(0)); + auto parameters = graph_sub->parameters(); + auto parameter_sub = parameters[node_user.second - 1]; + return IsUsedParameter(graph_sub, parameter_sub); + } + if (use_node->input(0)->isa()) { + auto cnode = use_node->input(0)->cast(); + if (!IsSomePrimitive(cnode, J) || !IsValueNode(cnode->input(1))) { + return true; + } + auto graph_sub = GetValueNode(cnode->input(1)); + auto parameters = graph_sub->parameters(); + auto parameter_sub = parameters[node_user.second - 1]; + return IsUsedParameter(graph_sub, parameter_sub); + } + return true; + } + return true; +} + +static void HandleNoUsedParameter(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + if (full_batch) { + return; + } + auto dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto parameters = root->parameters(); + for (auto ¶meter : parameters) { + if (IsUsedParameter(root, parameter)) { + continue; + } + auto parameter_shape = GetNodeShape(parameter); + if (parameter_shape.empty()) { + continue; + } + Shape slice_shape = parameter_shape[0]; + if (slice_shape.empty()) { + continue; + } + slice_shape[0] = slice_shape[0] / dev_num; + auto slice_shape_ptr = std::make_shared(slice_shape); + auto abstract = parameter->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto abstract_cloned = abstract->Clone(); + MS_EXCEPTION_IF_NULL(abstract_cloned); + abstract_cloned->set_shape(slice_shape_ptr); + parameter->set_abstract(abstract_cloned); + } +} + bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); @@ -3103,6 +3166,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // cover Parallel shape CoverSliceShape(root); + // handle input is not used + HandleNoUsedParameter(root); + // set the shape for optimizer's clone tensor SetClonedTensorShapeForOptimizer(root); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 00f60b39b6..4375853608 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -161,6 +161,8 @@ std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node); ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); +bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter); + void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode, const AnfNodePtr ¶meter, size_t index); } // namespace parallel