You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. /**
  2. * \file imperative/python/src/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "utils.h"
  12. #ifdef WIN32
  13. #include <stdio.h>
  14. #include <windows.h>
  15. #endif
  16. #include <pybind11/operators.h>
  17. #include <atomic>
  18. #include <cstdint>
  19. #include <shared_mutex>
  20. #include "./imperative_rt.h"
  21. #include "megbrain/common.h"
  22. #include "megbrain/comp_node.h"
  23. #include "megbrain/imperative/blob_manager.h"
  24. #include "megbrain/imperative/profiler.h"
  25. #include "megbrain/imperative/tensor_sanity_check.h"
  26. #include "megbrain/serialization/helper.h"
  27. #include "megbrain/utils/persistent_cache.h"
  28. #if MGB_ENABLE_OPR_MM
  29. #include "megbrain/opr/mm_handler.h"
  30. #endif
  31. namespace py = pybind11;
  32. namespace {
  33. bool g_global_finalized = false;
  34. class LoggerWrapper {
  35. public:
  36. using LogLevel = mgb::LogLevel;
  37. using LogHandler = mgb::LogHandler;
  38. static void set_log_handler(py::object logger_p) {
  39. logger = logger_p;
  40. mgb::set_log_handler(py_log_handler);
  41. }
  42. static LogLevel set_log_level(LogLevel log_level) {
  43. return mgb::set_log_level(log_level);
  44. }
  45. private:
  46. static py::object logger;
  47. static void py_log_handler(mgb::LogLevel level, const char* file,
  48. const char* func, int line, const char* fmt,
  49. va_list ap) {
  50. using mgb::LogLevel;
  51. MGB_MARK_USED_VAR(file);
  52. MGB_MARK_USED_VAR(func);
  53. MGB_MARK_USED_VAR(line);
  54. if (g_global_finalized)
  55. return;
  56. const char* py_type;
  57. switch (level) {
  58. case LogLevel::DEBUG:
  59. py_type = "debug";
  60. break;
  61. case LogLevel::INFO:
  62. py_type = "info";
  63. break;
  64. case LogLevel::WARN:
  65. py_type = "warning";
  66. break;
  67. case LogLevel::ERROR:
  68. py_type = "error";
  69. break;
  70. default:
  71. throw std::runtime_error("bad log level");
  72. }
  73. std::string msg = mgb::svsprintf(fmt, ap);
  74. auto do_log = [msg = msg, py_type]() {
  75. if (logger.is_none())
  76. return;
  77. py::object _call = logger.attr(py_type);
  78. _call(py::str(msg));
  79. };
  80. if (PyGILState_Check()) {
  81. do_log();
  82. } else {
  83. py_task_q.add_task(do_log);
  84. }
  85. }
  86. };
  87. py::object LoggerWrapper::logger = py::none{};
  88. uint32_t _get_dtype_num(py::object dtype) {
  89. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype.ptr()).enumv());
  90. }
  91. py::bytes _get_serialized_dtype(py::object dtype) {
  92. std::string sdtype;
  93. auto write = [&sdtype](const void* data, size_t size) {
  94. auto pos = sdtype.size();
  95. sdtype.resize(pos + size);
  96. memcpy(&sdtype[pos], data, size);
  97. };
  98. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype.ptr()), write);
  99. return py::bytes(sdtype.data(), sdtype.size());
  100. }
  101. int fork_exec_impl(const std::string& arg0, const std::string& arg1,
  102. const std::string& arg2) {
  103. #ifdef WIN32
  104. STARTUPINFO si;
  105. PROCESS_INFORMATION pi;
  106. ZeroMemory(&si, sizeof(si));
  107. si.cb = sizeof(si);
  108. ZeroMemory(&pi, sizeof(pi));
  109. auto args_str = " " + arg1 + " " + arg2;
  110. // Start the child process.
  111. if (!CreateProcess(arg0.c_str(), // exe name
  112. const_cast<char*>(args_str.c_str()), // Command line
  113. NULL, // Process handle not inheritable
  114. NULL, // Thread handle not inheritable
  115. FALSE, // Set handle inheritance to FALSE
  116. 0, // No creation flags
  117. NULL, // Use parent's environment block
  118. NULL, // Use parent's starting directory
  119. &si, // Pointer to STARTUPINFO structure
  120. &pi) // Pointer to PROCESS_INFORMATION structure
  121. ) {
  122. mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError());
  123. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n",
  124. arg0.c_str(), arg1.c_str(), arg2.c_str());
  125. __builtin_trap();
  126. }
  127. return pi.dwProcessId;
  128. #else
  129. auto pid = fork();
  130. if (!pid) {
  131. execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
  132. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n",
  133. arg0.c_str(), arg1.c_str(), arg2.c_str(), std::strerror(errno));
  134. std::terminate();
  135. }
  136. mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
  137. return pid;
  138. #endif
  139. }
  140. } // namespace
  141. void init_utils(py::module m) {
  142. auto atexit = py::module::import("atexit");
  143. atexit.attr("register")(py::cpp_function([]() {
  144. g_global_finalized = true;
  145. }));
  146. py::class_<std::atomic<uint64_t>>(m, "AtomicUint64")
  147. .def(py::init<>())
  148. .def(py::init<uint64_t>())
  149. .def("load",
  150. [](const std::atomic<uint64_t>& self) { return self.load(); })
  151. .def("store", [](std::atomic<uint64_t>& self,
  152. uint64_t value) { return self.store(value); })
  153. .def("fetch_add",
  154. [](std::atomic<uint64_t>& self, uint64_t value) {
  155. return self.fetch_add(value);
  156. })
  157. .def("fetch_sub",
  158. [](std::atomic<uint64_t>& self, uint64_t value) {
  159. return self.fetch_sub(value);
  160. })
  161. .def(py::self += uint64_t())
  162. .def(py::self -= uint64_t());
  163. // FIXME!!! Should add a submodule instead of using a class for logger
  164. py::class_<LoggerWrapper> logger(m, "Logger");
  165. logger.def(py::init<>())
  166. .def_static("set_log_level", &LoggerWrapper::set_log_level)
  167. .def_static("set_log_handler", &LoggerWrapper::set_log_handler);
  168. py::enum_<LoggerWrapper::LogLevel>(logger, "LogLevel")
  169. .value("Debug", LoggerWrapper::LogLevel::DEBUG)
  170. .value("Info", LoggerWrapper::LogLevel::INFO)
  171. .value("Warn", LoggerWrapper::LogLevel::WARN)
  172. .value("Error", LoggerWrapper::LogLevel::ERROR);
  173. m.def("_get_dtype_num", &_get_dtype_num,
  174. "Convert numpy dtype to internal dtype");
  175. m.def("_get_serialized_dtype", &_get_serialized_dtype,
  176. "Convert numpy dtype to internal dtype for serialization");
  177. m.def("_get_device_count", &mgb::CompNode::get_device_count,
  178. "Get total number of specific devices on this system");
  179. using mgb::imperative::TensorSanityCheck;
  180. py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
  181. .def(py::init<>())
  182. .def("enable",
  183. [](TensorSanityCheck& checker) -> TensorSanityCheck& {
  184. checker.enable();
  185. return checker;
  186. })
  187. .def("disable",
  188. [](TensorSanityCheck& checker) {
  189. checker.disable();
  190. });
  191. #if MGB_ENABLE_OPR_MM
  192. m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
  193. py::arg("port") = 0);
  194. #else
  195. m.def("create_mm_server", []() {});
  196. #endif
  197. // Debug code, internal only
  198. m.def("_set_defrag", [](bool enable) {
  199. mgb::imperative::BlobManager::inst()->set_enable(enable);
  200. });
  201. m.def("_defrag", [](const mgb::CompNode& cn) {
  202. mgb::imperative::BlobManager::inst()->defrag(cn);
  203. });
  204. m.def("_set_fork_exec_path_for_timed_func", [](const std::string& arg0,
  205. const ::std::string arg1) {
  206. using namespace std::placeholders;
  207. mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(std::bind(
  208. fork_exec_impl, std::string{arg0}, std::string{arg1}, _1));
  209. });
  210. m.def("_timed_func_exec_cb", [](const std::string& user_data){
  211. mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
  212. });
  213. using mgb::PersistentCache;
  214. class PyPersistentCache: public mgb::PersistentCache {
  215. private:
  216. using KeyPair = std::pair<std::string, std::string>;
  217. using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>;
  218. std::shared_mutex m_mutex;
  219. std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
  220. static size_t hash_key_pair(const KeyPair& kp) {
  221. std::hash<std::string> hasher;
  222. return hasher(kp.first) ^ hasher(kp.second);
  223. }
  224. std::string blob_to_str(const Blob& key) {
  225. return std::string(reinterpret_cast<const char*>(key.ptr), key.size);
  226. }
  227. BlobPtr copy_blob(const Blob& blob) {
  228. auto blob_deleter = [](Blob* blob){
  229. if (blob) {
  230. std::free(const_cast<void*>(blob->ptr));
  231. delete blob;
  232. }
  233. };
  234. auto blob_ptr = BlobPtr{ new Blob(), blob_deleter };
  235. blob_ptr->ptr = std::malloc(blob.size);
  236. std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size);
  237. blob_ptr->size = blob.size;
  238. return blob_ptr;
  239. }
  240. BlobPtr str_to_blob(const std::string& str) {
  241. auto blob = Blob{ str.data(), str.size() };
  242. return copy_blob(blob);
  243. }
  244. std::unique_ptr<Blob, void(*)(Blob*)> empty_blob() {
  245. return BlobPtr{ nullptr, [](Blob* blob){} };
  246. }
  247. public:
  248. mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
  249. auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> {
  250. PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
  251. };
  252. KeyPair kp = { category, blob_to_str(key) };
  253. std::shared_lock<decltype(m_mutex)> rlock;
  254. auto iter = m_local_cache.find(kp);
  255. if (iter == m_local_cache.end()) {
  256. auto py_ret = py_get(category, key);
  257. if (!py_ret.valid()) {
  258. iter = m_local_cache.insert({kp, empty_blob()}).first;
  259. } else {
  260. iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first;
  261. }
  262. }
  263. if (iter->second) {
  264. return *iter->second;
  265. } else {
  266. return {};
  267. }
  268. }
  269. void put(const std::string& category, const Blob& key, const Blob& value) override {
  270. KeyPair kp = { category, blob_to_str(key) };
  271. std::unique_lock<decltype(m_mutex)> wlock;
  272. m_local_cache.insert_or_assign(kp, copy_blob(value));
  273. PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
  274. }
  275. };
  276. py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache")
  277. .def(py::init<>())
  278. .def("get", &PersistentCache::get)
  279. .def("put", &PersistentCache::put)
  280. .def("reg", &PersistentCache::set_impl);
  281. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台