Browse Source

!13342 Throw exception when tensor with 0 shape is constructed

From: @liangzhibo
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
b61aa9b4cf
8 changed files with 63 additions and 12 deletions
  1. +35
    -8
      mindspore/core/ir/tensor.cc
  2. +2
    -0
      mindspore/core/ir/tensor.h
  3. +1
    -0
      tests/st/ops/cpu/test_dropout_grad_op.py
  4. +9
    -0
      tests/st/ops/gpu/test_add_op.py
  5. +1
    -1
      tests/st/pynative/test_tensor_index.py
  6. +11
    -0
      tests/ut/python/ir/test_tensor.py
  7. +2
    -1
      tests/ut/python/ops/test_ops.py
  8. +2
    -2
      tests/ut/python/ops/test_tensor_slice.py

+ 35
- 8
mindspore/core/ir/tensor.cc View File

@@ -465,7 +465,9 @@ Tensor::Tensor(const Tensor &tensor)
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
device_event_(tensor.device_event_) {
CheckShape(tensor.shape_);
}

Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
@@ -479,29 +481,43 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
device_event_(tensor.device_event_) {
CheckShape(tensor.shape_);
}

Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {
CheckShape(shape);
}

Tensor::Tensor(TypeId data_type, const ShapeVector &shape)
: Tensor(data_type, shape, MakeTensorData(data_type, shape)) {}
: Tensor(data_type, shape, MakeTensorData(data_type, shape)) {
CheckShape(shape);
}

Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len)
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {}
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {
CheckShape(shape);
}

Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type)
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {}
: Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {
CheckShape(shape);
}

Tensor::Tensor(const std::vector<int64_t> &input, const TypePtr &data_type)
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast<int>(input.size())}),
data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
id_(MakeId()) {}
id_(MakeId()) {
CheckShape(shape_);
}

Tensor::Tensor(const std::vector<double> &input, const TypePtr &data_type)
: MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast<int>(input.size())}),
data_(MakeTensorData(data_type_, shape_, input.data(), input.size())),
id_(MakeId()) {}
id_(MakeId()) {
CheckShape(shape_);
}

Tensor::Tensor(int64_t input, const TypePtr &data_type)
: MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}),
@@ -606,6 +622,17 @@ std::string Tensor::ToStringRepr() const {
return buf.str();
}

void Tensor::CheckShape(const ShapeVector &shape) const {
// Check tensor's shape, ignore one-dimensional tensor, including empty tensor with shape=(0,).
if (shape.size() > 1) {
for (const auto &s : shape) {
if (s == 0) {
MS_EXCEPTION(ValueError) << "Zero is not supported in the shape of Tensor. ";
}
}
}
}

void Tensor::data_sync(bool need_wait) const {
if (need_wait) {
Wait();


+ 2
- 0
mindspore/core/ir/tensor.h View File

@@ -280,6 +280,8 @@ class Tensor : public MetaTensor {

std::string ToStringRepr() const;

void CheckShape(const ShapeVector &shape) const;

bool is_init() const { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; }



+ 1
- 0
tests/st/ops/cpu/test_dropout_grad_op.py View File

@@ -107,6 +107,7 @@ def test_dropout_grad_004():
assert np.all(np.abs(diff) < error)


@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard


+ 9
- 0
tests/st/ops/gpu/test_add_op.py View File

@@ -127,30 +127,39 @@ def add(nptype):
assert (output[2].asnumpy() == expect2).all()
assert (output[3].asnumpy() == expect3).all()


@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float64():
add(np.float64)


@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float32():
add(np.float32)


@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_float16():
add(np.float16)


@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_add_int64():
add(np.int64)

@pytest.mark.skip(reason='0 in shape is not support')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard


+ 1
- 1
tests/st/pynative/test_tensor_index.py View File

@@ -787,7 +787,7 @@ def test_tensor_assign_exception():
tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(IndexError):
with pytest.raises(ValueError):
net_e2(t, 2)

# Error for A[Slice] = U, U is a Tensor


+ 11
- 0
tests/ut/python/ir/test_tensor.py View File

@@ -67,6 +67,17 @@ def test_tensor():
assert isinstance(t4, ms.Tensor)
assert t4.dtype == ms.int64

def test_tensor_empty():
t = ms.Tensor(np.ones(0), ms.float32)
assert isinstance(t, ms.Tensor)
assert t.shape == (0,)


def test_tensor_shape_has_zero():
with pytest.raises(ValueError):
t = ms.Tensor(np.ones((1, 0)), ms.float32)
print(t)


def test_tensor_type_float16():
t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16))


+ 2
- 1
tests/ut/python/ops/test_ops.py View File

@@ -14,6 +14,7 @@
# ============================================================================
""" test ops """
import functools
import pytest

import numpy as np

@@ -890,7 +891,7 @@ class StridedSliceNet(nn.Cell):
out_3 = self.strided_slice_3(x, self.begins, self.ends, self.strides) + self.const_3
return out_0, out_1, out_2, out_3

@pytest.mark.skip(reason='0 in shape is not support')
def test_strided_slice_const():
class StridedSLiceConstNet(nn.Cell):
"""StridedSLiceConstNet net definition"""


+ 2
- 2
tests/ut/python/ops/test_tensor_slice.py View File

@@ -464,8 +464,8 @@ def test_tensor_assign():
net(Ta, b, Tck)
net2(t, b, tck)
# Error for A[Slice] = Number
# 1. A[Slice] = Number, Slice error
with pytest.raises(IndexError):
# 1. A[Slice] = Number, 0 in shape
with pytest.raises(ValueError):
net_e2(t, Tensor(2, mstype.int32))

# Error for A[Slice] = U, U is a Tensor


Loading…
Cancel
Save