GitOrigin-RevId: 722c4debfa
tags/v1.0.0-rc1
| @@ -17,8 +17,6 @@ __all__ = [ | |||||
| "set_default_device", | "set_default_device", | ||||
| ] | ] | ||||
| _default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||||
| def _valid_device(inp): | def _valid_device(inp): | ||||
| if isinstance(inp, str) and len(inp) == 4: | if isinstance(inp, str) and len(inp) == 4: | ||||
| @@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"): | |||||
| It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | ||||
| """ | """ | ||||
| global _default_device # pylint: disable=global-statement | |||||
| assert _valid_device(device), "Invalid device name {}".format(device) | assert _valid_device(device), "Invalid device name {}".format(device) | ||||
| _default_device = device | |||||
| CompNode._set_default_device(device) | |||||
| def get_default_device() -> str: | def get_default_device() -> str: | ||||
| @@ -86,4 +83,7 @@ def get_default_device() -> str: | |||||
| It returns the value set by :func:`~.set_default_device`. | It returns the value set by :func:`~.set_default_device`. | ||||
| """ | """ | ||||
| return _default_device | |||||
| return CompNode._get_default_device() | |||||
| set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux")) | |||||
| @@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) { | |||||
| &XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | &XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | ||||
| } | } | ||||
| std::string default_device = "xpux"; | |||||
| } // namespace | } // namespace | ||||
| void set_default_device(const std::string &device) { | |||||
| default_device = device; | |||||
| } | |||||
| std::string get_default_device() { | |||||
| return default_device; | |||||
| } | |||||
| void init_common(py::module m) { | void init_common(py::module m) { | ||||
| auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
| .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | ||||
| .def("_set_default_device", &set_default_device) | |||||
| .def("_get_default_device", &get_default_device) | |||||
| .def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
| .def_static("_sync_all", &CompNode::sync_all) | .def_static("_sync_all", &CompNode::sync_all) | ||||
| .def(py::self == py::self) | .def(py::self == py::self) | ||||
| @@ -14,3 +14,6 @@ | |||||
| #include "./helper.h" | #include "./helper.h" | ||||
| void init_common(pybind11::module m); | void init_common(pybind11::module m); | ||||
| void set_default_device(const std::string &device); | |||||
| std::string get_default_device(); | |||||
| @@ -19,6 +19,7 @@ | |||||
| #include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
| #include "./common.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -230,7 +231,7 @@ void init_graph_rt(py::module m) { | |||||
| m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | ||||
| if (!cn.valid()) { | if (!cn.valid()) { | ||||
| cn = CompNode::load("xpux"); | |||||
| cn = CompNode::load(get_default_device()); | |||||
| } | } | ||||
| auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | ||||
| return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "./helper.h" | #include "./helper.h" | ||||
| #include "./common.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) { | |||||
| py::class_<Interpreter::Channel>(m, "Interpreter") | py::class_<Interpreter::Channel>(m, "Interpreter") | ||||
| .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { | .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { | ||||
| if (!cn.valid()) { | if (!cn.valid()) { | ||||
| cn = CompNode::load("xpux"); | |||||
| cn = CompNode::load(get_default_device()); | |||||
| } | } | ||||
| constexpr int size_threshhold = TensorShape::MAX_NDIM; | constexpr int size_threshhold = TensorShape::MAX_NDIM; | ||||
| if (data.size() > size_threshhold) { | if (data.size() > size_threshhold) { | ||||