Browse Source

!2204 add check for tensor slice before convert to ops

Merge pull request !2204 from zhangbuxue/add_check_for_tensor_slice_before_convert_to_ops
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b9ad407dc2
14 changed files with 60 additions and 37 deletions
  1. +20
    -11
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +2
    -2
      mindspore/ops/composite/multitype_ops/div_impl.py
  3. +1
    -1
      mindspore/ops/composite/multitype_ops/floordiv_impl.py
  4. +1
    -1
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  5. +1
    -1
      mindspore/ops/composite/multitype_ops/greater_equal_impl.py
  6. +1
    -1
      mindspore/ops/composite/multitype_ops/greater_impl.py
  7. +2
    -2
      mindspore/ops/composite/multitype_ops/less_equal_impl.py
  8. +19
    -5
      mindspore/ops/composite/multitype_ops/logic_not_impl.py
  9. +4
    -4
      mindspore/ops/composite/multitype_ops/logical_and_impl.py
  10. +5
    -5
      mindspore/ops/composite/multitype_ops/logical_or_impl.py
  11. +1
    -1
      mindspore/ops/composite/multitype_ops/mod_impl.py
  12. +1
    -1
      mindspore/ops/composite/multitype_ops/mul_impl.py
  13. +1
    -1
      mindspore/ops/composite/multitype_ops/sub_impl.py
  14. +1
    -1
      tests/ut/python/ops/test_tensor_slice.py

+ 20
- 11
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -156,7 +156,6 @@ def generate_updates_from_tensor(data, index, value, op_type):
return value



def tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
@@ -164,16 +163,15 @@ def tensor_getitem(self, index):
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, int):
return tensor_index_by_number(self, index)
return tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if isinstance(index, bool):
return tensor_index_by_bool(self, index)
if index is ...:
return self
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\
got {} with type{}".format(index, type(index)))

raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.")


tensor_operator_registry.register("__getitem__", tensor_getitem)
@@ -199,13 +197,19 @@ def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):

def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index)
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number)
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)

@@ -214,7 +218,7 @@ def tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("bool value as indexing ,false is not supported")
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")


def tensor_index_by_number(data, number):
@@ -224,7 +228,7 @@ def tensor_index_by_number(data, number):
return tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def tensor_index_by_tensor(data, tensor_index):
@@ -233,13 +237,18 @@ def tensor_index_by_tensor(data, tensor_index):
const_utils.TENSOR_GETITEM)
if dtype_valid:
return F.gather(data, tensor_index, 0)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
return const_utils.raise_index_error("For 'tensor getitem', "
"the index tensor data type only support mstype.int32.")


def tensor_index_by_tuple_slice(data, t):
"""Tensor getitem by a tuple of slice"""
shape = F.shape(data)
if len(t) > len(shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, "
"the length of the tuple cannot be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(F.shape(data), t)
const_utils.get_stride_info_from_tuple(shape, t)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)




+ 2
- 2
mindspore/ops/composite/multitype_ops/div_impl.py View File

@@ -47,8 +47,8 @@ def _div_tensor(x, y):
Two tensors divide by element.

Args:
x (Tensor): x
y (Tensor): The dtype is same as x.
x (Tensor): The first input tensor.
y (Tensor): The second input tensor.

Returns:
Tensor, has the same dtype as x.


+ 1
- 1
mindspore/ops/composite/multitype_ops/floordiv_impl.py View File

@@ -34,7 +34,7 @@ def _floordiv_scalar(x, y):

@floordiv.register("Tensor", "Tensor")
def _floordiv_tensor(x, y):
"""Returns x // y where x and y are all tensors and have save dtype."""
"""Returns x // y where x and y are all tensors."""
return F.tensor_floordiv(x, y)




+ 1
- 1
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -164,7 +164,7 @@ def _tensor_getitem_by_number(data, number_index):
@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
"""
For none indexing , expand data with one dim
For none indexing , expand data with one dim.

Inputs:
data (Tensor): A tensor.


+ 1
- 1
mindspore/ops/composite/multitype_ops/greater_equal_impl.py View File

@@ -25,7 +25,7 @@ greater_equal = base.MultitypeFuncGraph("greater_equal")
@greater_equal.register("Number", "Number")
def _greater_equal_scala(x, y):
"""
Determine whether x is greater equal than y
Determine whether x is greater equal than y.

Args:
x(Number): Number.


+ 1
- 1
mindspore/ops/composite/multitype_ops/greater_impl.py View File

@@ -48,6 +48,6 @@ def _greater_tensor(x, y):
y(Tensor): Tensor.

Returns:
tensor, return operation of x and y by P.Greater
tensor, return operation of x and y by P.Greater.
"""
return F.tensor_gt(x, y)

+ 2
- 2
mindspore/ops/composite/multitype_ops/less_equal_impl.py View File

@@ -25,7 +25,7 @@ less_equal = base.MultitypeFuncGraph("less_equal")
@less_equal.register("Number", "Number")
def _less_equal_scala(x, y):
"""
Determine whether x is less equal than y
Determine whether x is less equal than y.

Args:
x(Number): Number.
@@ -41,7 +41,7 @@ def _less_equal_scala(x, y):
@less_equal.register("Tensor", "Tensor")
def _less_equal_tensor(x, y):
"""
Determine whether tensor x is less equal than tensor y elementwise
Determine whether tensor x is less equal than tensor y elementwise.

Args:
x(Tensor): Tensor.


+ 19
- 5
mindspore/ops/composite/multitype_ops/logic_not_impl.py View File

@@ -25,13 +25,13 @@ logical_not = base.MultitypeFuncGraph("logical_not")
@logical_not.register("Number")
def _logical_not_scala(x):
"""
Return logical not operation result of x
Return logical not operation result of x.

Args:
x(Number): Number.

Returns:
bool, Return logical not operation result of x
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())

@@ -39,10 +39,24 @@ def _logical_not_scala(x):
@logical_not.register("Tensor")
def _logical_not_tensor(x):
"""
Return logical not operation result of x
Return logical not operation result of x.
Args:
x(Tensor): Tensor.
Returns:
Tensor, Return logical not operation result of x
Tensor, Return logical not operation result of x.
"""
return F.logical_not(x)
return F.logical_not(x)


@logical_not.register("Tuple")
def _logical_not_tuple(x):
"""
Return logical not operation result of a tuple object.

Args:
x(Tuple): The input tuple.

Returns:
bool, Return logical not operation result of x.
"""
return F.bool_not(x.__bool__())

+ 4
- 4
mindspore/ops/composite/multitype_ops/logical_and_impl.py View File

@@ -25,14 +25,14 @@ logical_and = base.MultitypeFuncGraph("logical_and")
@logical_and.register("Number", "Number")
def _logical_and_scala(x, y):
"""
Return logical and operation result of x and y
Return logical and operation result of x and y.

Args:
x(Number): Number.
y(Number): Number.

Returns:
bool, Return logical and operation result of x and y
bool, Return logical and operation result of x and y.
"""
return F.bool_and(x.__bool__(), y.__bool__())

@@ -40,13 +40,13 @@ def _logical_and_scala(x, y):
@logical_and.register("Tensor", "Tensor")
def _logical_and_tensor(x, y):
"""
Return logical and operation result of x and y
Return logical and operation result of x and y.

Args:
x(Tensor): Tensor.
y(Tensor): Tensor.

Returns:
Tensor, Return logical and operation result of x and y
Tensor, Return logical and operation result of x and y.
"""
return F.logical_and(x, y)

+ 5
- 5
mindspore/ops/composite/multitype_ops/logical_or_impl.py View File

@@ -25,14 +25,14 @@ logical_or = base.MultitypeFuncGraph("logical_or")
@logical_or.register("Number", "Number")
def _logical_or_scala(x, y):
"""
Return logical or operation result of x and y
Return logical or operation result of x and y.

Args:
x(Number): Number.
y(Number): Number.

Returns:
bool, Return logical or operation result of x and y
bool, Return logical or operation result of x and y.
"""
return F.bool_or(x.__bool__(), y.__bool__())

@@ -40,13 +40,13 @@ def _logical_or_scala(x, y):
@logical_or.register("Tensor", "Tensor")
def _logical_or_tensor(x, y):
"""
Return logical operation or result of x and y
Return logical operation or result of x and y.

Args:
x(Tensor): Tensor.
y(Tensor): Tensor.

Returns:
Tensor, Return logical operation or result of x and y
Tensor, Return logical operation or result of x and y.
"""
return F.logical_or(x, y)
return F.logical_or(x, y)

+ 1
- 1
mindspore/ops/composite/multitype_ops/mod_impl.py View File

@@ -34,7 +34,7 @@ def _mod_scalar(x, y):

@mod.register("Tensor", "Tensor")
def _mod_tensor(x, y):
"""Returns x % y where x and y are all tensors and have save dtype."""
"""Returns x % y where x and y are all tensors."""
return F.tensor_mod(x, y)




+ 1
- 1
mindspore/ops/composite/multitype_ops/mul_impl.py View File

@@ -40,7 +40,7 @@ def _mul_scalar(x, y):
@mul.register("Tensor", "Tensor")
def _mul_tensor(x, y):
"""
Returns x * y by element-wise where x and y are all tensors and have same dtype.
Returns x * y by element-wise where x and y are all tensors.

Outputs:
Tensor, has the same dtype as x.


+ 1
- 1
mindspore/ops/composite/multitype_ops/sub_impl.py View File

@@ -34,7 +34,7 @@ def _sub_scalar(x, y):

@sub.register("Tensor", "Tensor")
def _sub_tensor(x, y):
"""Returns x - y where x and y are all tensors and have save dtype."""
"""Returns x - y where x and y are all tensors."""
return F.tensor_sub(x, y)




+ 1
- 1
tests/ut/python/ops/test_tensor_slice.py View File

@@ -1139,7 +1139,7 @@ raise_error_set = [

@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
return test_cases




Loading…
Cancel
Save