From 62dc81678905e742e6f03ba9dcc2673cd2af9822 Mon Sep 17 00:00:00 2001 From: lilei Date: Thu, 4 Mar 2021 14:31:21 +0800 Subject: [PATCH] modify Tensor check --- mindspore/common/tensor.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 2553078688..a1f3bf04ea 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -42,8 +42,8 @@ class Tensor(Tensor_): The argument is used to define the data type of the output tensor. If it is None, the data type of the output tensor will be as same as the `input_data`. Default: None. shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of - output. Default: None. - init (:class:`Initializer`): the information of init data. + output. If `input_data` is available, `shape` doesn't need to be set. Default: None. + init (Initializer): the information of init data. 'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to use 'init' interface to initialize parameters in other conditions. If 'init' interface is used to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data. @@ -52,18 +52,26 @@ class Tensor(Tensor_): Tensor, with the same shape as `input_data`. Examples: + >>> import numpy as np >>> import mindspore as ms - >>> import mindspore.nn as nn + >>> from mindspore.common.tensor import Tensor + >>> from mindspore.common.initializer import One >>> # initialize a tensor with input data - >>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32) + >>> t1 = Tensor(np.zeros([1, 2, 3]), ms.float32) >>> assert isinstance(t1, Tensor) >>> assert t1.shape == (1, 2, 3) - >>> assert t1.dtype == mindspore.float32 + >>> assert t1.dtype == ms.float32 >>> >>> # initialize a tensor with a float scalar >>> t2 = Tensor(0.1) >>> assert isinstance(t2, Tensor) - >>> assert t2.dtype == mindspore.float64 + >>> assert t2.dtype == ms.float64 + ... + >>> # initialize a tensor with init + >>> t3 = Tensor(shape = (1, 3), dtype=ms.float32, init=One()) + >>> assert isinstance(t3, Tensor) + >>> assert t3.shape == (1, 3) + >>> assert t3.dtype == ms.float32 """ def __init__(self, input_data=None, dtype=None, shape=None, init=None): @@ -71,8 +79,8 @@ class Tensor(Tensor_): if isinstance(input_data, np_types): input_data = np.array(input_data) - if input_data is not None and shape is not None and input_data.shape != shape: - raise ValueError("input_data.shape and shape should be same.") + if input_data is not None and shape is not None: + raise ValueError("If input_data is available, shape doesn't need to be set") if init is not None and (shape is None or dtype is None): raise ValueError("init, dtype and shape must have values at the same time.")