Browse Source

modify Tensor shape

tags/v1.2.0-rc1
lilei 4 years ago
parent
commit
63eb3ed2d9
2 changed files with 13 additions and 0 deletions
  1. +5
    -0
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  2. +8
    -0
      mindspore/common/tensor.py

+ 5
- 0
mindspore/ccsrc/pybind_api/ir/tensor_py.cc View File

@@ -406,6 +406,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape)); return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}), }),
py::arg("dtype"), py::arg("shape")) py::arg("dtype"), py::arg("shape"))
.def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const py::array &input, const TypePtr &type_ptr) { .def(py::init([](const py::array &input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(input, type_ptr); return TensorPy::MakeTensor(input, type_ptr);
}), }),


+ 8
- 0
mindspore/common/tensor.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Tensor implementation.""" """Tensor implementation."""
import numbers
import numpy as np import numpy as np


from mindspore import log as logger from mindspore import log as logger
@@ -43,6 +44,9 @@ class Tensor(Tensor_):
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None. output. Default: None.
init (class:'Initializer'): the information of init data. init (class:'Initializer'): the information of init data.
'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.


Outputs: Outputs:
Tensor, with the same shape as `input_data`. Tensor, with the same shape as `input_data`.
@@ -76,6 +80,9 @@ class Tensor(Tensor_):
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False: if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.") raise TypeError("input_data and init can not be None at the same time.")


if isinstance(shape, numbers.Number):
shape = (shape,)

# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None: if init is None:
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
@@ -575,6 +582,7 @@ class Tensor(Tensor_):
def init_data(self, slice_index=None, shape=None, opt_shard_group=None): def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
""" """
Get the tensor format data of this Tensor. Get the tensor format data of this Tensor.
The init_data function can be called once for the same tensor.


Args: Args:
slice_index (int): Slice index of a parameter's slices. slice_index (int): Slice index of a parameter's slices.


Loading…
Cancel
Save