From c5b5a6719ca30f98acadede7ec1f5181f911d6f9 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Tue, 10 Nov 2020 15:14:40 +0800 Subject: [PATCH] Enable multi-dimensional list value assignment --- mindspore/ccsrc/pipeline/jit/parse/parse.cc | 15 ++- tests/ut/python/ops/test_tensor_slice.py | 8 +- .../pipeline/parse/test_sequence_assign.py | 116 ++++++++++++++++++ 3 files changed, 132 insertions(+), 7 deletions(-) create mode 100644 tests/ut/python/pipeline/parse/test_sequence_assign.py diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index fcb9be0b15..2a01068bb3 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1501,7 +1501,7 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje AnfNodePtr slice_node = ParseExprNode(block, slice_obj); CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); // getitem apply should return the sequence data structure itself - std::string var_name = ""; + std::string var_name; if (ast_->IsClassMember(value_obj)) { std::string attr_name = value_obj.attr("attr").cast(); var_name = "self." + attr_name; @@ -1515,9 +1515,18 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje << py::str(obj).cast() << "' with type '" << py::str(obj_type).cast() << "'."; } - } else { - var_name = value_obj.attr("id").cast(); + block->WriteVariable(var_name, setitem_app); + return; + } + if (AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, value_obj))) == + AST_SUB_TYPE_SUBSCRIPT) { + HandleAssignSubscript(block, value_obj, setitem_app); + return; + } + if (!py::hasattr(value_obj, "id")) { + MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast(); } + var_name = value_obj.attr("id").cast(); block->WriteVariable(var_name, setitem_app); } diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 0856d6c12d..de8190d0cc 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -672,14 +672,14 @@ def test_tensor_assign_bool_index(): with pytest.raises(ValueError): net2(Ta, u_tensor_error) net3 = TensorAssignWithBoolTensorIndexError() - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net3(Ta, Tb, Tc, u_tensor) - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net3(Ta, Tb, Tc, Tensor(u_scalar, mstype.int32)) net4 = TensorAssignWithBoolTensorIndex2Error() - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net4(Ta, u_tensor) - with pytest.raises(AttributeError): + with pytest.raises(IndexError): net4(Ta, Tensor(u_scalar, mstype.int32)) diff --git a/tests/ut/python/pipeline/parse/test_sequence_assign.py b/tests/ut/python/pipeline/parse/test_sequence_assign.py new file mode 100644 index 0000000000..5474dff92a --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_sequence_assign.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test enumerate""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + +context.set_context(mode=context.GRAPH_MODE) + + +def test_list_index_1D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [3, 3, 3]] + list_[0] = [100] + return list_ + + net = Net() + out = net() + assert out[0] == [100] + assert out[1] == [2, 2] + assert out[2] == [3, 3, 3] + + +def test_list_index_2D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [3, 3, 3]] + list_[1][0] = 200 + list_[1][1] = 201 + return list_ + + net = Net() + out = net() + assert out[0] == [1] + assert out[1] == [200, 201] + assert out[2] == [3, 3, 3] + + +def test_list_index_3D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [[3, 3, 3]]] + list_[2][0][0] = 300 + list_[2][0][1] = 301 + list_[2][0][2] = 302 + return list_ + + net = Net() + out = net() + assert out[0] == [1] + assert out[1] == [2, 2] + assert out[2] == [[300, 301, 302]] + + +def test_list_index_1D_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + list_ = [x] + list_[0] = 100 + return list_ + + net = Net() + net(Tensor(0)) + + +def test_list_index_2D_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + list_ = [[x, x]] + list_[0][0] = 100 + return list_ + + net = Net() + net(Tensor(0)) + + +def test_list_index_3D_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + list_ = [[[x, x]]] + list_[0][0][0] = 100 + return list_ + + net = Net() + net(Tensor(0))