diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 80c4a936d8..b7324bb4ea 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -412,6 +412,21 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar } return target; } + +PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) { + if (node == nullptr) { + return nullptr; + } + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + auto value = value_node->value(); + if (value == nullptr || !value->isa()) { + return nullptr; + } + return value->cast(); +} } // namespace std::string GetCNodeTarget(const AnfNodePtr &node) { @@ -429,29 +444,29 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { return *ud_target.get(); } auto attr_input = cnode->input(0); - if (attr_input == nullptr) { - return default_target; - } - auto value_node = attr_input->cast(); - if (value_node == nullptr) { - return default_target; - } - auto value = value_node->value(); - if (value == nullptr) { + auto primitive = GetPrimitiveFromValueNode(attr_input); + if (primitive == nullptr) { return default_target; } - if (!value->isa()) { - return default_target; - } - auto primitive = value->cast(); auto att_target = primitive->GetAttr(primitive_target); if (att_target != nullptr) { return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target); } if (IsPrimitiveCNode(node, prim::kPrimDepend)) { auto &inputs = cnode->inputs(); - if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) { - return GetCNodeTarget(inputs[1]); + if (inputs.size() >= 3) { + size_t use_index = 1; + if (!inputs[use_index]->isa()) { + use_index = 2; + } + if (!IsPrimitiveCNode(inputs[use_index], prim::kPrimMakeTuple)) { + return GetCNodeTarget(inputs[use_index]); + } + } + } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + auto &inputs = cnode->inputs(); + if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) { + return GetCNodeTarget(inputs[2]); } } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { return GetMaketupleNodeTarget(cnode);