Browse Source

!15282 get target for updatestate

From: @kisnwang
Reviewed-by: @chujinjin,@zhoufeng54
Signed-off-by: @zhoufeng54
pull/15282/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
e7dfb7beeb
1 changed files with 30 additions and 15 deletions
  1. +30
    -15
      mindspore/core/ir/anf.cc

+ 30
- 15
mindspore/core/ir/anf.cc View File

@@ -412,6 +412,21 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar
}
return target;
}

PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
auto value = value_node->value();
if (value == nullptr || !value->isa<Primitive>()) {
return nullptr;
}
return value->cast<PrimitivePtr>();
}
} // namespace

std::string GetCNodeTarget(const AnfNodePtr &node) {
@@ -429,29 +444,29 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
return *ud_target.get();
}
auto attr_input = cnode->input(0);
if (attr_input == nullptr) {
return default_target;
}
auto value_node = attr_input->cast<ValueNodePtr>();
if (value_node == nullptr) {
return default_target;
}
auto value = value_node->value();
if (value == nullptr) {
auto primitive = GetPrimitiveFromValueNode(attr_input);
if (primitive == nullptr) {
return default_target;
}
if (!value->isa<Primitive>()) {
return default_target;
}
auto primitive = value->cast<PrimitivePtr>();
auto att_target = primitive->GetAttr(primitive_target);
if (att_target != nullptr) {
return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target);
}
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
auto &inputs = cnode->inputs();
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[1]);
if (inputs.size() >= 3) {
size_t use_index = 1;
if (!inputs[use_index]->isa<CNode>()) {
use_index = 2;
}
if (!IsPrimitiveCNode(inputs[use_index], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[use_index]);
}
}
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
auto &inputs = cnode->inputs();
if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[2], prim::kPrimMakeTuple)) {
return GetCNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return GetMaketupleNodeTarget(cnode);


Loading…
Cancel
Save