|
|
|
@@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { |
|
|
|
if (pre_cnode == nullptr) { |
|
|
|
return loss_node_info; |
|
|
|
} |
|
|
|
pre_cnode = HandleDependLoss(pre_cnode); |
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
// return -> cast |
|
|
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { |
|
|
|
if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { |
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode); |
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
} |
|
|
|
pre_cnode = HandleDependLoss(pre_cnode); |
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); |
|
|
|
|
|
|
|
// notice: the GetNext op has not input |
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { |
|
|
|
@@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG |
|
|
|
// shape op doesn't have params and attrs. |
|
|
|
OperatorParams params; |
|
|
|
OperatorAttrs attrs; |
|
|
|
auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(shape_value); |
|
|
|
auto shape = shape_value->value(); |
|
|
|
if (shape.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
OperatorArgs args = std::make_pair(attrs, params); |
|
|
|
Operator op = std::make_pair(SHAPE_OP, args); |
|
|
|
InsertNode(op, node, 2, pre_node, root, "shape"); |
|
|
|
|