From 3189868a155c7280c93af6392da84cf856648cd9 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Mon, 7 Dec 2020 15:54:13 +0800 Subject: [PATCH] Assignment enable index smaller than 0 --- mindspore/ccsrc/frontend/optimizer/clean.cc | 2 +- mindspore/core/abstract/prim_structures.cc | 14 ++--- mindspore/core/ir/pattern_matcher.h | 3 ++ .../pipeline/parse/test_sequence_assign.py | 54 +++++++++++++++++++ 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index 60f6314c50..7d825410e1 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -330,7 +330,7 @@ AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { const int64_t DEPTH_MAX = 5; if (depth > DEPTH_MAX) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; + MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels."; } std::vector elements; for (const auto &it : value_list->value()) { diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc index 7e6633a73f..e00ad2a07f 100644 --- a/mindspore/core/abstract/prim_structures.cc +++ b/mindspore/core/abstract/prim_structures.cc @@ -163,18 +163,14 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra << index_value->ToString(); } int64_t idx_v = GetValue(index_value); - if (idx_v < 0) { - MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v - << "."; - } - - size_t uidx_v = LongToSize(idx_v); AbstractBasePtrList elements = queue->elements(); std::size_t nelems = elements.size(); - if (uidx_v >= nelems) { - MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 - << "."; + int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems); + if (idx_t < 0 || idx_t >= SizeToLong(nelems)) { + MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems + << "," << nelems - 1 << "]."; } + size_t uidx_v = LongToSize(idx_t); elements[uidx_v] = args_spec_list[2]; return std::make_shared(elements); } diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 71cb4a1d91..f14b4f98ec 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -755,6 +755,9 @@ class PConstant : public PBase > { ShapeVector tensor_shape = tensor_abstract->shape()->shape(); auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + if (new_tensor_ptr->DataSize() < tensor_ptr->DataSize()) { + MS_EXCEPTION(ValueError) << "DataSize of new_tensor_ptr is smaller than DataSize of tensor_ptr"; + } if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || (tensor_type == TypeId::kNumberTypeFloat64)) { float *data = reinterpret_cast(tensor_ptr->data_c()); diff --git a/tests/ut/python/pipeline/parse/test_sequence_assign.py b/tests/ut/python/pipeline/parse/test_sequence_assign.py index 5474dff92a..29b7bc11a5 100644 --- a/tests/ut/python/pipeline/parse/test_sequence_assign.py +++ b/tests/ut/python/pipeline/parse/test_sequence_assign.py @@ -37,6 +37,23 @@ def test_list_index_1D(): assert out[2] == [3, 3, 3] +def test_list_neg_index_1D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [3, 3, 3]] + list_[-3] = [100] + return list_ + + net = Net() + out = net() + assert out[0] == [100] + assert out[1] == [2, 2] + assert out[2] == [3, 3, 3] + + def test_list_index_2D(): class Net(nn.Cell): def __init__(self): @@ -55,6 +72,24 @@ def test_list_index_2D(): assert out[2] == [3, 3, 3] +def test_list_neg_index_2D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [3, 3, 3]] + list_[1][-2] = 200 + list_[1][-1] = 201 + return list_ + + net = Net() + out = net() + assert out[0] == [1] + assert out[1] == [200, 201] + assert out[2] == [3, 3, 3] + + def test_list_index_3D(): class Net(nn.Cell): def __init__(self): @@ -74,6 +109,25 @@ def test_list_index_3D(): assert out[2] == [[300, 301, 302]] +def test_list_neg_index_3D(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self): + list_ = [[1], [2, 2], [[3, 3, 3]]] + list_[2][0][-3] = 300 + list_[2][0][-2] = 301 + list_[2][0][-1] = 302 + return list_ + + net = Net() + out = net() + assert out[0] == [1] + assert out[1] == [2, 2] + assert out[2] == [[300, 301, 302]] + + def test_list_index_1D_parameter(): class Net(nn.Cell): def __init__(self):