Browse Source

Add pynative view and expand_as

tags/v1.1.0
l00591931 5 years ago
parent
commit
ee45c04775
5 changed files with 44 additions and 9 deletions
  1. +1
    -1
      mindspore/_checkparam.py
  2. +33
    -0
      mindspore/common/tensor.py
  3. +4
    -4
      mindspore/nn/layer/basic.py
  4. +2
    -0
      mindspore/ops/functional.py
  5. +4
    -4
      mindspore/ops/operations/array_ops.py

+ 1
- 1
mindspore/_checkparam.py View File

@@ -125,7 +125,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
- number = check_is_number(number, int, "bias", "bias_class")
"""
prim_name = f'in \'{prim_name}\'' if prim_name else ''
arg_name = f'\'{prim_name}\'' if arg_name else 'Input value'
arg_name = f'\'{arg_name}\'' if arg_name else 'Input value'
if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.')


+ 33
- 0
mindspore/common/tensor.py View File

@@ -292,6 +292,39 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('any')(keep_dims)(self, axis)


def view(self, *shape):
"""
Reshape the tensor according to the input shape.

Args:
shape (Union(list(int), *int)): Dimension of the output tensor.

Returns:
Tensor, has the same dimension as the input shape.
"""
if not shape:
raise ValueError("The shape variable should not be empty")
if isinstance(shape[0], tuple):
if len(shape) != 1:
raise ValueError(f"Only one tuple is needed, but got {shape}")
shape = shape[0]
return tensor_operator_registry.get('reshape')()(self, shape)


def expand_as(self, x):
"""
Expand the dimension of target tensor to the dimension of input tensor.

Args:
shape (Tensor): The input tensor. The shape of input tensor must obey
the broadcasting rule.

Returns:
Tensor, has the same dimension as input tensor.
"""
return tensor_operator_registry.get('broadcast_to')(x.shape)(self)


class RowTensor:
"""
A sparse representation of a set of tensor slices at given indices.


+ 4
- 4
mindspore/nn/layer/basic.py View File

@@ -599,14 +599,14 @@ class Interpolate(Cell):

Inputs:
- **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape:
math:'(batch, channels, height, width)', with data type of float32 or float64.
math:`(batch, channels, height, width)`, with data type of float16 or float32.

Outputs:
Resized tensor.
If size is set, the result is 4-D tensor with shape:math:'(batch, channels, new_height, new_width)'
If size is set, the result is 4-D tensor with shape:math:`(batch, channels, new_height, new_width)`
in float32.
If scale is set, the result is 4-D tensor with shape:math:'(batch, channels, scale_factor * height,
scale_factor * width)' in float32
If scale is set, the result is 4-D tensor with shape:math:`(batch, channels, scale_factor * height,
scale_factor * width)` in float32

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``


+ 2
- 0
mindspore/ops/functional.py View File

@@ -173,6 +173,8 @@ tensor_operator_registry.register('__pow__', tensor_pow)
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
tensor_operator_registry.register('all', P.ReduceAll)
tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)


+ 4
- 4
mindspore/ops/operations/array_ops.py View File

@@ -1107,7 +1107,7 @@ class Fill(PrimitiveWithInfer):
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
@@ -1160,7 +1160,7 @@ class Ones(PrimitiveWithInfer):
for i, item in enumerate(shape):
validator.check_non_negative_int(item, shape[i], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
@@ -1214,7 +1214,7 @@ class Zeros(PrimitiveWithInfer):
for i, item in enumerate(shape):
validator.check_non_negative_int(item, shape[i], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
@@ -1269,7 +1269,7 @@ class SequenceMask(PrimitiveWithInfer):
def __infer__(self, lengths, dtype, max_length=None):
validator.check_value_type("shape", lengths['value'], [tuple, list], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_subclass("dtype", dtype['value'], valid_types, self.name)
nptype = mstype.dtype_to_nptype(dtype['value'])


Loading…
Cancel
Save