You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /**
  2. * \file imperative/python/src/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "utils.h"
  12. #ifdef WIN32
  13. #include <stdio.h>
  14. #include <windows.h>
  15. #endif
  16. #include <pybind11/operators.h>
  17. #include <atomic>
  18. #include <cstdint>
  19. #include "./imperative_rt.h"
  20. #include "megbrain/common.h"
  21. #include "megbrain/comp_node.h"
  22. #include "megbrain/imperative/blob_manager.h"
  23. #include "megbrain/imperative/profiler.h"
  24. #include "megbrain/imperative/tensor_sanity_check.h"
  25. #include "megbrain/serialization/helper.h"
  26. #if MGB_ENABLE_OPR_MM
  27. #include "megbrain/opr/mm_handler.h"
  28. #endif
  29. namespace py = pybind11;
  30. namespace {
  31. bool g_global_finalized = false;
  32. class LoggerWrapper {
  33. public:
  34. using LogLevel = mgb::LogLevel;
  35. using LogHandler = mgb::LogHandler;
  36. static void set_log_handler(py::object logger_p) {
  37. logger = logger_p;
  38. mgb::set_log_handler(py_log_handler);
  39. }
  40. static LogLevel set_log_level(LogLevel log_level) {
  41. return mgb::set_log_level(log_level);
  42. }
  43. private:
  44. static py::object logger;
  45. static void py_log_handler(mgb::LogLevel level, const char* file,
  46. const char* func, int line, const char* fmt,
  47. va_list ap) {
  48. using mgb::LogLevel;
  49. MGB_MARK_USED_VAR(file);
  50. MGB_MARK_USED_VAR(func);
  51. MGB_MARK_USED_VAR(line);
  52. if (g_global_finalized)
  53. return;
  54. const char* py_type;
  55. switch (level) {
  56. case LogLevel::DEBUG:
  57. py_type = "debug";
  58. break;
  59. case LogLevel::INFO:
  60. py_type = "info";
  61. break;
  62. case LogLevel::WARN:
  63. py_type = "warning";
  64. break;
  65. case LogLevel::ERROR:
  66. py_type = "error";
  67. break;
  68. default:
  69. throw std::runtime_error("bad log level");
  70. }
  71. std::string msg = mgb::svsprintf(fmt, ap);
  72. auto do_log = [msg = msg, py_type]() {
  73. if (logger.is_none())
  74. return;
  75. py::object _call = logger.attr(py_type);
  76. _call(py::str(msg));
  77. };
  78. if (PyGILState_Check()) {
  79. do_log();
  80. } else {
  81. py_task_q.add_task(do_log);
  82. }
  83. }
  84. };
  85. py::object LoggerWrapper::logger = py::none{};
  86. uint32_t _get_dtype_num(py::object dtype) {
  87. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype.ptr()).enumv());
  88. }
  89. py::bytes _get_serialized_dtype(py::object dtype) {
  90. std::string sdtype;
  91. auto write = [&sdtype](const void* data, size_t size) {
  92. auto pos = sdtype.size();
  93. sdtype.resize(pos + size);
  94. memcpy(&sdtype[pos], data, size);
  95. };
  96. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype.ptr()), write);
  97. return py::bytes(sdtype.data(), sdtype.size());
  98. }
  99. int fork_exec_impl(const std::string& arg0, const std::string& arg1,
  100. const std::string& arg2) {
  101. #ifdef WIN32
  102. STARTUPINFO si;
  103. PROCESS_INFORMATION pi;
  104. ZeroMemory(&si, sizeof(si));
  105. si.cb = sizeof(si);
  106. ZeroMemory(&pi, sizeof(pi));
  107. auto args_str = " " + arg1 + " " + arg2;
  108. // Start the child process.
  109. if (!CreateProcess(arg0.c_str(), // exe name
  110. const_cast<char*>(args_str.c_str()), // Command line
  111. NULL, // Process handle not inheritable
  112. NULL, // Thread handle not inheritable
  113. FALSE, // Set handle inheritance to FALSE
  114. 0, // No creation flags
  115. NULL, // Use parent's environment block
  116. NULL, // Use parent's starting directory
  117. &si, // Pointer to STARTUPINFO structure
  118. &pi) // Pointer to PROCESS_INFORMATION structure
  119. ) {
  120. mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError());
  121. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n",
  122. arg0.c_str(), arg1.c_str(), arg2.c_str());
  123. __builtin_trap();
  124. }
  125. return pi.dwProcessId;
  126. #else
  127. auto pid = fork();
  128. if (!pid) {
  129. execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
  130. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n",
  131. arg0.c_str(), arg1.c_str(), arg2.c_str(), std::strerror(errno));
  132. std::terminate();
  133. }
  134. mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
  135. return pid;
  136. #endif
  137. }
  138. } // namespace
  139. void init_utils(py::module m) {
  140. auto atexit = py::module::import("atexit");
  141. atexit.attr("register")(py::cpp_function([]() {
  142. g_global_finalized = true;
  143. }));
  144. py::class_<std::atomic<uint64_t>>(m, "AtomicUint64")
  145. .def(py::init<>())
  146. .def(py::init<uint64_t>())
  147. .def("load",
  148. [](const std::atomic<uint64_t>& self) { return self.load(); })
  149. .def("store", [](std::atomic<uint64_t>& self,
  150. uint64_t value) { return self.store(value); })
  151. .def("fetch_add",
  152. [](std::atomic<uint64_t>& self, uint64_t value) {
  153. return self.fetch_add(value);
  154. })
  155. .def("fetch_sub",
  156. [](std::atomic<uint64_t>& self, uint64_t value) {
  157. return self.fetch_sub(value);
  158. })
  159. .def(py::self += uint64_t())
  160. .def(py::self -= uint64_t());
  161. // FIXME!!! Should add a submodule instead of using a class for logger
  162. py::class_<LoggerWrapper> logger(m, "Logger");
  163. logger.def(py::init<>())
  164. .def_static("set_log_level", &LoggerWrapper::set_log_level)
  165. .def_static("set_log_handler", &LoggerWrapper::set_log_handler);
  166. py::enum_<LoggerWrapper::LogLevel>(logger, "LogLevel")
  167. .value("Debug", LoggerWrapper::LogLevel::DEBUG)
  168. .value("Info", LoggerWrapper::LogLevel::INFO)
  169. .value("Warn", LoggerWrapper::LogLevel::WARN)
  170. .value("Error", LoggerWrapper::LogLevel::ERROR);
  171. m.def("_get_dtype_num", &_get_dtype_num,
  172. "Convert numpy dtype to internal dtype");
  173. m.def("_get_serialized_dtype", &_get_serialized_dtype,
  174. "Convert numpy dtype to internal dtype for serialization");
  175. m.def("_get_device_count", &mgb::CompNode::get_device_count,
  176. "Get total number of specific devices on this system");
  177. using mgb::imperative::ProfileEntry;
  178. py::class_<ProfileEntry>(m, "ProfileEntry")
  179. .def_readwrite("op", &ProfileEntry::op)
  180. .def_readwrite("host", &ProfileEntry::host)
  181. .def_readwrite("device_list", &ProfileEntry::device_list)
  182. .def_readwrite("inputs", &ProfileEntry::inputs)
  183. .def_readwrite("outputs", &ProfileEntry::outputs)
  184. .def_readwrite("id", &ProfileEntry::id)
  185. .def_readwrite("parent", &ProfileEntry::parent)
  186. .def_readwrite("memory", &ProfileEntry::memory)
  187. .def_readwrite("computation", &ProfileEntry::computation)
  188. .def_property_readonly("param", [](ProfileEntry& self)->std::string{
  189. if(self.param){
  190. return self.param->to_string();
  191. } else {
  192. return {};
  193. }
  194. });
  195. py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl")
  196. .def(py::init<>())
  197. .def("start", &mgb::imperative::Profiler::start)
  198. .def("stop", &mgb::imperative::Profiler::stop)
  199. .def("clear", &mgb::imperative::Profiler::clear)
  200. .def("dump", &mgb::imperative::Profiler::get_profile);
  201. using mgb::imperative::TensorSanityCheck;
  202. py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
  203. .def(py::init<>())
  204. .def("enable",
  205. [](TensorSanityCheck& checker) -> TensorSanityCheck& {
  206. checker.enable();
  207. return checker;
  208. })
  209. .def("disable",
  210. [](TensorSanityCheck& checker) {
  211. checker.disable();
  212. });
  213. #if MGB_ENABLE_OPR_MM
  214. m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
  215. py::arg("port") = 0);
  216. #else
  217. m.def("create_mm_server", []() {});
  218. #endif
  219. // Debug code, internal only
  220. m.def("_set_defrag", [](bool enable) {
  221. mgb::imperative::BlobManager::inst()->set_enable(enable);
  222. });
  223. m.def("_defrag", [](const mgb::CompNode& cn) {
  224. mgb::imperative::BlobManager::inst()->defrag(cn);
  225. });
  226. m.def("_set_fork_exec_path_for_timed_func", [](const std::string& arg0,
  227. const ::std::string arg1) {
  228. using namespace std::placeholders;
  229. mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(std::bind(
  230. fork_exec_impl, std::string{arg0}, std::string{arg1}, _1));
  231. });
  232. m.def("_timed_func_exec_cb", [](const std::string& user_data){
  233. mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
  234. });
  235. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台