From 61857b69ed9ad696b2e131255510fb1d482d887f Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 17 Jun 2020 11:01:53 +0800 Subject: [PATCH] add check for tensor slice before convert to ops --- .../composite/multitype_ops/_compile_utils.py | 31 ++++++++++++------- .../ops/composite/multitype_ops/div_impl.py | 4 +-- .../composite/multitype_ops/floordiv_impl.py | 2 +- .../composite/multitype_ops/getitem_impl.py | 2 +- .../multitype_ops/greater_equal_impl.py | 2 +- .../composite/multitype_ops/greater_impl.py | 2 +- .../multitype_ops/less_equal_impl.py | 4 +-- .../composite/multitype_ops/logic_not_impl.py | 24 +++++++++++--- .../multitype_ops/logical_and_impl.py | 8 ++--- .../multitype_ops/logical_or_impl.py | 10 +++--- .../ops/composite/multitype_ops/mod_impl.py | 2 +- .../ops/composite/multitype_ops/mul_impl.py | 2 +- .../ops/composite/multitype_ops/sub_impl.py | 2 +- tests/ut/python/ops/test_tensor_slice.py | 2 +- 14 files changed, 60 insertions(+), 37 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index f99fec64d0..826cb9500c 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -156,7 +156,6 @@ def generate_updates_from_tensor(data, index, value, op_type): return value - def tensor_getitem(self, index): """Handle tensor getitem""" if isinstance(index, Tensor): @@ -164,16 +163,15 @@ def tensor_getitem(self, index): if isinstance(index, tuple): return tensor_index_by_tuple(self, index) if isinstance(index, int): - return tensor_index_by_number(self, index) + return tensor_index_by_integer(self, index) if isinstance(index, slice): return tensor_index_by_slice(self, index) if isinstance(index, bool): return tensor_index_by_bool(self, index) if index is ...: return self - raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\ - got {} with type{}".format(index, type(index))) - + raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " + f"got {index} with type {type(index)}.") tensor_operator_registry.register("__getitem__", tensor_getitem) @@ -199,13 +197,19 @@ def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): def tensor_index_by_slice(data, slice_index): """Tensor getitem by a single slice""" - begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index) + shape = F.shape(data) + if not shape: + const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor cannot be 0.") + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) return F.strided_slice(data, begin_strides, end_strides, step_strides) def tensor_index_by_integer(data, number): """Tensor getitem by a single integer number""" - begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number) + shape = F.shape(data) + if not shape: + const_utils.raise_index_error("When tensor is indexed by an integer, the dimension of the tensor cannot be 0.") + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number) shrink_axis_mask = 1 return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) @@ -214,7 +218,7 @@ def tensor_index_by_bool(data, bool_value): """Tensor getitem by a single bool value""" if bool_value: return F.expand_dims(data, 0) - return const_utils.raise_index_error("bool value as indexing ,false is not supported") + return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") def tensor_index_by_number(data, number): @@ -224,7 +228,7 @@ def tensor_index_by_number(data, number): return tensor_index_by_bool(data, number) if number_type == const_utils.INT_: return tensor_index_by_integer(data, number) - return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") + return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") def tensor_index_by_tensor(data, tensor_index): @@ -233,13 +237,18 @@ def tensor_index_by_tensor(data, tensor_index): const_utils.TENSOR_GETITEM) if dtype_valid: return F.gather(data, tensor_index, 0) - return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool") + return const_utils.raise_index_error("For 'tensor getitem', " + "the index tensor data type only support mstype.int32.") def tensor_index_by_tuple_slice(data, t): """Tensor getitem by a tuple of slice""" + shape = F.shape(data) + if len(t) > len(shape): + const_utils.raise_index_error("When tensor is indexed by a tuple, " + "the length of the tuple cannot be greater than the dimension of the tensor.") begin_strides, end_strides, step_strides, shrink_axis_mask = \ - const_utils.get_stride_info_from_tuple(F.shape(data), t) + const_utils.get_stride_info_from_tuple(shape, t) return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) diff --git a/mindspore/ops/composite/multitype_ops/div_impl.py b/mindspore/ops/composite/multitype_ops/div_impl.py index c37fcb9c36..85a4e035c0 100644 --- a/mindspore/ops/composite/multitype_ops/div_impl.py +++ b/mindspore/ops/composite/multitype_ops/div_impl.py @@ -47,8 +47,8 @@ def _div_tensor(x, y): Two tensors divide by element. Args: - x (Tensor): x - y (Tensor): The dtype is same as x. + x (Tensor): The first input tensor. + y (Tensor): The second input tensor. Returns: Tensor, has the same dtype as x. diff --git a/mindspore/ops/composite/multitype_ops/floordiv_impl.py b/mindspore/ops/composite/multitype_ops/floordiv_impl.py index c1a47f881f..8e9e941309 100644 --- a/mindspore/ops/composite/multitype_ops/floordiv_impl.py +++ b/mindspore/ops/composite/multitype_ops/floordiv_impl.py @@ -34,7 +34,7 @@ def _floordiv_scalar(x, y): @floordiv.register("Tensor", "Tensor") def _floordiv_tensor(x, y): - """Returns x // y where x and y are all tensors and have save dtype.""" + """Returns x // y where x and y are all tensors.""" return F.tensor_floordiv(x, y) diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 7bc9fd2805..ffd5ea4d62 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -164,7 +164,7 @@ def _tensor_getitem_by_number(data, number_index): @getitem.register("Tensor", "None") def _tensor_getitem_by_none(data, index): """ - For none indexing , expand data with one dim + For none indexing , expand data with one dim. Inputs: data (Tensor): A tensor. diff --git a/mindspore/ops/composite/multitype_ops/greater_equal_impl.py b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py index 2073abb762..93f1acbc54 100644 --- a/mindspore/ops/composite/multitype_ops/greater_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py @@ -25,7 +25,7 @@ greater_equal = base.MultitypeFuncGraph("greater_equal") @greater_equal.register("Number", "Number") def _greater_equal_scala(x, y): """ - Determine whether x is greater equal than y + Determine whether x is greater equal than y. Args: x(Number): Number. diff --git a/mindspore/ops/composite/multitype_ops/greater_impl.py b/mindspore/ops/composite/multitype_ops/greater_impl.py index ec65704f62..2f3a2dbb83 100644 --- a/mindspore/ops/composite/multitype_ops/greater_impl.py +++ b/mindspore/ops/composite/multitype_ops/greater_impl.py @@ -48,6 +48,6 @@ def _greater_tensor(x, y): y(Tensor): Tensor. Returns: - tensor, return operation of x and y by P.Greater + tensor, return operation of x and y by P.Greater. """ return F.tensor_gt(x, y) diff --git a/mindspore/ops/composite/multitype_ops/less_equal_impl.py b/mindspore/ops/composite/multitype_ops/less_equal_impl.py index dc1438da2c..5927c4b349 100644 --- a/mindspore/ops/composite/multitype_ops/less_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/less_equal_impl.py @@ -25,7 +25,7 @@ less_equal = base.MultitypeFuncGraph("less_equal") @less_equal.register("Number", "Number") def _less_equal_scala(x, y): """ - Determine whether x is less equal than y + Determine whether x is less equal than y. Args: x(Number): Number. @@ -41,7 +41,7 @@ def _less_equal_scala(x, y): @less_equal.register("Tensor", "Tensor") def _less_equal_tensor(x, y): """ - Determine whether tensor x is less equal than tensor y elementwise + Determine whether tensor x is less equal than tensor y elementwise. Args: x(Tensor): Tensor. diff --git a/mindspore/ops/composite/multitype_ops/logic_not_impl.py b/mindspore/ops/composite/multitype_ops/logic_not_impl.py index 35ae766433..6705145a64 100644 --- a/mindspore/ops/composite/multitype_ops/logic_not_impl.py +++ b/mindspore/ops/composite/multitype_ops/logic_not_impl.py @@ -25,13 +25,13 @@ logical_not = base.MultitypeFuncGraph("logical_not") @logical_not.register("Number") def _logical_not_scala(x): """ - Return logical not operation result of x + Return logical not operation result of x. Args: x(Number): Number. Returns: - bool, Return logical not operation result of x + bool, Return logical not operation result of x. """ return F.bool_not(x.__bool__()) @@ -39,10 +39,24 @@ def _logical_not_scala(x): @logical_not.register("Tensor") def _logical_not_tensor(x): """ - Return logical not operation result of x + Return logical not operation result of x. Args: x(Tensor): Tensor. Returns: - Tensor, Return logical not operation result of x + Tensor, Return logical not operation result of x. """ - return F.logical_not(x) + return F.logical_not(x) + + +@logical_not.register("Tuple") +def _logical_not_tuple(x): + """ + Return logical not operation result of a tuple object. + + Args: + x(Tuple): The input tuple. + + Returns: + bool, Return logical not operation result of x. + """ + return F.bool_not(x.__bool__()) diff --git a/mindspore/ops/composite/multitype_ops/logical_and_impl.py b/mindspore/ops/composite/multitype_ops/logical_and_impl.py index 324ce3a78d..79001f43e8 100644 --- a/mindspore/ops/composite/multitype_ops/logical_and_impl.py +++ b/mindspore/ops/composite/multitype_ops/logical_and_impl.py @@ -25,14 +25,14 @@ logical_and = base.MultitypeFuncGraph("logical_and") @logical_and.register("Number", "Number") def _logical_and_scala(x, y): """ - Return logical and operation result of x and y + Return logical and operation result of x and y. Args: x(Number): Number. y(Number): Number. Returns: - bool, Return logical and operation result of x and y + bool, Return logical and operation result of x and y. """ return F.bool_and(x.__bool__(), y.__bool__()) @@ -40,13 +40,13 @@ def _logical_and_scala(x, y): @logical_and.register("Tensor", "Tensor") def _logical_and_tensor(x, y): """ - Return logical and operation result of x and y + Return logical and operation result of x and y. Args: x(Tensor): Tensor. y(Tensor): Tensor. Returns: - Tensor, Return logical and operation result of x and y + Tensor, Return logical and operation result of x and y. """ return F.logical_and(x, y) diff --git a/mindspore/ops/composite/multitype_ops/logical_or_impl.py b/mindspore/ops/composite/multitype_ops/logical_or_impl.py index fd106f7685..6d070d5cbf 100644 --- a/mindspore/ops/composite/multitype_ops/logical_or_impl.py +++ b/mindspore/ops/composite/multitype_ops/logical_or_impl.py @@ -25,14 +25,14 @@ logical_or = base.MultitypeFuncGraph("logical_or") @logical_or.register("Number", "Number") def _logical_or_scala(x, y): """ - Return logical or operation result of x and y + Return logical or operation result of x and y. Args: x(Number): Number. y(Number): Number. Returns: - bool, Return logical or operation result of x and y + bool, Return logical or operation result of x and y. """ return F.bool_or(x.__bool__(), y.__bool__()) @@ -40,13 +40,13 @@ def _logical_or_scala(x, y): @logical_or.register("Tensor", "Tensor") def _logical_or_tensor(x, y): """ - Return logical operation or result of x and y + Return logical operation or result of x and y. Args: x(Tensor): Tensor. y(Tensor): Tensor. Returns: - Tensor, Return logical operation or result of x and y + Tensor, Return logical operation or result of x and y. """ - return F.logical_or(x, y) + return F.logical_or(x, y) diff --git a/mindspore/ops/composite/multitype_ops/mod_impl.py b/mindspore/ops/composite/multitype_ops/mod_impl.py index e9947677ac..4b6a13bbc8 100644 --- a/mindspore/ops/composite/multitype_ops/mod_impl.py +++ b/mindspore/ops/composite/multitype_ops/mod_impl.py @@ -34,7 +34,7 @@ def _mod_scalar(x, y): @mod.register("Tensor", "Tensor") def _mod_tensor(x, y): - """Returns x % y where x and y are all tensors and have save dtype.""" + """Returns x % y where x and y are all tensors.""" return F.tensor_mod(x, y) diff --git a/mindspore/ops/composite/multitype_ops/mul_impl.py b/mindspore/ops/composite/multitype_ops/mul_impl.py index ce9ec391af..b5535df135 100644 --- a/mindspore/ops/composite/multitype_ops/mul_impl.py +++ b/mindspore/ops/composite/multitype_ops/mul_impl.py @@ -40,7 +40,7 @@ def _mul_scalar(x, y): @mul.register("Tensor", "Tensor") def _mul_tensor(x, y): """ - Returns x * y by element-wise where x and y are all tensors and have same dtype. + Returns x * y by element-wise where x and y are all tensors. Outputs: Tensor, has the same dtype as x. diff --git a/mindspore/ops/composite/multitype_ops/sub_impl.py b/mindspore/ops/composite/multitype_ops/sub_impl.py index 431a58b991..864b8678d4 100644 --- a/mindspore/ops/composite/multitype_ops/sub_impl.py +++ b/mindspore/ops/composite/multitype_ops/sub_impl.py @@ -34,7 +34,7 @@ def _sub_scalar(x, y): @sub.register("Tensor", "Tensor") def _sub_tensor(x, y): - """Returns x - y where x and y are all tensors and have save dtype.""" + """Returns x - y where x and y are all tensors.""" return F.tensor_sub(x, y) diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 55159f19ae..e6b078fd5c 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -1139,7 +1139,7 @@ raise_error_set = [ @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) def test_exec(): - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) return test_cases