Browse Source

!3062 fix valuenode simplify

Merge pull request !3062 from riemann_penn/fix_value_node_simplify
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bacbb26e3b
2 changed files with 26 additions and 9 deletions
  1. +16
    -9
      mindspore/ccsrc/frontend/optimizer/clean.cc
  2. +10
    -0
      tests/ut/python/pynative_mode/test_framstruct.py

+ 16
- 9
mindspore/ccsrc/frontend/optimizer/clean.cc View File

@@ -43,26 +43,28 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
return nullptr; return nullptr;
} }


AbstractBasePtr res = t;
if (t->isa<AbstractClass>()) { if (t->isa<AbstractClass>()) {
auto abs_class = dyn_cast<AbstractClass>(t); auto abs_class = dyn_cast<AbstractClass>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto attributes = abs_class->attributes(); auto attributes = abs_class->attributes();
(void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractDictionary>()) {
return std::make_shared<AbstractTuple>(baselist);
}
if (t->isa<AbstractDictionary>()) {
auto abs_dict = dyn_cast<AbstractDictionary>(t); auto abs_dict = dyn_cast<AbstractDictionary>(t);
AbstractBasePtrList baselist; AbstractBasePtrList baselist;
auto elements = abs_dict->elements(); auto elements = abs_dict->elements();
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
res = std::make_shared<AbstractTuple>(baselist);
} else if (t->isa<AbstractList>()) {
auto abs_dict = dyn_cast<AbstractList>(t);
res = std::make_shared<AbstractTuple>(abs_dict->elements());
return std::make_shared<AbstractTuple>(baselist);
}
if (t->isa<AbstractList>()) {
auto abs_list = dyn_cast<AbstractList>(t);
return std::make_shared<AbstractTuple>(abs_list->elements());
} }
return res;

return nullptr;
} }


AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
@@ -376,7 +378,12 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr


for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract()); auto ret = Reabs(node->abstract());
node->set_abstract(ret);
if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString();
node->set_abstract(ret);
changed = true;
}
} }
return changed; return changed;
} }


+ 10
- 0
tests/ut/python/pynative_mode/test_framstruct.py View File

@@ -1031,3 +1031,13 @@ def test_grad_if_defer_inline():
inp = Tensor(np.ones([128, 96]).astype(np.float32)) inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp) grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)


def test_dict_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.res = {'1': 10}
def construct(self):
return self.res
Net()()

Loading…
Cancel
Save