|
|
|
@@ -876,6 +876,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { |
|
|
|
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead |
|
|
|
if (next_cnode.first) { |
|
|
|
MS_EXCEPTION_IF_NULL(next_cnode.second); |
|
|
|
// param->cast->op, insert mirror before cast |
|
|
|
if (node->input(index)->isa<CNode>()) { |
|
|
|
auto pre_cnode = node->input(index)->cast<CNodePtr>(); |
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
if (pre_prim->name() == CAST) { |
|
|
|
manager->SetEdge(pre_cnode, 1, next_cnode.second); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
manager->SetEdge(node, SizeToInt(index), next_cnode.second); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|