You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. /**
  2. * \file imperative/python/src/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "utils.h"
  12. #ifdef WIN32
  13. #include <stdio.h>
  14. #include <windows.h>
  15. #endif
  16. #include <pybind11/operators.h>
  17. #include <atomic>
  18. #include <cstdint>
  19. #include <shared_mutex>
  20. #include "./imperative_rt.h"
  21. #include "megbrain/common.h"
  22. #include "megbrain/comp_node.h"
  23. #include "megbrain/imperative/blob_manager.h"
  24. #include "megbrain/imperative/persistent_cache.h"
  25. #include "megbrain/imperative/profiler.h"
  26. #include "megbrain/imperative/tensor_sanity_check.h"
  27. #include "megbrain/serialization/helper.h"
  28. #include "megbrain/utils/persistent_cache.h"
  29. #if MGB_ENABLE_OPR_MM
  30. #include "megbrain/opr/mm_handler.h"
  31. #endif
  32. namespace py = pybind11;
  33. namespace {
  34. bool g_global_finalized = false;
  35. class LoggerWrapper {
  36. public:
  37. using LogLevel = mgb::LogLevel;
  38. using LogHandler = mgb::LogHandler;
  39. static void set_log_handler(py::object logger_p) {
  40. logger = logger_p;
  41. mgb::set_log_handler(py_log_handler);
  42. }
  43. static LogLevel set_log_level(LogLevel log_level) {
  44. return mgb::set_log_level(log_level);
  45. }
  46. private:
  47. static py::object logger;
  48. static void py_log_handler(
  49. mgb::LogLevel level, const char* file, const char* func, int line,
  50. const char* fmt, va_list ap) {
  51. using mgb::LogLevel;
  52. MGB_MARK_USED_VAR(file);
  53. MGB_MARK_USED_VAR(func);
  54. MGB_MARK_USED_VAR(line);
  55. if (g_global_finalized)
  56. return;
  57. const char* py_type;
  58. switch (level) {
  59. case LogLevel::DEBUG:
  60. py_type = "debug";
  61. break;
  62. case LogLevel::INFO:
  63. py_type = "info";
  64. break;
  65. case LogLevel::WARN:
  66. py_type = "warning";
  67. break;
  68. case LogLevel::ERROR:
  69. py_type = "error";
  70. break;
  71. default:
  72. throw std::runtime_error("bad log level");
  73. }
  74. std::string msg = mgb::svsprintf(fmt, ap);
  75. auto do_log = [msg = msg, py_type]() {
  76. if (logger.is_none())
  77. return;
  78. py::object _call = logger.attr(py_type);
  79. _call(py::str(msg));
  80. };
  81. if (PyGILState_Check()) {
  82. do_log();
  83. } else {
  84. py_task_q.add_task(do_log);
  85. }
  86. }
  87. };
  88. py::object LoggerWrapper::logger = py::none{};
  89. uint32_t _get_dtype_num(py::object dtype) {
  90. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype.ptr()).enumv());
  91. }
  92. py::bytes _get_serialized_dtype(py::object dtype) {
  93. std::string sdtype;
  94. auto write = [&sdtype](const void* data, size_t size) {
  95. auto pos = sdtype.size();
  96. sdtype.resize(pos + size);
  97. memcpy(&sdtype[pos], data, size);
  98. };
  99. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype.ptr()), write);
  100. return py::bytes(sdtype.data(), sdtype.size());
  101. }
  102. int fork_exec_impl(
  103. const std::string& arg0, const std::string& arg1, const std::string& arg2) {
  104. #ifdef WIN32
  105. STARTUPINFO si;
  106. PROCESS_INFORMATION pi;
  107. ZeroMemory(&si, sizeof(si));
  108. si.cb = sizeof(si);
  109. ZeroMemory(&pi, sizeof(pi));
  110. auto args_str = " " + arg1 + " " + arg2;
  111. // Start the child process.
  112. if (!CreateProcess(
  113. arg0.c_str(), // exe name
  114. const_cast<char*>(args_str.c_str()), // Command line
  115. NULL, // Process handle not inheritable
  116. NULL, // Thread handle not inheritable
  117. FALSE, // Set handle inheritance to FALSE
  118. 0, // No creation flags
  119. NULL, // Use parent's environment block
  120. NULL, // Use parent's starting directory
  121. &si, // Pointer to STARTUPINFO structure
  122. &pi) // Pointer to PROCESS_INFORMATION structure
  123. ) {
  124. mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError());
  125. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n", arg0.c_str(),
  126. arg1.c_str(), arg2.c_str());
  127. __builtin_trap();
  128. }
  129. return pi.dwProcessId;
  130. #else
  131. auto pid = fork();
  132. if (!pid) {
  133. execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
  134. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n", arg0.c_str(),
  135. arg1.c_str(), arg2.c_str(), std::strerror(errno));
  136. std::terminate();
  137. }
  138. mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
  139. return pid;
  140. #endif
  141. }
  142. } // namespace
  143. void init_utils(py::module m) {
  144. auto atexit = py::module::import("atexit");
  145. atexit.attr("register")(py::cpp_function([]() { g_global_finalized = true; }));
  146. py::class_<std::atomic<uint64_t>>(m, "AtomicUint64")
  147. .def(py::init<>())
  148. .def(py::init<uint64_t>())
  149. .def("load", [](const std::atomic<uint64_t>& self) { return self.load(); })
  150. .def("store", [](std::atomic<uint64_t>& self,
  151. uint64_t value) { return self.store(value); })
  152. .def("fetch_add", [](std::atomic<uint64_t>& self,
  153. uint64_t value) { return self.fetch_add(value); })
  154. .def("fetch_sub", [](std::atomic<uint64_t>& self,
  155. uint64_t value) { return self.fetch_sub(value); })
  156. .def(py::self += uint64_t())
  157. .def(py::self -= uint64_t());
  158. // FIXME!!! Should add a submodule instead of using a class for logger
  159. py::class_<LoggerWrapper> logger(m, "Logger");
  160. logger.def(py::init<>())
  161. .def_static("set_log_level", &LoggerWrapper::set_log_level)
  162. .def_static("set_log_handler", &LoggerWrapper::set_log_handler);
  163. py::enum_<LoggerWrapper::LogLevel>(logger, "LogLevel")
  164. .value("Debug", LoggerWrapper::LogLevel::DEBUG)
  165. .value("Info", LoggerWrapper::LogLevel::INFO)
  166. .value("Warn", LoggerWrapper::LogLevel::WARN)
  167. .value("Error", LoggerWrapper::LogLevel::ERROR);
  168. m.def("_get_dtype_num", &_get_dtype_num, "Convert numpy dtype to internal dtype");
  169. m.def("_get_serialized_dtype", &_get_serialized_dtype,
  170. "Convert numpy dtype to internal dtype for serialization");
  171. m.def("_get_device_count", &mgb::CompNode::get_device_count,
  172. "Get total number of specific devices on this system");
  173. m.def("_try_coalesce_all_free_memory", &mgb::CompNode::try_coalesce_all_free_memory,
  174. "This function will try it best to free all consecutive free chunks back to "
  175. "operating system");
  176. using mgb::imperative::TensorSanityCheck;
  177. py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
  178. .def(py::init<>())
  179. .def("enable",
  180. [](TensorSanityCheck& checker) -> TensorSanityCheck& {
  181. checker.enable();
  182. return checker;
  183. })
  184. .def("disable", [](TensorSanityCheck& checker) { checker.disable(); });
  185. #if MGB_ENABLE_OPR_MM
  186. m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
  187. py::arg("port") = 0);
  188. #else
  189. m.def("create_mm_server", []() {});
  190. #endif
  191. // Debug code, internal only
  192. m.def("_defrag", [](const mgb::CompNode& cn) {
  193. mgb::imperative::BlobManager::inst()->defrag(cn);
  194. });
  195. m.def("_set_fork_exec_path_for_timed_func",
  196. [](const std::string& arg0, const ::std::string arg1) {
  197. using namespace std::placeholders;
  198. mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(std::bind(
  199. fork_exec_impl, std::string{arg0}, std::string{arg1}, _1));
  200. });
  201. m.def("_timed_func_exec_cb", [](const std::string& user_data) {
  202. mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
  203. });
  204. using PersistentCache = mgb::PersistentCache;
  205. using ExtendedPersistentCache =
  206. mgb::imperative::persistent_cache::ExtendedPersistentCache;
  207. struct PersistentCacheManager {
  208. std::shared_ptr<ExtendedPersistentCache> instance;
  209. bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) {
  210. if (cache) {
  211. instance = cache;
  212. PersistentCache::set_impl(cache);
  213. return true;
  214. }
  215. return false;
  216. }
  217. bool open_file(std::string path) {
  218. return try_reg(mgb::imperative::persistent_cache::make_in_file(path));
  219. }
  220. std::optional<size_t> clean() {
  221. if (instance) {
  222. return instance->clear();
  223. }
  224. return {};
  225. }
  226. void put(std::string category, std::string key, std::string value) {
  227. PersistentCache::inst().put(
  228. category, {key.data(), key.size()}, {value.data(), value.size()});
  229. }
  230. py::object get(std::string category, std::string key) {
  231. auto value =
  232. PersistentCache::inst().get(category, {key.data(), key.size()});
  233. if (value.valid()) {
  234. return py::bytes(std::string((const char*)value->ptr, value->size));
  235. } else {
  236. return py::none();
  237. }
  238. }
  239. };
  240. py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
  241. .def(py::init<>())
  242. .def("try_open_file", &PersistentCacheManager::open_file)
  243. .def("clean", &PersistentCacheManager::clean)
  244. .def("put", &PersistentCacheManager::put)
  245. .def("get", &PersistentCacheManager::get);
  246. }