Browse Source

insert mirror before load

tags/v1.2.0-rc1
yangzhenzhang 5 years ago
parent
commit
6eadd241a0
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 3
- 3
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -1202,7 +1202,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
if (node->input(index)->isa<CNode>()) {
auto pre_cnode = node->input(index)->cast<CNodePtr>();
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST) {
if ((pre_prim->name() == CAST) || (pre_prim->name() == LOAD)) {
manager->SetEdge(pre_cnode, 1, next_cnode.second);
continue;
}
@@ -1217,10 +1217,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size();
}
std::string instance_name = MIRROR_OP;
if (IsCastBeforMirror(node, index)) {
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
for (auto &op : backward_op) {
// insert new node before the node
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);


Loading…
Cancel
Save