Browse Source

Enable multi-dimensional list value assignment

tags/v1.1.0
l00591931 5 years ago
parent
commit
c5b5a6719c
3 changed files with 132 additions and 7 deletions
  1. +12
    -3
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  2. +4
    -4
      tests/ut/python/ops/test_tensor_slice.py
  3. +116
    -0
      tests/ut/python/pipeline/parse/test_sequence_assign.py

+ 12
- 3
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -1501,7 +1501,7 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node});
// getitem apply should return the sequence data structure itself
std::string var_name = "";
std::string var_name;
if (ast_->IsClassMember(value_obj)) {
std::string attr_name = value_obj.attr("attr").cast<std::string>();
var_name = "self." + attr_name;
@@ -1515,9 +1515,18 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
<< py::str(obj).cast<std::string>() << "' with type '"
<< py::str(obj_type).cast<std::string>() << "'.";
}
} else {
var_name = value_obj.attr("id").cast<std::string>();
block->WriteVariable(var_name, setitem_app);
return;
}
if (AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
AST_SUB_TYPE_SUBSCRIPT) {
HandleAssignSubscript(block, value_obj, setitem_app);
return;
}
if (!py::hasattr(value_obj, "id")) {
MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>();
}
var_name = value_obj.attr("id").cast<std::string>();
block->WriteVariable(var_name, setitem_app);
}



+ 4
- 4
tests/ut/python/ops/test_tensor_slice.py View File

@@ -672,14 +672,14 @@ def test_tensor_assign_bool_index():
with pytest.raises(ValueError):
net2(Ta, u_tensor_error)
net3 = TensorAssignWithBoolTensorIndexError()
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net3(Ta, Tb, Tc, u_tensor)
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net3(Ta, Tb, Tc, Tensor(u_scalar, mstype.int32))
net4 = TensorAssignWithBoolTensorIndex2Error()
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net4(Ta, u_tensor)
with pytest.raises(AttributeError):
with pytest.raises(IndexError):
net4(Ta, Tensor(u_scalar, mstype.int32))




+ 116
- 0
tests/ut/python/pipeline/parse/test_sequence_assign.py View File

@@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test enumerate"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context

context.set_context(mode=context.GRAPH_MODE)


def test_list_index_1D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[0] = [100]
return list_

net = Net()
out = net()
assert out[0] == [100]
assert out[1] == [2, 2]
assert out[2] == [3, 3, 3]


def test_list_index_2D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[1][0] = 200
list_[1][1] = 201
return list_

net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [200, 201]
assert out[2] == [3, 3, 3]


def test_list_index_3D():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self):
list_ = [[1], [2, 2], [[3, 3, 3]]]
list_[2][0][0] = 300
list_[2][0][1] = 301
list_[2][0][2] = 302
return list_

net = Net()
out = net()
assert out[0] == [1]
assert out[1] == [2, 2]
assert out[2] == [[300, 301, 302]]


def test_list_index_1D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
list_ = [x]
list_[0] = 100
return list_

net = Net()
net(Tensor(0))


def test_list_index_2D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
list_ = [[x, x]]
list_[0][0] = 100
return list_

net = Net()
net(Tensor(0))


def test_list_index_3D_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
list_ = [[[x, x]]]
list_[0][0][0] = 100
return list_

net = Net()
net(Tensor(0))

Loading…
Cancel
Save