| @@ -1852,14 +1852,14 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { | |||||
| if (pre_cnode == nullptr) { | if (pre_cnode == nullptr) { | ||||
| return loss_node_info; | return loss_node_info; | ||||
| } | } | ||||
| pre_cnode = HandleDependLoss(pre_cnode); | |||||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| auto prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| // return -> cast | // return -> cast | ||||
| if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||||
| if (prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| } | } | ||||
| pre_cnode = HandleDependLoss(pre_cnode); | |||||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||||
| // notice: the GetNext op has not input | // notice: the GetNext op has not input | ||||
| if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { | ||||
| @@ -2416,6 +2416,12 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG | |||||
| // shape op doesn't have params and attrs. | // shape op doesn't have params and attrs. | ||||
| OperatorParams params; | OperatorParams params; | ||||
| OperatorAttrs attrs; | OperatorAttrs attrs; | ||||
| auto shape_value = GetValueNode(node->input(2))->cast<ValueSequeuePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(shape_value); | |||||
| auto shape = shape_value->value(); | |||||
| if (shape.empty()) { | |||||
| return; | |||||
| } | |||||
| OperatorArgs args = std::make_pair(attrs, params); | OperatorArgs args = std::make_pair(attrs, params); | ||||
| Operator op = std::make_pair(SHAPE_OP, args); | Operator op = std::make_pair(SHAPE_OP, args); | ||||
| InsertNode(op, node, 2, pre_node, root, "shape"); | InsertNode(op, node, 2, pre_node, root, "shape"); | ||||