diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index 64f7d03e02..39afd6410b 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -17,13 +17,13 @@ from . import dtype from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple -from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor +from .tensor import Tensor, RowTensor, SparseTensor from .seed import set_seed, get_seed __all__ = dtype.__all__ __all__.extend([ - "MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor + "Tensor", "RowTensor", "SparseTensor", # tensor 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype", diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index c113f77fe8..4779ed8107 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -100,7 +100,6 @@ class Parameter(Tensor_): ... def construct(self, x): ... out = self.matmul(self.weight, x) ... return out - >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") >>> net = Net() >>> x = Tensor(np.ones((2,1))) >>> print(net(x)) @@ -113,15 +112,14 @@ class Parameter(Tensor_): __base_type__ = {} def __new__(cls, default_input, *args, **kwargs): + init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init) input_class, *class_init_args = Parameter._get_parameter_new_args(default_input) new_type = Parameter._get_base_class(input_class) obj = input_class.__new__(new_type) input_class.__init__(obj, *class_init_args) # it's better to make the Initializer a kind of tensor. obj.init_mode = None - obj.is_default_input_init = False - if isinstance(default_input, Tensor) and default_input.has_init: - obj.is_default_input_init = True + obj.is_default_input_init = init_data_flag if obj.has_init: obj.init_mode = default_input return obj diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index fef71083f1..5af786ab0b 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -19,11 +19,10 @@ from mindspore import log as logger from mindspore.communication.management import get_rank, get_group_size from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry -from .._c_expression import MetaTensor from .._c_expression import Tensor as Tensor_ from .._checkparam import Validator as validator -__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor'] +__all__ = ['Tensor', 'RowTensor', 'SparseTensor'] np_types = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_) @@ -41,6 +40,9 @@ class Tensor(Tensor_): dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`. The argument is used to define the data type of the output tensor. If it is None, the data type of the output tensor will be as same as the `input_data`. Default: None. + shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of + output. Default: None. + init (class:'Initializer'): the information of init data. Outputs: Tensor, with the same shape as `input_data`. @@ -65,6 +67,12 @@ class Tensor(Tensor_): if isinstance(input_data, np_types): input_data = np.array(input_data) + if input_data is not None and shape is not None and input_data.shape != shape: + raise ValueError("input_data.shape and shape should be same.") + + if init is not None and (shape is None or dtype is None): + raise ValueError("init, dtype and shape must have values at the same time.") + if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False: raise TypeError("input_data and init can not be None at the same time.") @@ -306,10 +314,12 @@ class Tensor(Tensor_): def asnumpy(self): """Convert tensor to numpy array.""" + self.init_check() return Tensor_.asnumpy(self) def _flush_from_cache(self): """Flush cache data to host if tensor is cache enable.""" + self.init_check() Tensor_._flush_from_cache(self) def all(self, axis=(), keep_dims=False): @@ -327,6 +337,7 @@ class Tensor(Tensor_): Tensor, has the same data type as x. """ + self.init_check() if axis is None: axis = () return tensor_operator_registry.get('all')(keep_dims)(self, axis) @@ -346,6 +357,7 @@ class Tensor(Tensor_): Tensor, has the same data type as x. """ + self.init_check() if axis is None: axis = () return tensor_operator_registry.get('any')(keep_dims)(self, axis) @@ -360,6 +372,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same dimension as the input shape. """ + self.init_check() if not shape: raise ValueError("The shape variable should not be empty") if isinstance(shape[0], tuple): @@ -379,6 +392,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same dimension as input tensor. """ + self.init_check() return tensor_operator_registry.get('broadcast_to')(x.shape)(self) def abs(self): @@ -388,6 +402,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same data type as x. """ + self.init_check() return tensor_operator_registry.get('abs')()(self) def mean(self, axis=(), keep_dims=False): @@ -404,6 +419,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same data type as x. """ + self.init_check() if axis is None: axis = () return tensor_operator_registry.get('mean')(keep_dims)(self, axis) @@ -429,6 +445,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same dimension as input tensor, with axes suitably permuted. """ + self.init_check() perm = validator.check_transpose_axis(axes, self.ndim) return tensor_operator_registry.get('transpose')()(self, perm) @@ -446,6 +463,7 @@ class Tensor(Tensor_): reshaped_tensor(Tensor): This will be a new view object if possible; otherwise, it will be a copy. """ + self.init_check() new_shape = validator.check_reshape_shp(shape) return tensor_operator_registry.get('reshape')()(self, new_shape) @@ -457,6 +475,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same data type as x. """ + self.init_check() reshape_op = tensor_operator_registry.get('reshape')() return reshape_op(self, (-1,)) @@ -472,6 +491,7 @@ class Tensor(Tensor_): Returns: Tensor, has the same data type as x. """ + self.init_check() reshape_op = tensor_operator_registry.get('reshape')() trans_op = tensor_operator_registry.get('transpose')() @@ -493,6 +513,7 @@ class Tensor(Tensor_): Returns: Transposed tensor, has the same data type as the original tensor x. """ + self.init_check() axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim) if axis1 == axis2: @@ -521,6 +542,7 @@ class Tensor(Tensor_): Returns: Tensor, with all or a subset of the dimensions of length 1 removed. """ + self.init_check() if axis is None: return tensor_operator_registry.get('squeeze')(self) new_shape = validator.prepare_shape_for_squeeze(self.shape, axis) @@ -542,12 +564,18 @@ class Tensor(Tensor_): Returns: Tensor, with the designated dtype. """ + self.init_check() dtype = validator.check_astype_dtype(dtype) if not copy and dtype == self.dtype: return self return tensor_operator_registry.get('cast')(self, dtype) + def init_check(self): + if self.has_init: + self.init_data() + return self + def init_data(self, slice_index=None, shape=None, opt_shard_group=None): """ Get the tensor format data of this Tensor. @@ -601,7 +629,9 @@ class Tensor(Tensor_): rank = get_rank(opt_shard_group) size = get_group_size(opt_shard_group) data = np.split(data, size)[rank] - return Tensor(data, dtype=self.dtype) + self.init = None + self.assign_value(Tensor(data, dtype=self.dtype)) + return self def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None): diff --git a/tests/ut/python/ir/test_tensor_py.py b/tests/ut/python/ir/test_tensor_py.py index 5db563cffe..69842b32c6 100644 --- a/tests/ut/python/ir/test_tensor_py.py +++ b/tests/ut/python/ir/test_tensor_py.py @@ -16,6 +16,7 @@ import numpy as np import mindspore as ms +import mindspore.common.initializer as init from mindspore.common.api import _executor from mindspore.nn import Cell from mindspore.ops import operations as P @@ -99,6 +100,12 @@ def test_asnumpy(): assert a.asnumpy().all() == npd.all() +def test_initializer_asnumpy(): + npd = np.ones((2, 3)) + a = init.initializer('one', [2, 3], ms.int32) + assert a.asnumpy().all() == npd.all() + + def test_print(): a = ms.Tensor(np.ones((2, 3))) a.set_dtype(ms.int32) diff --git a/tests/ut/python/pipeline/parse/test_expand_as.py b/tests/ut/python/pipeline/parse/test_expand_as.py index 36583e065f..9cbe45753b 100644 --- a/tests/ut/python/pipeline/parse/test_expand_as.py +++ b/tests/ut/python/pipeline/parse/test_expand_as.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================ """ test expand_as""" +import mindspore as ms import mindspore.nn as nn +import mindspore.common.initializer as init from mindspore import Tensor from mindspore import context @@ -34,6 +36,20 @@ def test_expand_as(): net() +def test_initializer_expand_as(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.t1 = init.initializer('one', [1, 3], ms.float32) + self.t2 = init.initializer('one', [2, 3], ms.float32) + + def construct(self): + return self.t1.expand_as(self.t2) + + net = Net() + net() + + def test_expand_as_parameter(): class Net(nn.Cell): def __init__(self): diff --git a/tests/ut/python/pipeline/parse/test_view.py b/tests/ut/python/pipeline/parse/test_view.py index 13e085d93c..36b25d1e34 100644 --- a/tests/ut/python/pipeline/parse/test_view.py +++ b/tests/ut/python/pipeline/parse/test_view.py @@ -15,7 +15,9 @@ """ test view""" import pytest +import mindspore as ms import mindspore.nn as nn +import mindspore.common.initializer as init from mindspore import Tensor from mindspore import context @@ -35,6 +37,19 @@ def test_view(): net() +def test_view_initializer(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = init.initializer('normal', [2, 3], ms.float32) + + def construct(self): + return self.value.view(-1) + + net = Net() + net() + + def test_view_1(): class Net(nn.Cell): def __init__(self):