Browse Source

fix bug if input not used

tags/v1.1.0
lichenever 5 years ago
parent
commit
7c7006f347
2 changed files with 68 additions and 0 deletions
  1. +66
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +2
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.h

+ 66
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -54,6 +54,7 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C], // g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i) // it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
static void HandleNoUsedParameter(const FuncGraphPtr &root);


void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) { if (new_node_input.empty()) {
@@ -3032,6 +3033,68 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }


bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(parameter);
auto manager = graph->manager();
auto node_users = manager->node_users()[parameter];
if (node_users.empty()) {
return false;
}
for (auto node_user : node_users) {
auto use_node = node_user.first->cast<CNodePtr>();
if (IsValueNode<FuncGraph>(use_node->input(0))) {
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
auto parameters = graph_sub->parameters();
auto parameter_sub = parameters[node_user.second - 1];
return IsUsedParameter(graph_sub, parameter_sub);
}
if (use_node->input(0)->isa<CNode>()) {
auto cnode = use_node->input(0)->cast<CNodePtr>();
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
return true;
}
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
auto parameters = graph_sub->parameters();
auto parameter_sub = parameters[node_user.second - 1];
return IsUsedParameter(graph_sub, parameter_sub);
}
return true;
}
return true;
}

static void HandleNoUsedParameter(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
bool full_batch = ParallelContext::GetInstance()->full_batch();
if (full_batch) {
return;
}
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto parameters = root->parameters();
for (auto &parameter : parameters) {
if (IsUsedParameter(root, parameter)) {
continue;
}
auto parameter_shape = GetNodeShape(parameter);
if (parameter_shape.empty()) {
continue;
}
Shape slice_shape = parameter_shape[0];
if (slice_shape.empty()) {
continue;
}
slice_shape[0] = slice_shape[0] / dev_num;
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto abstract_cloned = abstract->Clone();
MS_EXCEPTION_IF_NULL(abstract_cloned);
abstract_cloned->set_shape(slice_shape_ptr);
parameter->set_abstract(abstract_cloned);
}
}

bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
@@ -3103,6 +3166,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// cover Parallel shape // cover Parallel shape
CoverSliceShape(root); CoverSliceShape(root);


// handle input is not used
HandleNoUsedParameter(root);

// set the shape for optimizer's clone tensor // set the shape for optimizer's clone tensor
SetClonedTensorShapeForOptimizer(root); SetClonedTensorShapeForOptimizer(root);




+ 2
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -161,6 +161,8 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);


ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));


bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter);

void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index); const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index);
} // namespace parallel } // namespace parallel


Loading…
Cancel
Save