GitOrigin-RevId: 050c3864a7
tags/v1.0.0
| @@ -22,11 +22,13 @@ class Device: | |||||
| else: | else: | ||||
| self._cn = CompNode(device) | self._cn = CompNode(device) | ||||
| self.logical_name = self._cn.logical_name | |||||
| def to_c(self): | def to_c(self): | ||||
| return self._cn | return self._cn | ||||
| def __repr__(self): | def __repr__(self): | ||||
| return "{}({})".format(type(self).__qualname__, self) | |||||
| return "{}({})".format(type(self).__qualname__, repr(self._cn)) | |||||
| def __str__(self): | def __str__(self): | ||||
| return str(self._cn) | return str(self._cn) | ||||
| @@ -67,7 +67,7 @@ class Tensor(_Tensor): | |||||
| state = { | state = { | ||||
| "data": self.numpy(), | "data": self.numpy(), | ||||
| "device": str(self.device), | |||||
| "device": self.device.logical_name, | |||||
| "dtype": self.dtype, | "dtype": self.dtype, | ||||
| "qdict": self.q_dict, | "qdict": self.q_dict, | ||||
| } | } | ||||
| @@ -75,13 +75,13 @@ class Tensor(_Tensor): | |||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| data = state.pop("data") | data = state.pop("data") | ||||
| device = state.pop("device") | |||||
| logical_device = state.pop("device") | |||||
| if self.dmap_callback is not None: | if self.dmap_callback is not None: | ||||
| assert isinstance(device, str) | |||||
| device = self.dmap_callback(device) | |||||
| assert isinstance(logical_device, str) | |||||
| logical_device = self.dmap_callback(logical_device) | |||||
| dtype = state.pop("dtype") | dtype = state.pop("dtype") | ||||
| self.q_dict = state.pop("qdict") | self.q_dict = state.pop("qdict") | ||||
| super().__init__(data, dtype=dtype, device=device) | |||||
| super().__init__(data, dtype=dtype, device=logical_device) | |||||
| def detach(self): | def detach(self): | ||||
| r""" | r""" | ||||
| @@ -55,10 +55,16 @@ void init_common(py::module m) { | |||||
| auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
| .def_property_readonly("logical_name", [](const CompNode& cn) { | |||||
| return cn.to_string_logical(); | |||||
| }) | |||||
| .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | ||||
| .def("_set_default_device", &set_default_device) | .def("_set_default_device", &set_default_device) | ||||
| .def("_get_default_device", &get_default_device) | .def("_get_default_device", &get_default_device) | ||||
| .def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
| .def("__repr__", [](const CompNode& cn) { | |||||
| return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); | |||||
| }) | |||||
| .def_static("_sync_all", &CompNode::sync_all) | .def_static("_sync_all", &CompNode::sync_all) | ||||
| .def(py::self == py::self) | .def(py::self == py::self) | ||||
| .def_static("_get_device_count", &CompNode::get_device_count, | .def_static("_get_device_count", &CompNode::get_device_count, | ||||