|
|
|
@@ -1202,7 +1202,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons |
|
|
|
if (node->input(index)->isa<CNode>()) { |
|
|
|
auto pre_cnode = node->input(index)->cast<CNodePtr>(); |
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
if (pre_prim->name() == CAST) { |
|
|
|
if ((pre_prim->name() == CAST) || (pre_prim->name() == LOAD)) { |
|
|
|
manager->SetEdge(pre_cnode, 1, next_cnode.second); |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -1217,10 +1217,10 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons |
|
|
|
MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size(); |
|
|
|
} |
|
|
|
std::string instance_name = MIRROR_OP; |
|
|
|
if (IsCastBeforMirror(node, index)) { |
|
|
|
CNodePtr cnode = node->input(index)->cast<CNodePtr>(); |
|
|
|
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) { |
|
|
|
for (auto &op : backward_op) { |
|
|
|
// insert new node before the node |
|
|
|
CNodePtr cnode = node->input(index)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
AnfNodePtr pre_node = cnode->input(1); |
|
|
|
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); |
|
|
|
|