| @@ -304,6 +304,13 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { | |||
| return inputs[2]; | |||
| } | |||
| AnfNodePtr EraseDictGetValues(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &inputs = node->inputs(); | |||
| MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs"); | |||
| return inputs[1]; | |||
| } | |||
| AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &inputs = node->inputs(); | |||
| @@ -374,6 +381,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr | |||
| new_node = ConvertDictGetItemToTupleGetItem(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { | |||
| new_node = ConvertDictSetItemToTupleSetItem(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) { | |||
| new_node = EraseDictGetValues(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { | |||
| new_node = EraseMakeDictNode(cnode); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { | |||
| @@ -141,6 +141,8 @@ BuiltInTypeMap &GetMethodMap() { | |||
| {"__len__", prim::kPrimDictLen}, // P.dict_len | |||
| {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem | |||
| {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, | |||
| {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys, | |||
| {"values", prim::kPrimDictGetValues}, // P.dict_getvalues, | |||
| {"__bool__", std::string("dict_bool")} // C.dict_bool | |||
| }}, | |||
| {kObjectTypeTensorType, | |||
| @@ -131,6 +131,10 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -249,6 +249,32 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP | |||
| return std::make_shared<AbstractDictionary>(dict_elems); | |||
| } | |||
| AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList keys; | |||
| std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys), | |||
| [](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); }); | |||
| return std::make_shared<AbstractTuple>(keys); | |||
| } | |||
| AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a dict. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); | |||
| std::vector<AbstractAttribute> dict_elems = dict->elements(); | |||
| AbstractBasePtrList values; | |||
| std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values), | |||
| [](const AbstractAttribute &item) { return item.second; }); | |||
| return std::make_shared<AbstractTuple>(values); | |||
| } | |||
| AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a list and an object of a subclass of AbstractBase. | |||
| @@ -72,6 +72,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimListSetItem, {InferImplListSetItem, true}}, | |||
| {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, | |||
| {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, | |||
| {prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}}, | |||
| {prim::kPrimDictGetValues, {InferImplDictGetValues, true}}, | |||
| {prim::kPrimListAppend, {InferImplListAppend, true}}, | |||
| {prim::kPrimTupleLen, {InferImplTupleLen, true}}, | |||
| {prim::kPrimListLen, {InferImplListLen, true}}, | |||
| @@ -279,6 +279,8 @@ inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_g | |||
| inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem"); | |||
| inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem"); | |||
| inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys"); | |||
| inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues"); | |||
| inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append"); | |||
| inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len"); | |||
| @@ -132,6 +132,20 @@ def _dict_setitem_with_number(data, key, value): | |||
| """ | |||
| return F.dict_setitem(data, key, value) | |||
| @setitem.register("Dictionary", "String", "Tuple") | |||
| def _dict_setitem_with_tuple(data, key, value): | |||
| """ | |||
| Assigns value to dictionary. | |||
| Inputs: | |||
| data (dict): Data of type dict. | |||
| key (str): Key of the data. | |||
| value (Tuple): Value given. | |||
| Outputs: | |||
| dict, type is as same as the element type of data. | |||
| """ | |||
| return F.dict_setitem(data, key, value) | |||
| @setitem.register("Tensor", "Tensor", "Tensor") | |||
| def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): | |||
| @@ -0,0 +1,81 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test_dictionary """ | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| class Net1(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| dic = {'x': 0, 'y': 1} | |||
| output = [] | |||
| for i in dic.keys(): | |||
| output.append(i) | |||
| for j in dic.values(): | |||
| output.append(j) | |||
| return output | |||
| class Net2(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| dic = {'x': x, 'y': 1} | |||
| output = [] | |||
| for i in dic.keys(): | |||
| output.append(i) | |||
| for j in dic.values(): | |||
| output.append(j) | |||
| return output | |||
| class Net3(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x): | |||
| dic = {'x': 0} | |||
| dic['y'] = (0, 1) | |||
| output = [] | |||
| for i in dic.keys(): | |||
| output.append(i) | |||
| for j in dic.values(): | |||
| output.append(j) | |||
| return output | |||
| def test_dict1(): | |||
| input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) | |||
| input_me = Tensor(input_np) | |||
| net = Net1() | |||
| out_me = net(input_me) | |||
| assert out_me == ('x', 'y', 0, 1) | |||
| def test_dict2(): | |||
| input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) | |||
| input_me = Tensor(input_np) | |||
| net = Net2() | |||
| net(input_me) | |||
| def test_dict3(): | |||
| input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) | |||
| input_me = Tensor(input_np) | |||
| net = Net3() | |||
| out_me = net(input_me) | |||
| assert out_me == ('x', 'y', 0, (0, 1)) | |||