Browse Source

modify Tensor

tags/v1.2.0-rc1
lilei 4 years ago
parent
commit
7d9f5f6dba
6 changed files with 75 additions and 9 deletions
  1. +2
    -2
      mindspore/common/__init__.py
  2. +2
    -4
      mindspore/common/parameter.py
  3. +33
    -3
      mindspore/common/tensor.py
  4. +7
    -0
      tests/ut/python/ir/test_tensor_py.py
  5. +16
    -0
      tests/ut/python/pipeline/parse/test_expand_as.py
  6. +15
    -0
      tests/ut/python/pipeline/parse/test_view.py

+ 2
- 2
mindspore/common/__init__.py View File

@@ -17,13 +17,13 @@ from . import dtype
from .api import ms_function from .api import ms_function
from .dtype import * from .dtype import *
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
from .tensor import Tensor, RowTensor, SparseTensor
from .seed import set_seed, get_seed from .seed import set_seed, get_seed




__all__ = dtype.__all__ __all__ = dtype.__all__
__all__.extend([ __all__.extend([
"MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
"Tensor", "RowTensor", "SparseTensor", # tensor
'ms_function', # api 'ms_function', # api
'Parameter', 'ParameterTuple', # parameter 'Parameter', 'ParameterTuple', # parameter
"dtype", "dtype",


+ 2
- 4
mindspore/common/parameter.py View File

@@ -100,7 +100,6 @@ class Parameter(Tensor_):
... def construct(self, x): ... def construct(self, x):
... out = self.matmul(self.weight, x) ... out = self.matmul(self.weight, x)
... return out ... return out
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> net = Net() >>> net = Net()
>>> x = Tensor(np.ones((2,1))) >>> x = Tensor(np.ones((2,1)))
>>> print(net(x)) >>> print(net(x))
@@ -113,15 +112,14 @@ class Parameter(Tensor_):
__base_type__ = {} __base_type__ = {}


def __new__(cls, default_input, *args, **kwargs): def __new__(cls, default_input, *args, **kwargs):
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input) input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
new_type = Parameter._get_base_class(input_class) new_type = Parameter._get_base_class(input_class)
obj = input_class.__new__(new_type) obj = input_class.__new__(new_type)
input_class.__init__(obj, *class_init_args) input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of tensor. # it's better to make the Initializer a kind of tensor.
obj.init_mode = None obj.init_mode = None
obj.is_default_input_init = False
if isinstance(default_input, Tensor) and default_input.has_init:
obj.is_default_input_init = True
obj.is_default_input_init = init_data_flag
if obj.has_init: if obj.has_init:
obj.init_mode = default_input obj.init_mode = default_input
return obj return obj


+ 33
- 3
mindspore/common/tensor.py View File

@@ -19,11 +19,10 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size from mindspore.communication.management import get_rank, get_group_size
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
from .._c_expression import MetaTensor
from .._c_expression import Tensor as Tensor_ from .._c_expression import Tensor as Tensor_
from .._checkparam import Validator as validator from .._checkparam import Validator as validator


__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64, np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_) np.float32, np.float64, np.bool_)
@@ -41,6 +40,9 @@ class Tensor(Tensor_):
dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`. dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`.
The argument is used to define the data type of the output tensor. If it is None, the data type of the The argument is used to define the data type of the output tensor. If it is None, the data type of the
output tensor will be as same as the `input_data`. Default: None. output tensor will be as same as the `input_data`. Default: None.
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (class:'Initializer'): the information of init data.


Outputs: Outputs:
Tensor, with the same shape as `input_data`. Tensor, with the same shape as `input_data`.
@@ -65,6 +67,12 @@ class Tensor(Tensor_):
if isinstance(input_data, np_types): if isinstance(input_data, np_types):
input_data = np.array(input_data) input_data = np.array(input_data)


if input_data is not None and shape is not None and input_data.shape != shape:
raise ValueError("input_data.shape and shape should be same.")

if init is not None and (shape is None or dtype is None):
raise ValueError("init, dtype and shape must have values at the same time.")

if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False: if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.") raise TypeError("input_data and init can not be None at the same time.")


@@ -306,10 +314,12 @@ class Tensor(Tensor_):


def asnumpy(self): def asnumpy(self):
"""Convert tensor to numpy array.""" """Convert tensor to numpy array."""
self.init_check()
return Tensor_.asnumpy(self) return Tensor_.asnumpy(self)


def _flush_from_cache(self): def _flush_from_cache(self):
"""Flush cache data to host if tensor is cache enable.""" """Flush cache data to host if tensor is cache enable."""
self.init_check()
Tensor_._flush_from_cache(self) Tensor_._flush_from_cache(self)


def all(self, axis=(), keep_dims=False): def all(self, axis=(), keep_dims=False):
@@ -327,6 +337,7 @@ class Tensor(Tensor_):
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """


self.init_check()
if axis is None: if axis is None:
axis = () axis = ()
return tensor_operator_registry.get('all')(keep_dims)(self, axis) return tensor_operator_registry.get('all')(keep_dims)(self, axis)
@@ -346,6 +357,7 @@ class Tensor(Tensor_):
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """


self.init_check()
if axis is None: if axis is None:
axis = () axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis) return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@@ -360,6 +372,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same dimension as the input shape. Tensor, has the same dimension as the input shape.
""" """
self.init_check()
if not shape: if not shape:
raise ValueError("The shape variable should not be empty") raise ValueError("The shape variable should not be empty")
if isinstance(shape[0], tuple): if isinstance(shape[0], tuple):
@@ -379,6 +392,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same dimension as input tensor. Tensor, has the same dimension as input tensor.
""" """
self.init_check()
return tensor_operator_registry.get('broadcast_to')(x.shape)(self) return tensor_operator_registry.get('broadcast_to')(x.shape)(self)


def abs(self): def abs(self):
@@ -388,6 +402,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
self.init_check()
return tensor_operator_registry.get('abs')()(self) return tensor_operator_registry.get('abs')()(self)


def mean(self, axis=(), keep_dims=False): def mean(self, axis=(), keep_dims=False):
@@ -404,6 +419,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
self.init_check()
if axis is None: if axis is None:
axis = () axis = ()
return tensor_operator_registry.get('mean')(keep_dims)(self, axis) return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
@@ -429,6 +445,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same dimension as input tensor, with axes suitably permuted. Tensor, has the same dimension as input tensor, with axes suitably permuted.
""" """
self.init_check()
perm = validator.check_transpose_axis(axes, self.ndim) perm = validator.check_transpose_axis(axes, self.ndim)
return tensor_operator_registry.get('transpose')()(self, perm) return tensor_operator_registry.get('transpose')()(self, perm)


@@ -446,6 +463,7 @@ class Tensor(Tensor_):
reshaped_tensor(Tensor): This will be a new view object if possible; reshaped_tensor(Tensor): This will be a new view object if possible;
otherwise, it will be a copy. otherwise, it will be a copy.
""" """
self.init_check()
new_shape = validator.check_reshape_shp(shape) new_shape = validator.check_reshape_shp(shape)
return tensor_operator_registry.get('reshape')()(self, new_shape) return tensor_operator_registry.get('reshape')()(self, new_shape)


@@ -457,6 +475,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')() reshape_op = tensor_operator_registry.get('reshape')()
return reshape_op(self, (-1,)) return reshape_op(self, (-1,))


@@ -472,6 +491,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, has the same data type as x. Tensor, has the same data type as x.
""" """
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')() reshape_op = tensor_operator_registry.get('reshape')()
trans_op = tensor_operator_registry.get('transpose')() trans_op = tensor_operator_registry.get('transpose')()


@@ -493,6 +513,7 @@ class Tensor(Tensor_):
Returns: Returns:
Transposed tensor, has the same data type as the original tensor x. Transposed tensor, has the same data type as the original tensor x.
""" """
self.init_check()
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim) axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)


if axis1 == axis2: if axis1 == axis2:
@@ -521,6 +542,7 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, with all or a subset of the dimensions of length 1 removed. Tensor, with all or a subset of the dimensions of length 1 removed.
""" """
self.init_check()
if axis is None: if axis is None:
return tensor_operator_registry.get('squeeze')(self) return tensor_operator_registry.get('squeeze')(self)
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis) new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
@@ -542,12 +564,18 @@ class Tensor(Tensor_):
Returns: Returns:
Tensor, with the designated dtype. Tensor, with the designated dtype.
""" """
self.init_check()
dtype = validator.check_astype_dtype(dtype) dtype = validator.check_astype_dtype(dtype)
if not copy and dtype == self.dtype: if not copy and dtype == self.dtype:
return self return self
return tensor_operator_registry.get('cast')(self, dtype) return tensor_operator_registry.get('cast')(self, dtype)




def init_check(self):
if self.has_init:
self.init_data()
return self

def init_data(self, slice_index=None, shape=None, opt_shard_group=None): def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
""" """
Get the tensor format data of this Tensor. Get the tensor format data of this Tensor.
@@ -601,7 +629,9 @@ class Tensor(Tensor_):
rank = get_rank(opt_shard_group) rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group) size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank] data = np.split(data, size)[rank]
return Tensor(data, dtype=self.dtype)
self.init = None
self.assign_value(Tensor(data, dtype=self.dtype))
return self




def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None): def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):


+ 7
- 0
tests/ut/python/ir/test_tensor_py.py View File

@@ -16,6 +16,7 @@
import numpy as np import numpy as np


import mindspore as ms import mindspore as ms
import mindspore.common.initializer as init
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
@@ -99,6 +100,12 @@ def test_asnumpy():
assert a.asnumpy().all() == npd.all() assert a.asnumpy().all() == npd.all()




def test_initializer_asnumpy():
npd = np.ones((2, 3))
a = init.initializer('one', [2, 3], ms.int32)
assert a.asnumpy().all() == npd.all()


def test_print(): def test_print():
a = ms.Tensor(np.ones((2, 3))) a = ms.Tensor(np.ones((2, 3)))
a.set_dtype(ms.int32) a.set_dtype(ms.int32)


+ 16
- 0
tests/ut/python/pipeline/parse/test_expand_as.py View File

@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test expand_as""" """ test expand_as"""
import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as init
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context


@@ -34,6 +36,20 @@ def test_expand_as():
net() net()




def test_initializer_expand_as():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t1 = init.initializer('one', [1, 3], ms.float32)
self.t2 = init.initializer('one', [2, 3], ms.float32)

def construct(self):
return self.t1.expand_as(self.t2)

net = Net()
net()


def test_expand_as_parameter(): def test_expand_as_parameter():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):


+ 15
- 0
tests/ut/python/pipeline/parse/test_view.py View File

@@ -15,7 +15,9 @@
""" test view""" """ test view"""
import pytest import pytest


import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as init
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context


@@ -35,6 +37,19 @@ def test_view():
net() net()




def test_view_initializer():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = init.initializer('normal', [2, 3], ms.float32)

def construct(self):
return self.value.view(-1)

net = Net()
net()


def test_view_1(): def test_view_1():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):


Loading…
Cancel
Save