Browse Source

modify tensor comment

tags/v1.6.0
changzherui 4 years ago
parent
commit
1600e01178
2 changed files with 15 additions and 12 deletions
  1. +11
    -12
      mindspore/common/tensor.py
  2. +4
    -0
      mindspore/train/callback/_callback.py

+ 11
- 12
mindspore/common/tensor.py View File

@@ -567,7 +567,7 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('any')(keep_dims)(self, axis)

def view(self, *shape):
r"""
"""
Reshape the tensor according to the input shape.

Args:
@@ -675,7 +675,7 @@ class Tensor(Tensor_):

def transpose(self, *axes):
r"""
Return a view of the tensor with axes transposed.
Return a tensor with axes transposed.

- For a 1-D tensor, this has no effect, as a transposed vector is simply the same vector.
- For a 2-D tensor, this is a standard matrix transpose.
@@ -696,7 +696,7 @@ class Tensor(Tensor_):

Raises:
TypeError: If input arguments have types not specified above.
ValueError: If the number of `axes` is not equal to a.ndim.
ValueError: If the number of `axes` is not equal to Tensor's ndim.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -727,8 +727,8 @@ class Tensor(Tensor_):
Tensor, with new specified shape.

Raises:
TypeError: If new_shape is not integer, list or tuple.
ValueError: If new_shape is not compatible with the original shape.
TypeError: If new shape is not integer, list or tuple.
ValueError: If new shape is not compatible with the original shape.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -863,7 +863,7 @@ class Tensor(Tensor_):

def squeeze(self, axis=None):
"""
Remove single-dimensional entries from the shape of a tensor.
Remove the dimension of shape 1 from the Tensor

Args:
axis (Union[None, int, list(int), tuple(int)], optional): Selects a subset of the entries of
@@ -875,7 +875,7 @@ class Tensor(Tensor_):

Raises:
TypeError: If input arguments have types not specified above.
ValueError: If specified axis has shape entry :math:`> 1`.
ValueError: If axis is greater than one.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@@ -1026,7 +1026,7 @@ class Tensor(Tensor_):
>>> print(a.argmin())
0
"""
# P.Argmax only supports float
# P.Argmin only supports float
a = self.astype(mstype.float32)
if axis is None:
a = a.ravel()
@@ -1213,7 +1213,6 @@ class Tensor(Tensor_):
:func:`mindspore.Tensor.max`: Return the maximum of a tensor or maximum along an axis.

Examples:
>>> import numpy as np
>>> from mindspore import Tensor
>>> import mindspore.numpy as np
>>> a = Tensor(np.arange(4).reshape((2,2)).astype('float32'))
@@ -1239,7 +1238,7 @@ class Tensor(Tensor_):
value (Union[None, int, float, bool]): All elements of a will be assigned this value.

Returns:
Tensor, with the original dtype and shape as input tensor.
Tensor, with the original dtype and shape.

Raises:
TypeError: If input arguments have types not specified above.
@@ -1552,7 +1551,7 @@ class Tensor(Tensor_):
second axis.

Returns:
Tensor, if `a` is 2-D, then `a` 1-D array containing the diagonal.
Tensor, if Tensor is 2-D, return a 1-D Tensor containing the diagonal.

Raises:
ValueError: if the input tensor has less than two dimensions.
@@ -1843,7 +1842,7 @@ class Tensor(Tensor_):
shape = v.shape
if sorter is not None:
if sorter.ndim != 1 or sorter.size != a.size:
raise ValueError('sorter must be 1-D array with the same size as `a`')
raise ValueError('sorter must be 1-D array with the same size as the Tensor')
sorter = tensor_operator_registry.get('make_tensor')(sorter)
sorter = sorter.reshape(sorter.shape + (1,))
a = tensor_operator_registry.get('gather_nd')(a, sorter)


+ 4
- 0
mindspore/train/callback/_callback.py View File

@@ -80,6 +80,10 @@ class Callback:
You can use this mechanism to initialize and release resources automatically.

Callback function will execute some operations in the current step or epoch.
To create a custom callback, subclass Callback and override the method associated
with the stage of interest. See
https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html#callback
for more information.

It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,


Loading…
Cancel
Save