Browse Source

add data_sync in Tensor.__repr__

tags/v1.1.0
huanghui 5 years ago
parent
commit
1c5cba7b81
2 changed files with 2 additions and 4 deletions
  1. +1
    -0
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  2. +1
    -4
      mindspore/common/tensor.py

+ 1
- 0
mindspore/ccsrc/pybind_api/ir/tensor_py.cc View File

@@ -461,6 +461,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
mindspore.int32 mindspore.int32
)mydelimiter") )mydelimiter")
.def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr) .def("set_cast_dtype", &Tensor::set_cast_dtype, py::arg("dtype") = nullptr)
.def("data_sync", &Tensor::data_sync)
.def("__str__", &Tensor::ToString) .def("__str__", &Tensor::ToString)
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def(py::pickle( .def(py::pickle(


+ 1
- 4
mindspore/common/tensor.py View File

@@ -88,6 +88,7 @@ class Tensor(Tensor_):
self._virtual_flag = False self._virtual_flag = False


def __repr__(self): def __repr__(self):
Tensor_.data_sync(self, False)
return Tensor_.__repr__(self) return Tensor_.__repr__(self)


def __add__(self, other): def __add__(self, other):
@@ -293,7 +294,6 @@ class Tensor(Tensor_):
axis = () axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis) return tensor_operator_registry.get('any')(keep_dims)(self, axis)



def view(self, *shape): def view(self, *shape):
r""" r"""
Reshape the tensor according to the input shape. Reshape the tensor according to the input shape.
@@ -312,7 +312,6 @@ class Tensor(Tensor_):
shape = shape[0] shape = shape[0]
return tensor_operator_registry.get('reshape')()(self, shape) return tensor_operator_registry.get('reshape')()(self, shape)



def expand_as(self, x): def expand_as(self, x):
""" """
Expand the dimension of target tensor to the dimension of input tensor. Expand the dimension of target tensor to the dimension of input tensor.
@@ -326,7 +325,6 @@ class Tensor(Tensor_):
""" """
return tensor_operator_registry.get('broadcast_to')(x.shape)(self) return tensor_operator_registry.get('broadcast_to')(x.shape)(self)



def abs(self): def abs(self):
""" """
Return absolute value element-wisely. Return absolute value element-wisely.
@@ -336,7 +334,6 @@ class Tensor(Tensor_):
""" """
return tensor_operator_registry.get('abs')()(self) return tensor_operator_registry.get('abs')()(self)



def mean(self, axis=(), keep_dims=False): def mean(self, axis=(), keep_dims=False):
""" """
Reduce a dimension of a tensor by averaging all elements in the dimension. Reduce a dimension of a tensor by averaging all elements in the dimension.


Loading…
Cancel
Save