Merge pull request !27757 from 张清华/opt_cell_list_getattr2tags/v1.6.0
| @@ -36,45 +36,41 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr | |||
| constexpr auto recursive_level = 3; | |||
| MS_LOG(DEBUG) << "getattr_operand_node: " << getattr_operand_node->DebugString(recursive_level); | |||
| // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, ...}} | |||
| // {prim::GetAttr, {{prim::Resolve, ..., 'getitem'}, {prim::Resolve, ...}, index}, attr} | |||
| auto getitem_cnode = getattr_operand_node->cast<CNodePtr>(); | |||
| if (getitem_cnode != nullptr) { | |||
| constexpr size_t getitem_inputs_size = 3; | |||
| if (getitem_cnode != nullptr && getitem_cnode->size() == getitem_inputs_size) { | |||
| constexpr size_t prim_index = 0; | |||
| auto primitive_node = getitem_cnode->input(prim_index); | |||
| auto resolved_getitem_node = primitive_node; | |||
| if (IsPrimitiveCNode(primitive_node, prim::kPrimResolve)) { | |||
| auto resolve_getitem_cnode = primitive_node->cast<CNodePtr>(); | |||
| auto resolve_getitem_node = getitem_cnode->input(prim_index); | |||
| constexpr size_t resolve_index = 1; | |||
| auto resolve_node = getitem_cnode->input(resolve_index); | |||
| if (IsPrimitiveCNode(resolve_getitem_node, prim::kPrimResolve) && | |||
| IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { | |||
| auto resolve_getitem_cnode = resolve_getitem_node->cast<CNodePtr>(); | |||
| auto resolve_getitem_symbol = GetValueNode<parse::SymbolPtr>(resolve_getitem_cnode->input(2)); | |||
| constexpr auto getitem_symbol = "getitem"; | |||
| if (resolve_getitem_symbol->symbol() == getitem_symbol) { | |||
| auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1)); | |||
| resolved_getitem_node = | |||
| ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node); | |||
| } | |||
| } | |||
| bool is_getattr_getitem = false; | |||
| auto do_signature = dyn_cast<prim::DoSignaturePrimitive>(GetValueNode(resolved_getitem_node)); | |||
| if (do_signature != nullptr) { | |||
| auto &func_value = do_signature->function(); | |||
| // The function 'func_value' must be the MultitypeFuncGraph of 'getitem'. | |||
| auto multitype_fg_value = dyn_cast<prim::MultitypeFuncGraph>(func_value); | |||
| constexpr auto getitem_symbol = "getitem"; | |||
| if (multitype_fg_value != nullptr && multitype_fg_value->name() == getitem_symbol) { | |||
| is_getattr_getitem = true; | |||
| } | |||
| } | |||
| if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimTupleGetItem)) { | |||
| is_getattr_getitem = true; | |||
| } | |||
| if (is_getattr_getitem) { | |||
| constexpr size_t resolve_index = 1; | |||
| auto resolve_node = getitem_cnode->input(resolve_index); | |||
| constexpr size_t position_index = 2; | |||
| auto index_node = getitem_cnode->input(position_index); | |||
| if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve) && index_node->isa<ValueNode>()) { | |||
| constexpr size_t position_index = 2; | |||
| auto index_node = getitem_cnode->input(position_index); | |||
| auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); | |||
| auto py_item = parse::GetItemObjectFromSequence(name_space, symbol, resolve_node, index_node); | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), py_item, resolve_node, attr); | |||
| auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); | |||
| if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) { | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| auto sequence = obj.cast<py::sequence>(); | |||
| for (size_t i = 0; i < sequence.size(); ++i) { | |||
| auto res = parse::ResolveCellWithAttr(optimizer->manager(), sequence[i], resolve_node, attr); | |||
| inputs.emplace_back(res); | |||
| } | |||
| auto make_tuple_node = getitem_cnode->func_graph()->NewCNodeInOrder(inputs); | |||
| auto resolve_getitem_name_space = GetValueNode<parse::NameSpacePtr>(resolve_getitem_cnode->input(1)); | |||
| auto resolved_getitem_node = | |||
| ResolveSymbol(optimizer->manager(), resolve_getitem_name_space, resolve_getitem_symbol, node); | |||
| auto out = | |||
| getitem_cnode->func_graph()->NewCNodeInOrder({resolved_getitem_node, make_tuple_node, index_node}); | |||
| return out; | |||
| } | |||
| return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr); | |||
| } | |||
| } | |||
| } | |||
| @@ -311,9 +311,9 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons | |||
| } | |||
| } // namespace | |||
| // Get python object with index from a list. | |||
| py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, | |||
| const AnfNodePtr &index_node) { | |||
| // Get python object with index from a list or the whole list if the index is not fixed. | |||
| py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, | |||
| const AnfNodePtr &index_node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info())); | |||
| if (node->func_graph() == nullptr) { | |||
| @@ -329,15 +329,18 @@ py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const Symbo | |||
| MS_LOG(EXCEPTION) << "Should not get item from non-sequence type, obj: " << py::str(obj); | |||
| } | |||
| const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE; | |||
| const std::string module = "mindspore._extends.parse.parser"; | |||
| MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index_node: " << index_node->ToString(); | |||
| auto imm_value = GetValueNode<Int64ImmPtr>(index_node); | |||
| if (imm_value == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Expect an int64 value node, node: " << node->DebugString() | |||
| << ", index_node: " << index_node->DebugString(); | |||
| MS_LOG(DEBUG) << "The index is not a value node, so we return the whole list, node: " << node->DebugString() | |||
| << ", index_node: " << index_node->DebugString(); | |||
| // Index is not fixed, return the whole list. | |||
| return obj; | |||
| } | |||
| // It index is a value node, get the item of index directly. | |||
| const std::string fn = PYTHON_MOD_GET_ITEM_FROM_SEQUENCE; | |||
| const std::string module = "mindspore._extends.parse.parser"; | |||
| int index = imm_value->value(); | |||
| MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", index: " << index; | |||
| py::object item_obj = parse::python_adapter::GetPyFn(module, fn)(obj, py::int_(index)); | |||
| return item_obj; | |||
| } | |||
| @@ -179,9 +179,9 @@ class SymbolResolver { | |||
| }; | |||
| using SymbolResolverPtr = std::shared_ptr<SymbolResolver>; | |||
| // Get python object with index from a list. | |||
| py::object GetItemObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, | |||
| const AnfNodePtr &index_node); | |||
| // Get python object with index from a list or the whole list if the index is not fixed. | |||
| py::object GetObjectFromSequence(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node, | |||
| const AnfNodePtr &index_node); | |||
| std::pair<parse::NameSpacePtr, parse::SymbolPtr> GetNamespaceAndSymbol(const AnfNodePtr &node); | |||
| // Get resolved python object by namespace and symbol. | |||
| @@ -13,13 +13,12 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test a list of cell, and getattr by its item """ | |||
| import numpy as np | |||
| from mindspore import context, nn, dtype, Tensor | |||
| from mindspore.ops import operations as P | |||
| class Actor(nn.Cell): | |||
| def __init__(self): | |||
| super(Actor, self).__init__() | |||
| def act(self, x, y): | |||
| return x + y | |||
| @@ -44,4 +43,40 @@ def test_list_item_getattr(): | |||
| trainer = Trainer(actor_list) | |||
| x = Tensor([3], dtype=dtype.float32) | |||
| y = Tensor([6], dtype=dtype.float32) | |||
| print(trainer(x, y)) | |||
| res = trainer(x, y) | |||
| print(f'res: {res}') | |||
| expect_res = Tensor([9], dtype=dtype.float32) | |||
| assert np.array_equal(res.asnumpy(), expect_res.asnumpy()) | |||
| class Trainer2(nn.Cell): | |||
| def __init__(self, net_list): | |||
| super(Trainer2, self).__init__() | |||
| self.net_list = net_list | |||
| self.less = P.Less() | |||
| self.zero_float = Tensor(0, dtype=dtype.float32) | |||
| def construct(self, x, y): | |||
| sum_value = self.zero_float | |||
| num_actor = 0 | |||
| while num_actor < 3: | |||
| sum_value += self.net_list[num_actor].act(x, y) | |||
| num_actor += 1 | |||
| return sum_value | |||
| def test_list_item_getattr2(): | |||
| """ | |||
| Feature: getattr by the item from list of cell with a Tensor variable. | |||
| Description: Support RL use method in graph mode. | |||
| Expectation: No exception. | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| actor_list = [Actor(), Actor(), Actor()] | |||
| trainer = Trainer2(actor_list) | |||
| x = Tensor([3], dtype=dtype.float32) | |||
| y = Tensor([6], dtype=dtype.float32) | |||
| res = trainer(x, y) | |||
| print(f'res: {res}') | |||
| expect_res = Tensor([27], dtype=dtype.float32) | |||
| assert np.array_equal(res.asnumpy(), expect_res.asnumpy()) | |||