From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qhtags/v1.1.0
| @@ -330,7 +330,7 @@ AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { | |||||
| ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { | ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) { | ||||
| const int64_t DEPTH_MAX = 5; | const int64_t DEPTH_MAX = 5; | ||||
| if (depth > DEPTH_MAX) { | if (depth > DEPTH_MAX) { | ||||
| MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; | |||||
| MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels."; | |||||
| } | } | ||||
| std::vector<ValuePtr> elements; | std::vector<ValuePtr> elements; | ||||
| for (const auto &it : value_list->value()) { | for (const auto &it : value_list->value()) { | ||||
| @@ -163,18 +163,14 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra | |||||
| << index_value->ToString(); | << index_value->ToString(); | ||||
| } | } | ||||
| int64_t idx_v = GetValue<int64_t>(index_value); | int64_t idx_v = GetValue<int64_t>(index_value); | ||||
| if (idx_v < 0) { | |||||
| MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v | |||||
| << "."; | |||||
| } | |||||
| size_t uidx_v = LongToSize(idx_v); | |||||
| AbstractBasePtrList elements = queue->elements(); | AbstractBasePtrList elements = queue->elements(); | ||||
| std::size_t nelems = elements.size(); | std::size_t nelems = elements.size(); | ||||
| if (uidx_v >= nelems) { | |||||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 | |||||
| << "."; | |||||
| int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems); | |||||
| if (idx_t < 0 || idx_t >= SizeToLong(nelems)) { | |||||
| MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << idx_v << " to set out of range: [-" << nelems | |||||
| << "," << nelems - 1 << "]."; | |||||
| } | } | ||||
| size_t uidx_v = LongToSize(idx_t); | |||||
| elements[uidx_v] = args_spec_list[2]; | elements[uidx_v] = args_spec_list[2]; | ||||
| return std::make_shared<T>(elements); | return std::make_shared<T>(elements); | ||||
| } | } | ||||
| @@ -755,6 +755,9 @@ class PConstant : public PBase<PConstant<T> > { | |||||
| ShapeVector tensor_shape = tensor_abstract->shape()->shape(); | ShapeVector tensor_shape = tensor_abstract->shape()->shape(); | ||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | ||||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | ||||
| if (new_tensor_ptr->DataSize() < tensor_ptr->DataSize()) { | |||||
| MS_EXCEPTION(ValueError) << "DataSize of new_tensor_ptr is smaller than DataSize of tensor_ptr"; | |||||
| } | |||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || | if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat) || | ||||
| (tensor_type == TypeId::kNumberTypeFloat64)) { | (tensor_type == TypeId::kNumberTypeFloat64)) { | ||||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | ||||
| @@ -37,6 +37,23 @@ def test_list_index_1D(): | |||||
| assert out[2] == [3, 3, 3] | assert out[2] == [3, 3, 3] | ||||
| def test_list_neg_index_1D(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self): | |||||
| list_ = [[1], [2, 2], [3, 3, 3]] | |||||
| list_[-3] = [100] | |||||
| return list_ | |||||
| net = Net() | |||||
| out = net() | |||||
| assert out[0] == [100] | |||||
| assert out[1] == [2, 2] | |||||
| assert out[2] == [3, 3, 3] | |||||
| def test_list_index_2D(): | def test_list_index_2D(): | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -55,6 +72,24 @@ def test_list_index_2D(): | |||||
| assert out[2] == [3, 3, 3] | assert out[2] == [3, 3, 3] | ||||
| def test_list_neg_index_2D(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self): | |||||
| list_ = [[1], [2, 2], [3, 3, 3]] | |||||
| list_[1][-2] = 200 | |||||
| list_[1][-1] = 201 | |||||
| return list_ | |||||
| net = Net() | |||||
| out = net() | |||||
| assert out[0] == [1] | |||||
| assert out[1] == [200, 201] | |||||
| assert out[2] == [3, 3, 3] | |||||
| def test_list_index_3D(): | def test_list_index_3D(): | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -74,6 +109,25 @@ def test_list_index_3D(): | |||||
| assert out[2] == [[300, 301, 302]] | assert out[2] == [[300, 301, 302]] | ||||
| def test_list_neg_index_3D(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self): | |||||
| list_ = [[1], [2, 2], [[3, 3, 3]]] | |||||
| list_[2][0][-3] = 300 | |||||
| list_[2][0][-2] = 301 | |||||
| list_[2][0][-1] = 302 | |||||
| return list_ | |||||
| net = Net() | |||||
| out = net() | |||||
| assert out[0] == [1] | |||||
| assert out[1] == [2, 2] | |||||
| assert out[2] == [[300, 301, 302]] | |||||
| def test_list_index_1D_parameter(): | def test_list_index_1D_parameter(): | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||