Browse Source

modify Tensor shape

tags/v1.2.0-rc1
lilei 4 years ago
parent
commit
63eb3ed2d9
2 changed files with 13 additions and 0 deletions
  1. +5
    -0
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  2. +8
    -0
      mindspore/common/tensor.py

+ 5
- 0
mindspore/ccsrc/pybind_api/ir/tensor_py.cc View File

@@ -406,6 +406,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const py::array &input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(input, type_ptr);
}),


+ 8
- 0
mindspore/common/tensor.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Tensor implementation."""
import numbers
import numpy as np

from mindspore import log as logger
@@ -43,6 +44,9 @@ class Tensor(Tensor_):
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (class:'Initializer'): the information of init data.
'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.

Outputs:
Tensor, with the same shape as `input_data`.
@@ -76,6 +80,9 @@ class Tensor(Tensor_):
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.")

if isinstance(shape, numbers.Number):
shape = (shape,)

# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None:
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
@@ -575,6 +582,7 @@ class Tensor(Tensor_):
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
"""
Get the tensor format data of this Tensor.
The init_data function can be called once for the same tensor.

Args:
slice_index (int): Slice index of a parameter's slices.


Loading…
Cancel
Save