Browse Source

!11935 Modify the type of Tensor's shape

From: @Somnus2020
Reviewed-by: @kingxian,@zhoufeng54
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
3bfcf8947c
2 changed files with 13 additions and 0 deletions
  1. +5
    -0
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  2. +8
    -0
      mindspore/common/tensor.py

+ 5
- 0
mindspore/ccsrc/pybind_api/ir/tensor_py.cc View File

@@ -406,6 +406,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const py::array &input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(input, type_ptr);
}),


+ 8
- 0
mindspore/common/tensor.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Tensor implementation."""
import numbers
import numpy as np

from mindspore import log as logger
@@ -43,6 +44,9 @@ class Tensor(Tensor_):
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (class:'Initializer'): the information of init data.
'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.

Outputs:
Tensor, with the same shape as `input_data`.
@@ -76,6 +80,9 @@ class Tensor(Tensor_):
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.")

if isinstance(shape, numbers.Number):
shape = (shape,)

# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None:
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
@@ -575,6 +582,7 @@ class Tensor(Tensor_):
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
"""
Get the tensor format data of this Tensor.
The init_data function can be called once for the same tensor.

Args:
slice_index (int): Slice index of a parameter's slices.


Loading…
Cancel
Save