Browse Source

!6973 [AutoParallel]Semi_auto_parallel_support_gpt2

Merge pull request !6973 from lichen/semi_auto_support_gpt2
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
31f37a3c01
1 changed files with 9 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 9
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -876,6 +876,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
if (next_cnode.first) {
MS_EXCEPTION_IF_NULL(next_cnode.second);
// param->cast->op, insert mirror before cast
if (node->input(index)->isa<CNode>()) {
auto pre_cnode = node->input(index)->cast<CNodePtr>();
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST) {
manager->SetEdge(pre_cnode, 1, next_cnode.second);
continue;
}
}
manager->SetEdge(node, SizeToInt(index), next_cnode.second);
continue;
}


Loading…
Cancel
Save