Browse Source

iml dict.items() and multiplication op

tags/v1.6.0
zzb 4 years ago
parent
commit
95425b770d
12 changed files with 268 additions and 5 deletions
  1. +23
    -0
      mindspore/ccsrc/frontend/optimizer/clean.cc
  2. +1
    -0
      mindspore/ccsrc/pipeline/jit/resource.cc
  3. +2
    -0
      mindspore/core/abstract/infer_functions.h
  4. +15
    -0
      mindspore/core/abstract/prim_structures.cc
  5. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  6. +1
    -0
      mindspore/core/base/core_ops.h
  7. +24
    -5
      mindspore/ops/composite/multitype_ops/mul_impl.py
  8. +17
    -0
      tests/ut/python/dtype/test_dictionary.py
  9. +46
    -0
      tests/ut/python/pipeline/parse/test_list_mul_number.py
  10. +46
    -0
      tests/ut/python/pipeline/parse/test_number_mul_list.py
  11. +46
    -0
      tests/ut/python/pipeline/parse/test_number_mul_tuple.py
  12. +46
    -0
      tests/ut/python/pipeline/parse/test_tuple_mul_number.py

+ 23
- 0
mindspore/ccsrc/frontend/optimizer/clean.cc View File

@@ -334,6 +334,27 @@ AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
return inputs[1];
}

AnfNodePtr EraseDictItems(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
const size_t expect_inputs_size = 2;
CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
const auto &tmp = inputs[0]->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(tmp);
MS_EXCEPTION_IF_NULL(tmp->value()->cast<ValueTuplePtr>());
ValuePtrList keys = tmp->value()->cast<ValueTuplePtr>()->value();
std::vector<AnfNodePtr> outer_node{NewValueNode(prim::kPrimMakeList)};
for (size_t i = 0; i < keys.size(); ++i) {
std::vector<AnfNodePtr> inner_node;
inner_node.push_back(NewValueNode(prim::kPrimMakeTuple));
inner_node.push_back(NewValueNode(keys[i]));
inner_node.push_back(NewCNode(
std::vector<AnfNodePtr>{NewValueNode(prim::kPrimTupleGetItem), inputs[1], NewValueNode(i)}, node->func_graph()));
outer_node.push_back(NewCNode(inner_node, node->func_graph()));
}
return NewCNode(outer_node, node->func_graph());
}

AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
@@ -416,6 +437,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node = EraseMakeKeywordArgNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
new_node = EraseExtractKeywordArg(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictItems)) {
new_node = EraseDictItems(cnode);
}

if (new_node != nullptr) {


+ 1
- 0
mindspore/ccsrc/pipeline/jit/resource.cc View File

@@ -143,6 +143,7 @@ BuiltInTypeMap &GetMethodMap() {
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items
{"__bool__", std::string("dict_bool")} // C.dict_bool
}},
{kObjectTypeTensorType,


+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -116,6 +116,8 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 15
- 0
mindspore/core/abstract/prim_structures.cc View File

@@ -312,6 +312,21 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
return std::make_shared<AbstractTuple>(values);
}

AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList items;
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(items), [](const AbstractAttribute &item) {
return std::make_shared<AbstractTuple>(
AbstractBasePtrList{std::make_shared<AbstractScalar>(item.first), item.second});
});
return std::make_shared<AbstractList>(items);
}

AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.


+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -151,6 +151,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}},
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}},
{prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}},
{prim::kPrimDictItems, {InferImplDictItems, nullptr, true}},
{prim::kPrimListAppend, {InferImplListAppend, nullptr, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}},
{prim::kPrimListLen, {InferImplListLen, nullptr, true}},


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -618,6 +618,7 @@ inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_g
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys");
inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues");
inline const PrimitivePtr kPrimDictItems = std::make_shared<Primitive>("dict_items");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");



+ 24
- 5
mindspore/ops/composite/multitype_ops/mul_impl.py View File

@@ -15,7 +15,6 @@

"""Implementation for internal polymorphism `mul` operations."""

from . import _constexpr_utils as const_utils
from . import _compile_utils as utils
from ...composite import base
from ... import functional as F
@@ -80,7 +79,12 @@ def _list_mul_scalar(x, y):
Outputs:
List.
"""
return const_utils.sequence_mul_int(x, y)
res = []
i = 0
while i < y:
res += x
i += 1
return res


@mul.register("Number", "List")
@@ -91,7 +95,12 @@ def _scalar_mul_list(x, y):
Outputs:
List.
"""
return const_utils.sequence_mul_int(y, x)
res = []
i = 0
while i < x:
res += y
i += 1
return res


@mul.register("Tuple", "Number")
@@ -102,7 +111,12 @@ def _tuple_mul_scalar(x, y):
Outputs:
Tuple.
"""
return const_utils.sequence_mul_int(x, y)
res = ()
i = 0
while i < y:
res += x
i += 1
return res


@mul.register("Number", "Tuple")
@@ -113,7 +127,12 @@ def _scalar_mul_tuple(x, y):
Outputs:
Tuple.
"""
return const_utils.sequence_mul_int(y, x)
res = ()
i = 0
while i < x:
res += y
i += 1
return res


@mul.register("Tensor", "Tuple")


+ 17
- 0
tests/ut/python/dtype/test_dictionary.py View File

@@ -172,3 +172,20 @@ def test_dict_set_item_create_new():
x = Tensor(np.ones([2, 2, 3], np.float32))
net = DictSetNet()
_ = net(x)


def test_dict_items():
"""
Description: test_dict_items
Expectation: the results are as expected
"""

class DictItemsNet(Cell):
def __init__(self):
super(DictItemsNet, self).__init__()

def construct(self, x):
return x.items()
x = {"1": Tensor(1), "2": {"test": (1, 2)}}
net = DictItemsNet()
_ = net(x)

+ 46
- 0
tests/ut/python/pipeline/parse/test_list_mul_number.py View File

@@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test list mul number """

import numpy as np
from mindspore import Tensor, context
from mindspore import nn


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.list_ = [Tensor([1, 2, 3])]
self.number1 = 5
self.number2 = 0

def construct(self):
return self.list_ * self.number1, self.list_ * self.number2

def test_list_mul_number():
"""
Description: test_list_mul_number
Expectation: the results are as expected
"""

context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = [Tensor([1, 2, 3])] * 5
expect_ret1 = [Tensor([1, 2, 3])] * 0
assert isinstance(net()[0], list)
assert isinstance(net()[1], list)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

+ 46
- 0
tests/ut/python/pipeline/parse/test_number_mul_list.py View File

@@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test number mul list """

import numpy as np
from mindspore import Tensor, context
from mindspore import nn


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.list_ = [Tensor([1, 2, 3])]
self.number1 = 5
self.number2 = 0

def construct(self):
return self.number1 * self.list_, self.number2 * self.list_

def test_number_mul_list():
"""
Description: test_number_mul_list
Expectation: the results are as expected
"""

context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = 5 * [Tensor([1, 2, 3])]
expect_ret1 = 0 * [Tensor([1, 2, 3])]
assert isinstance(net()[0], list)
assert isinstance(net()[1], list)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

+ 46
- 0
tests/ut/python/pipeline/parse/test_number_mul_tuple.py View File

@@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test number mul tuple """

import numpy as np
from mindspore import Tensor, context
from mindspore import nn


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.tuple_ = (Tensor([1, 2, 3]),)
self.number1 = 5
self.number2 = 0

def construct(self):
return self.number1 * self.tuple_, self.number2 * self.tuple_

def test_number_mul_tuple():
"""
Description: test_number_mul_tuple
Expectation: the results are as expected
"""

context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = 5 * (Tensor([1, 2, 3]),)
expect_ret1 = 0 * (Tensor([1, 2, 3]),)
assert isinstance(net()[0], tuple)
assert isinstance(net()[1], tuple)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

+ 46
- 0
tests/ut/python/pipeline/parse/test_tuple_mul_number.py View File

@@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test tuple mul number """

import numpy as np
from mindspore import Tensor, context
from mindspore import nn


class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.tuple_ = (Tensor([1, 2, 3]),)
self.number1 = 5
self.number2 = 0

def construct(self):
return self.tuple_ * self.number1, self.tuple_ * self.number2

def test_tuple_mul_number():
"""
Description: test_tuple_mul_number
Expectation: the results are as expected
"""

context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = (Tensor([1, 2, 3]),) * 5
expect_ret1 = (Tensor([1, 2, 3]),) * 0
assert isinstance(net()[0], tuple)
assert isinstance(net()[1], tuple)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

Loading…
Cancel
Save