Browse Source

!6888 [AutoParallel]Support GPT2 to compile graph

Merge pull request !6888 from lichen/support_gpt2_compile_graph
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f0fcc3653b
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 3
- 2
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -2528,7 +2528,7 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN
for (auto &candidate : candidate_set) {
auto candidate_node = candidate.first;
auto c = candidate_node->cast<CNodePtr>();
if (c == nullptr || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) {
if (c == nullptr || !c->has_user_data<OperatorInfo>()) {
continue;
}
(void)parameter_user_info.second.second.insert(candidate);
@@ -2688,7 +2688,6 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)

// mark the forward cnodes, parallel only care these nodes
MarkForwardCNode(root);
HandleRootReshape(all_nodes);

if (FindCommunicationOp(all_nodes)) {
MS_LOG(EXCEPTION) << "The graph contain communication op";
@@ -2699,6 +2698,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ReshapeInit(all_nodes);
}

HandleRootReshape(all_nodes);

HandleForwardMakeTupleAndMakeList(all_nodes);

// if the input or parameter has multiple users, check whether its split strategies are consistent.


Loading…
Cancel
Save