Browse Source

get keys and values from dictionary & set tuple to dictionary

tags/v1.1.0
simson 5 years ago
parent
commit
3b21822824
8 changed files with 140 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/frontend/optimizer/clean.cc
  2. +2
    -0
      mindspore/ccsrc/pipeline/jit/resource.cc
  3. +4
    -0
      mindspore/core/abstract/infer_functions.h
  4. +26
    -0
      mindspore/core/abstract/prim_structures.cc
  5. +2
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  6. +2
    -0
      mindspore/core/base/core_ops.h
  7. +14
    -0
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  8. +81
    -0
      tests/ut/python/pipeline/parse/test_dictionary.py

+ 9
- 0
mindspore/ccsrc/frontend/optimizer/clean.cc View File

@@ -304,6 +304,13 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
return inputs[2];
}

AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs");
return inputs[1];
}

AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
@@ -374,6 +381,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node = ConvertDictGetItemToTupleGetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
new_node = ConvertDictSetItemToTupleSetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
new_node = EraseDictGetValues(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
new_node = EraseMakeDictNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {


+ 2
- 0
mindspore/ccsrc/pipeline/jit/resource.cc View File

@@ -141,6 +141,8 @@ BuiltInTypeMap &GetMethodMap() {
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"__bool__", std::string("dict_bool")} // C.dict_bool
}},
{kObjectTypeTensorType,


+ 4
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -131,6 +131,10 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 26
- 0
mindspore/core/abstract/prim_structures.cc View File

@@ -249,6 +249,32 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<AbstractDictionary>(dict_elems);
}

AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList keys;
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
[](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); });
return std::make_shared<AbstractTuple>(keys);
}

AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList values;
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values),
[](const AbstractAttribute &item) { return item.second; });
return std::make_shared<AbstractTuple>(values);
}

AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.


+ 2
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -72,6 +72,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}},
{prim::kPrimDictGetValues, {InferImplDictGetValues, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},


+ 2
- 0
mindspore/core/base/core_ops.h View File

@@ -279,6 +279,8 @@ inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_g
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys");
inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");



+ 14
- 0
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -132,6 +132,20 @@ def _dict_setitem_with_number(data, key, value):
"""
return F.dict_setitem(data, key, value)

@setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value):
"""
Assigns value to dictionary.

Inputs:
data (dict): Data of type dict.
key (str): Key of the data.
value (Tuple): Value given.

Outputs:
dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)

@setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):


+ 81
- 0
tests/ut/python/pipeline/parse/test_dictionary.py View File

@@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dictionary """
import numpy as np

from mindspore import Tensor
from mindspore.nn import Cell


class Net1(Cell):
def __init__(self):
super().__init__()

def construct(self, x):
dic = {'x': 0, 'y': 1}
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output

class Net2(Cell):
def __init__(self):
super().__init__()

def construct(self, x):
dic = {'x': x, 'y': 1}
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output

class Net3(Cell):
def __init__(self):
super().__init__()

def construct(self, x):
dic = {'x': 0}
dic['y'] = (0, 1)
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output

def test_dict1():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net1()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, 1)


def test_dict2():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net2()
net(input_me)

def test_dict3():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net3()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, (0, 1))

Loading…
Cancel
Save