| @@ -7,6 +7,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import os | import os | ||||
| import re | |||||
| from .core._imperative_rt.common import CompNode, DeviceType | from .core._imperative_rt.common import CompNode, DeviceType | ||||
| from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | ||||
| @@ -22,10 +23,8 @@ __all__ = [ | |||||
| def _valid_device(inp): | def _valid_device(inp): | ||||
| if isinstance(inp, str) and len(inp) == 4: | |||||
| if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": | |||||
| if inp[3] == "x" or inp[3].isdigit(): | |||||
| return True | |||||
| if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp): | |||||
| return True | |||||
| return False | return False | ||||
| @@ -14,7 +14,7 @@ from .core import Tensor as _Tensor | |||||
| from .core.ops.builtin import Copy | from .core.ops.builtin import Copy | ||||
| from .core.tensor.core import apply | from .core.tensor.core import apply | ||||
| from .core.tensor.raw_tensor import as_device | from .core.tensor.raw_tensor import as_device | ||||
| from .device import get_default_device | |||||
| from .device import _valid_device, get_default_device | |||||
| from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
| @@ -37,6 +37,12 @@ class Tensor(_Tensor): | |||||
| self *= 0 | self *= 0 | ||||
| def to(self, device): | def to(self, device): | ||||
| if isinstance(device, str) and not _valid_device(device): | |||||
| raise ValueError( | |||||
| "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format( | |||||
| device | |||||
| ) | |||||
| ) | |||||
| cn = as_device(device).to_c() | cn = as_device(device).to_c() | ||||
| return apply(Copy(comp_node=cn), self)[0] | return apply(Copy(comp_node=cn), self)[0] | ||||