| @@ -334,6 +334,27 @@ AnfNodePtr EraseDictGetValues(const CNodePtr &node) { | |||
| return inputs[1]; | |||
| } | |||
| AnfNodePtr EraseDictItems(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &inputs = node->inputs(); | |||
| const size_t expect_inputs_size = 2; | |||
| CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node)); | |||
| const auto &tmp = inputs[0]->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tmp); | |||
| MS_EXCEPTION_IF_NULL(tmp->value()->cast<ValueTuplePtr>()); | |||
| ValuePtrList keys = tmp->value()->cast<ValueTuplePtr>()->value(); | |||
| std::vector<AnfNodePtr> outer_node{NewValueNode(prim::kPrimMakeList)}; | |||
| for (size_t i = 0; i < keys.size(); ++i) { | |||
| std::vector<AnfNodePtr> inner_node; | |||
| inner_node.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| inner_node.push_back(NewValueNode(keys[i])); | |||
| inner_node.push_back(NewCNode( | |||
| std::vector<AnfNodePtr>{NewValueNode(prim::kPrimTupleGetItem), inputs[1], NewValueNode(i)}, node->func_graph())); | |||
| outer_node.push_back(NewCNode(inner_node, node->func_graph())); | |||
| } | |||
| return NewCNode(outer_node, node->func_graph()); | |||
| } | |||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &inputs = node->inputs(); | |||
| @@ -416,6 +437,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| new_node = EraseMakeKeywordArgNode(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { | |||
| new_node = EraseExtractKeywordArg(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimDictItems)) { | |||
| new_node = EraseDictItems(cnode); | |||
| } | |||
| if (new_node != nullptr) { | |||
| @@ -143,6 +143,7 @@ BuiltInTypeMap &GetMethodMap() { | |||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||
| {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys, | |||
| {"values", prim::kPrimDictGetValues}, // P.dict_getvalues, | |||
| {"items", prim::kPrimDictItems}, // P.dict_items | |||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||
| }}, | |||
| {kObjectTypeTensorType, | |||
| @@ -116,6 +116,8 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -312,6 +312,21 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList items; | |||
| std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(items), [](const AbstractAttribute &item) { | |||
| return std::make_shared<AbstractTuple>( | |||
| AbstractBasePtrList{std::make_shared<AbstractScalar>(item.first), item.second}); | |||
| }); | |||
| return std::make_shared<AbstractList>(items); | |||
| } | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a list and an object of a subclass of AbstractBase. | |||
| @@ -151,6 +151,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}}, | |||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}}, | |||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}}, | |||
| {prim::kPrimDictItems, {InferImplDictItems, nullptr, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, nullptr, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, nullptr, true}}, | |||
| @@ -618,6 +618,7 @@ inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_g | |||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys"); | |||
| inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues"); | |||
| inline const PrimitivePtr kPrimDictItems = std::make_shared<Primitive>("dict_items"); | |||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| @@ -15,7 +15,6 @@ | |||
| """Implementation for internal polymorphism `mul` operations.""" | |||
| from . import _constexpr_utils as const_utils | |||
| from . import _compile_utils as utils | |||
| from ...composite import base | |||
| from ... import functional as F | |||
| @@ -80,7 +79,12 @@ def _list_mul_scalar(x, y): | |||
| Outputs: | |||
| List. | |||
| """ | |||
| return const_utils.sequence_mul_int(x, y) | |||
| res = [] | |||
| i = 0 | |||
| while i < y: | |||
| res += x | |||
| i += 1 | |||
| return res | |||
| @mul.register("Number", "List") | |||
| @@ -91,7 +95,12 @@ def _scalar_mul_list(x, y): | |||
| Outputs: | |||
| List. | |||
| """ | |||
| return const_utils.sequence_mul_int(y, x) | |||
| res = [] | |||
| i = 0 | |||
| while i < x: | |||
| res += y | |||
| i += 1 | |||
| return res | |||
| @mul.register("Tuple", "Number") | |||
| @@ -102,7 +111,12 @@ def _tuple_mul_scalar(x, y): | |||
| Outputs: | |||
| Tuple. | |||
| """ | |||
| return const_utils.sequence_mul_int(x, y) | |||
| res = () | |||
| i = 0 | |||
| while i < y: | |||
| res += x | |||
| i += 1 | |||
| return res | |||
| @mul.register("Number", "Tuple") | |||
| @@ -113,7 +127,12 @@ def _scalar_mul_tuple(x, y): | |||
| Outputs: | |||
| Tuple. | |||
| """ | |||
| return const_utils.sequence_mul_int(y, x) | |||
| res = () | |||
| i = 0 | |||
| while i < x: | |||
| res += y | |||
| i += 1 | |||
| return res | |||
| @mul.register("Tensor", "Tuple") | |||
| @@ -172,3 +172,20 @@ def test_dict_set_item_create_new(): | |||
| x = Tensor(np.ones([2, 2, 3], np.float32)) | |||
| net = DictSetNet() | |||
| _ = net(x) | |||
| def test_dict_items(): | |||
| """ | |||
| Description: test_dict_items | |||
| Expectation: the results are as expected | |||
| """ | |||
| class DictItemsNet(Cell): | |||
| def __init__(self): | |||
| super(DictItemsNet, self).__init__() | |||
| def construct(self, x): | |||
| return x.items() | |||
| x = {"1": Tensor(1), "2": {"test": (1, 2)}} | |||
| net = DictItemsNet() | |||
| _ = net(x) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test list mul number """ | |||
| import numpy as np | |||
| from mindspore import Tensor, context | |||
| from mindspore import nn | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.list_ = [Tensor([1, 2, 3])] | |||
| self.number1 = 5 | |||
| self.number2 = 0 | |||
| def construct(self): | |||
| return self.list_ * self.number1, self.list_ * self.number2 | |||
| def test_list_mul_number(): | |||
| """ | |||
| Description: test_list_mul_number | |||
| Expectation: the results are as expected | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| expect_ret0 = [Tensor([1, 2, 3])] * 5 | |||
| expect_ret1 = [Tensor([1, 2, 3])] * 0 | |||
| assert isinstance(net()[0], list) | |||
| assert isinstance(net()[1], list) | |||
| for i in range(len(net()[0])): | |||
| assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) | |||
| assert net()[1] == expect_ret1 | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test number mul list """ | |||
| import numpy as np | |||
| from mindspore import Tensor, context | |||
| from mindspore import nn | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.list_ = [Tensor([1, 2, 3])] | |||
| self.number1 = 5 | |||
| self.number2 = 0 | |||
| def construct(self): | |||
| return self.number1 * self.list_, self.number2 * self.list_ | |||
| def test_number_mul_list(): | |||
| """ | |||
| Description: test_number_mul_list | |||
| Expectation: the results are as expected | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| expect_ret0 = 5 * [Tensor([1, 2, 3])] | |||
| expect_ret1 = 0 * [Tensor([1, 2, 3])] | |||
| assert isinstance(net()[0], list) | |||
| assert isinstance(net()[1], list) | |||
| for i in range(len(net()[0])): | |||
| assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) | |||
| assert net()[1] == expect_ret1 | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test number mul tuple """ | |||
| import numpy as np | |||
| from mindspore import Tensor, context | |||
| from mindspore import nn | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.tuple_ = (Tensor([1, 2, 3]),) | |||
| self.number1 = 5 | |||
| self.number2 = 0 | |||
| def construct(self): | |||
| return self.number1 * self.tuple_, self.number2 * self.tuple_ | |||
| def test_number_mul_tuple(): | |||
| """ | |||
| Description: test_number_mul_tuple | |||
| Expectation: the results are as expected | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| expect_ret0 = 5 * (Tensor([1, 2, 3]),) | |||
| expect_ret1 = 0 * (Tensor([1, 2, 3]),) | |||
| assert isinstance(net()[0], tuple) | |||
| assert isinstance(net()[1], tuple) | |||
| for i in range(len(net()[0])): | |||
| assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) | |||
| assert net()[1] == expect_ret1 | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test tuple mul number """ | |||
| import numpy as np | |||
| from mindspore import Tensor, context | |||
| from mindspore import nn | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.tuple_ = (Tensor([1, 2, 3]),) | |||
| self.number1 = 5 | |||
| self.number2 = 0 | |||
| def construct(self): | |||
| return self.tuple_ * self.number1, self.tuple_ * self.number2 | |||
| def test_tuple_mul_number(): | |||
| """ | |||
| Description: test_tuple_mul_number | |||
| Expectation: the results are as expected | |||
| """ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net = Net() | |||
| expect_ret0 = (Tensor([1, 2, 3]),) * 5 | |||
| expect_ret1 = (Tensor([1, 2, 3]),) * 0 | |||
| assert isinstance(net()[0], tuple) | |||
| assert isinstance(net()[1], tuple) | |||
| for i in range(len(net()[0])): | |||
| assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) | |||
| assert net()[1] == expect_ret1 | |||