GitOrigin-RevId: f9caf17d24
tags/v1.0.0-rc1
| @@ -70,7 +70,7 @@ def set_default_device(device: str = "xpux"): | |||||
| multi-threading parallelism at the operator level. For example, | multi-threading parallelism at the operator level. For example, | ||||
| 'multithread4' will compute with 4 threads. which implements | 'multithread4' will compute with 4 threads. which implements | ||||
| The default value is 'xpux' to specify any device available. | |||||
| The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available. | |||||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | ||||
| """ | """ | ||||
| @@ -11,6 +11,8 @@ | |||||
| import collections | import collections | ||||
| from .core import Tensor as _Tensor | from .core import Tensor as _Tensor | ||||
| from .core.ops.builtin import Copy | |||||
| from .core.tensor.core import apply | |||||
| from .device import get_default_device | from .device import get_default_device | ||||
| @@ -30,6 +32,9 @@ class Tensor(_Tensor): | |||||
| def reset_zero(self): | def reset_zero(self): | ||||
| self *= 0 | self *= 0 | ||||
| def to(self, cn): | |||||
| return apply(Copy(comp_node=cn), self)[0] | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| r""" __getstate__ will be called for pickle serialization or deep copy | r""" __getstate__ will be called for pickle serialization or deep copy | ||||
| """ | """ | ||||
| @@ -322,6 +322,8 @@ def copy_test(dst, src): | |||||
| x = tensor(data, device=src) | x = tensor(data, device=src) | ||||
| y = F.copy(x, dst) | y = F.copy(x, dst) | ||||
| assert np.allclose(data, y.numpy()) | assert np.allclose(data, y.numpy()) | ||||
| z = x.to(dst) | |||||
| assert np.allclose(data, z.numpy()) | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||