From ee45c047755eb202103f78bbc1cec1c0adcc56bd Mon Sep 17 00:00:00 2001 From: l00591931 Date: Thu, 26 Nov 2020 19:29:58 +0800 Subject: [PATCH] Add pynative view and expand_as --- mindspore/_checkparam.py | 2 +- mindspore/common/tensor.py | 33 +++++++++++++++++++++++++++ mindspore/nn/layer/basic.py | 8 +++---- mindspore/ops/functional.py | 2 ++ mindspore/ops/operations/array_ops.py | 8 +++---- 5 files changed, 44 insertions(+), 9 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 462615653a..ed71c859ff 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -125,7 +125,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): - number = check_is_number(number, int, "bias", "bias_class") """ prim_name = f'in \'{prim_name}\'' if prim_name else '' - arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' + arg_name = f'\'{arg_name}\'' if arg_name else 'Input value' if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index ffc84f04f2..fdec3169f4 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -292,6 +292,39 @@ class Tensor(Tensor_): return tensor_operator_registry.get('any')(keep_dims)(self, axis) + def view(self, *shape): + """ + Reshape the tensor according to the input shape. + + Args: + shape (Union(list(int), *int)): Dimension of the output tensor. + + Returns: + Tensor, has the same dimension as the input shape. + """ + if not shape: + raise ValueError("The shape variable should not be empty") + if isinstance(shape[0], tuple): + if len(shape) != 1: + raise ValueError(f"Only one tuple is needed, but got {shape}") + shape = shape[0] + return tensor_operator_registry.get('reshape')()(self, shape) + + + def expand_as(self, x): + """ + Expand the dimension of target tensor to the dimension of input tensor. + + Args: + shape (Tensor): The input tensor. The shape of input tensor must obey + the broadcasting rule. + + Returns: + Tensor, has the same dimension as input tensor. + """ + return tensor_operator_registry.get('broadcast_to')(x.shape)(self) + + class RowTensor: """ A sparse representation of a set of tensor slices at given indices. diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 862de34688..8df3b18501 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -599,14 +599,14 @@ class Interpolate(Cell): Inputs: - **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape: - math:'(batch, channels, height, width)', with data type of float32 or float64. + math:`(batch, channels, height, width)`, with data type of float16 or float32. Outputs: Resized tensor. - If size is set, the result is 4-D tensor with shape:math:'(batch, channels, new_height, new_width)' + If size is set, the result is 4-D tensor with shape:math:`(batch, channels, new_height, new_width)` in float32. - If scale is set, the result is 4-D tensor with shape:math:'(batch, channels, scale_factor * height, - scale_factor * width)' in float32 + If scale is set, the result is 4-D tensor with shape:math:`(batch, channels, scale_factor * height, + scale_factor * width)` in float32 Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 1a5f7a08f2..9e2451aac3 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -173,6 +173,8 @@ tensor_operator_registry.register('__pow__', tensor_pow) tensor_operator_registry.register('__floordiv__', tensor_floordiv) tensor_operator_registry.register('all', P.ReduceAll) tensor_operator_registry.register('any', P.ReduceAny) +tensor_operator_registry.register('reshape', P.Reshape) +tensor_operator_registry.register('broadcast_to', P.BroadcastTo) # ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 9d6f778380..f86f6bc8d5 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1107,7 +1107,7 @@ class Fill(PrimitiveWithInfer): for i, item in enumerate(dims['value']): validator.check_positive_int(item, f'dims[{i}]', self.name) valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint32, mstype.uint64, + mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) @@ -1160,7 +1160,7 @@ class Ones(PrimitiveWithInfer): for i, item in enumerate(shape): validator.check_non_negative_int(item, shape[i], self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint32, mstype.uint64, + mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) @@ -1214,7 +1214,7 @@ class Zeros(PrimitiveWithInfer): for i, item in enumerate(shape): validator.check_non_negative_int(item, shape[i], self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint32, mstype.uint64, + mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) @@ -1269,7 +1269,7 @@ class SequenceMask(PrimitiveWithInfer): def __infer__(self, lengths, dtype, max_length=None): validator.check_value_type("shape", lengths['value'], [tuple, list], self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint32, mstype.uint64, + mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_subclass("dtype", dtype['value'], valid_types, self.name) nptype = mstype.dtype_to_nptype(dtype['value'])