From: @yanglf1121 Reviewed-by: @kingxian,@zhunaipan Signed-off-by: @kingxiantags/v1.1.0
| @@ -635,7 +635,7 @@ def check_input_data(*data, data_class): | |||
| f' either a single' | |||
| f' or a list of {data_class.__name__},' | |||
| f' but got part data type is {str(type(item))}.') | |||
| if hasattr(item, "size") and item.size() == 0: | |||
| if hasattr(item, "size") and item.size == 0: | |||
| msg = "Please provide non-empty data." | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| @@ -31,6 +31,8 @@ shape_ = P.Shape() | |||
| reshape_ = P.Reshape() | |||
| dtype_ = P.DType() | |||
| abs_ = P.Abs() | |||
| ndim_ = P.Rank() | |||
| size_ = P.Size() | |||
| def mean(x, axis=(), keep_dims=False): | |||
| """ | |||
| @@ -192,6 +192,8 @@ BuiltInTypeMap &GetAttrMap() { | |||
| { | |||
| {"shape", std::string("shape_")}, // C.shape_ | |||
| {"dtype", std::string("dtype_")}, // C.dtype_ | |||
| {"size", std::string("size_")}, // C.size_ | |||
| {"ndim", std::string("ndim_")}, // C.ndim_ | |||
| }}, | |||
| {kObjectTypeRowTensorType, | |||
| { | |||
| @@ -370,6 +370,17 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| >>> data.shape() | |||
| (3, 3) | |||
| )mydelimiter") | |||
| .def_property_readonly("_size", &Tensor::DataSize, R"mydelimiter( | |||
| Get tensor's data size. | |||
| Returns: | |||
| int, the size of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.size | |||
| 6 | |||
| )mydelimiter") | |||
| .def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter( | |||
| Creates a Tensor from a numpy.ndarray without copy. | |||
| @@ -396,17 +407,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| array([[1., 1., 1.], | |||
| [1., 1., 1.]]) | |||
| )mydelimiter") | |||
| .def("size", &Tensor::DataSize, R"mydelimiter( | |||
| Get tensor's data size. | |||
| Returns: | |||
| int, the size of tensor. | |||
| Examples: | |||
| >>> data = mindspore.Tensor(np.ones((2, 3))) | |||
| >>> data.size() | |||
| 6 | |||
| )mydelimiter") | |||
| .def("is_init", &Tensor::is_init, R"mydelimiter( | |||
| Get tensor init_flag. | |||
| @@ -237,6 +237,16 @@ class Tensor(Tensor_): | |||
| """The dtype of tensor is a mindspore type.""" | |||
| return self._dtype | |||
| @property | |||
| def size(self): | |||
| """The size reflects the total number of elements in tensor.""" | |||
| return self._size | |||
| @property | |||
| def ndim(self): | |||
| """The ndim of tensor is an integer.""" | |||
| return len(self._shape) | |||
| @property | |||
| def virtual_flag(self): | |||
| """Mark tensor is virtual.""" | |||
| @@ -277,7 +277,7 @@ class Dense(Cell): | |||
| self.shape_op = P.Shape() | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("Weight init shape error.") | |||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | |||
| @@ -285,7 +285,7 @@ class Dense(Cell): | |||
| self.bias = None | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("Bias init shape error.") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| @@ -318,7 +318,7 @@ class BatchNorm1d(_BatchNorm): | |||
| input_dims='1d') | |||
| def _check_data_dim(self, x): | |||
| if x.dim() != 2: | |||
| if x.ndim != 2: | |||
| pass | |||
| @@ -415,7 +415,7 @@ class BatchNorm2d(_BatchNorm): | |||
| data_format=data_format) | |||
| def _check_data_dim(self, x): | |||
| if x.dim() != 4: | |||
| if x.ndim != 4: | |||
| pass | |||
| @@ -1013,7 +1013,7 @@ class DenseQuant(Cell): | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -1022,7 +1022,7 @@ class DenseQuant(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -278,17 +278,17 @@ class Optimizer(Cell): | |||
| learning_rate = float(learning_rate) | |||
| validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name) | |||
| return learning_rate | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: | |||
| if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: | |||
| return learning_rate | |||
| self.dynamic_lr = True | |||
| if isinstance(learning_rate, Iterable): | |||
| return Tensor(np.array(list(learning_rate)).astype(np.float32)) | |||
| if isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() > 1: | |||
| if learning_rate.ndim > 1: | |||
| raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1," | |||
| f"but got {learning_rate.dim()}.") | |||
| if learning_rate.dim() == 1 and learning_rate.size() < 2: | |||
| f"but got {learning_rate.ndim}.") | |||
| if learning_rate.ndim == 1 and learning_rate.size < 2: | |||
| logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number" | |||
| "of elements in the tensor passed is greater than 1.") | |||
| return learning_rate | |||
| @@ -303,12 +303,12 @@ class Optimizer(Cell): | |||
| if self.is_group_lr and self.dynamic_lr: | |||
| learning_rate = _ConvertToCell(learning_rate) | |||
| return learning_rate | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: | |||
| if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0: | |||
| learning_rate = Parameter(learning_rate, name) | |||
| if self.is_group_lr and self.dynamic_lr: | |||
| learning_rate = _ConvertToCell(learning_rate) | |||
| return learning_rate | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: | |||
| if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: | |||
| return _IteratorLearningRate(learning_rate, name) | |||
| return learning_rate | |||
| @@ -338,8 +338,8 @@ class Optimizer(Cell): | |||
| def _parse_group_params(self, parameters, learning_rate): | |||
| """Parse group params.""" | |||
| self._check_group_params(parameters) | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 1: | |||
| tensor_lr_length = learning_rate.size() | |||
| if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1: | |||
| tensor_lr_length = learning_rate.size | |||
| else: | |||
| tensor_lr_length = 0 | |||
| @@ -357,8 +357,8 @@ class Optimizer(Cell): | |||
| self.is_group_lr = True | |||
| group_lr = self._preprocess_single_lr(group_param['lr']) | |||
| if isinstance(group_lr, Tensor) and group_lr.dim() == 1: | |||
| group_lr_length = group_lr.size() | |||
| if isinstance(group_lr, Tensor) and group_lr.ndim == 1: | |||
| group_lr_length = group_lr.size | |||
| if tensor_lr_length == 0: | |||
| tensor_lr_length = group_lr_length | |||
| elif group_lr_length != tensor_lr_length: | |||
| @@ -617,9 +617,9 @@ class _IteratorLearningRate(LearningRateSchedule): | |||
| def __init__(self, learning_rate, name): | |||
| super(_IteratorLearningRate, self).__init__() | |||
| if isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() != 1: | |||
| if learning_rate.ndim != 1: | |||
| raise ValueError("The dim of `Tensor` type dynamic learning rate should be a 1," | |||
| f"but got {learning_rate.dim()}.") | |||
| f"but got {learning_rate.ndim}.") | |||
| else: | |||
| raise TypeError("Learning rate should be Tensor.") | |||
| @@ -143,7 +143,7 @@ class ExpandDims(PrimitiveWithInfer): | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| - **axis** (int) - Specifies the dimension index at which to expand | |||
| the shape of `input_x`. The value of axis must be in the range | |||
| `[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed. | |||
| `[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the | |||
| @@ -597,7 +597,7 @@ class Squeeze(PrimitiveWithInfer): | |||
| Returns a tensor with the same type but dimensions of 1 are removed based on `axis`. | |||
| Note: | |||
| The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`. | |||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim`. | |||
| Raises: | |||
| ValueError: If the corresponding dimension of the specified axis does not equal to 1. | |||
| @@ -238,7 +238,7 @@ def _load_tensor_by_layout(tensor, layout): | |||
| group = layout[5] | |||
| if uniform_split == 0: | |||
| raise RuntimeError("The load tensor only support uniform split now") | |||
| if tensor.size() == 1: | |||
| if tensor.size == 1: | |||
| return tensor | |||
| tensor_slice = _load_tensor(tensor, dev_mat, tensor_map) | |||
| if group: | |||
| @@ -294,7 +294,7 @@ class Dense_Thor_GPU(Cell): | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| self.thor = True | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -302,7 +302,7 @@ class Dense_Thor_GPU(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels])) | |||
| @@ -639,7 +639,7 @@ class Dense_Thor(Cell): | |||
| self.thor = True | |||
| self.batch_size = batch_size | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -647,7 +647,7 @@ class Dense_Thor(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels])) | |||
| @@ -50,16 +50,16 @@ def convert(weights_file, output_file): | |||
| var = params[i+2] | |||
| gamma = params[i+3] | |||
| beta = params[i+4] | |||
| beta_data = weights[index: index+beta.size()].reshape(beta.shape) | |||
| index += beta.size() | |||
| gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape) | |||
| index += gamma.size() | |||
| mean_data = weights[index: index+mean.size()].reshape(mean.shape) | |||
| index += mean.size() | |||
| var_data = weights[index: index + var.size()].reshape(var.shape) | |||
| index += var.size() | |||
| weight_data = weights[index: index+weight.size()].reshape(weight.shape) | |||
| index += weight.size() | |||
| beta_data = weights[index: index+beta.size].reshape(beta.shape) | |||
| index += beta.size | |||
| gamma_data = weights[index: index+gamma.size].reshape(gamma.shape) | |||
| index += gamma.size | |||
| mean_data = weights[index: index+mean.size].reshape(mean.shape) | |||
| index += mean.size | |||
| var_data = weights[index: index + var.size].reshape(var.shape) | |||
| index += var.size | |||
| weight_data = weights[index: index+weight.size].reshape(weight.shape) | |||
| index += weight.size | |||
| param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape, | |||
| 'data': Tensor(weight_data)}) | |||
| @@ -76,7 +76,7 @@ class GNNFeatureTransform(nn.Cell): | |||
| self.has_bias = has_bias | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -84,7 +84,7 @@ class GNNFeatureTransform(nn.Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels])) | |||
| @@ -162,7 +162,7 @@ class Dense_Thor(Cell): | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| self.thor = True | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \ | |||
| weight_init.shape()[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -170,7 +170,7 @@ class Dense_Thor(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels])) | |||
| @@ -45,7 +45,7 @@ class DenseLayer(nn.Cell): | |||
| self.has_bias = validator.check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \ | |||
| weight_init.shape()[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -53,7 +53,7 @@ class DenseLayer(nn.Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels])) | |||
| @@ -78,7 +78,7 @@ class GNNFeatureTransform(nn.Cell): | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -86,7 +86,7 @@ class GNNFeatureTransform(nn.Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| @@ -342,7 +342,7 @@ class Dense_Thor(Cell): | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| self.thor = True | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| @@ -350,7 +350,7 @@ class Dense_Thor(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| @@ -35,7 +35,7 @@ class Net(nn.Cell): | |||
| self.biasAdd = P.BiasAdd() | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -33,7 +33,7 @@ class Net(nn.Cell): | |||
| self.biasAdd = P.BiasAdd() | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -65,7 +65,7 @@ def test_bias_add(test_with_simu): | |||
| self.biasAdd = P.BiasAdd() | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != output_channels: | |||
| if bias_init.ndim != 1 or bias_init.shape[0] != output_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -24,7 +24,7 @@ from ..ut_filter import non_graph_engine | |||
| def _attribute(tensor, shape_, size_, dtype_): | |||
| result = (tensor.shape == shape_) and \ | |||
| (tensor.size() == size_) and \ | |||
| (tensor.size == size_) and \ | |||
| (tensor.dtype == dtype_) | |||
| return result | |||
| @@ -60,13 +60,13 @@ def test_tensor_mul(): | |||
| def test_tensor_dim(): | |||
| arr = np.ones((1, 6)) | |||
| b = ms.Tensor(arr) | |||
| assert b.dim() == 2 | |||
| assert b.ndim == 2 | |||
| def test_tensor_size(): | |||
| arr = np.ones((1, 6)) | |||
| b = ms.Tensor(arr) | |||
| assert arr.size == b.size() | |||
| assert arr.size == b.size | |||
| def test_dtype(): | |||