You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. #include "utils.h"
  2. #ifdef WIN32
  3. #include <stdio.h>
  4. #include <windows.h>
  5. #endif
  6. #include <pybind11/operators.h>
  7. #include <atomic>
  8. #include <cstdint>
  9. #include <shared_mutex>
  10. #include "./imperative_rt.h"
  11. #include "megbrain/common.h"
  12. #include "megbrain/comp_node.h"
  13. #include "megbrain/imperative/blob_manager.h"
  14. #include "megbrain/imperative/persistent_cache.h"
  15. #include "megbrain/imperative/profiler.h"
  16. #include "megbrain/imperative/tensor_sanity_check.h"
  17. #include "megbrain/serialization/helper.h"
  18. #include "megbrain/utils/persistent_cache.h"
  19. #if MGB_ENABLE_OPR_MM
  20. #include "megbrain/opr/mm_handler.h"
  21. #endif
  22. namespace py = pybind11;
  23. namespace {
  24. bool g_global_finalized = false;
  25. class LoggerWrapper {
  26. public:
  27. using LogLevel = mgb::LogLevel;
  28. using LogHandler = mgb::LogHandler;
  29. static void set_log_handler(py::object logger_p) {
  30. logger = logger_p;
  31. mgb::set_log_handler(py_log_handler);
  32. }
  33. static LogLevel set_log_level(LogLevel log_level) {
  34. return mgb::set_log_level(log_level);
  35. }
  36. private:
  37. static py::object logger;
  38. static void py_log_handler(
  39. mgb::LogLevel level, const char* file, const char* func, int line,
  40. const char* fmt, va_list ap) {
  41. using mgb::LogLevel;
  42. MGB_MARK_USED_VAR(file);
  43. MGB_MARK_USED_VAR(func);
  44. MGB_MARK_USED_VAR(line);
  45. if (g_global_finalized)
  46. return;
  47. const char* py_type;
  48. switch (level) {
  49. case LogLevel::DEBUG:
  50. py_type = "debug";
  51. break;
  52. case LogLevel::INFO:
  53. py_type = "info";
  54. break;
  55. case LogLevel::WARN:
  56. py_type = "warning";
  57. break;
  58. case LogLevel::ERROR:
  59. py_type = "error";
  60. break;
  61. default:
  62. throw std::runtime_error("bad log level");
  63. }
  64. std::string msg = mgb::svsprintf(fmt, ap);
  65. auto do_log = [msg = msg, py_type]() {
  66. if (logger.is_none())
  67. return;
  68. py::object _call = logger.attr(py_type);
  69. _call(py::str(msg));
  70. };
  71. if (PyGILState_Check()) {
  72. do_log();
  73. } else {
  74. py_task_q.add_task(do_log);
  75. }
  76. }
  77. };
  78. py::object LoggerWrapper::logger = py::none{};
  79. uint32_t _get_dtype_num(py::object dtype) {
  80. return static_cast<uint32_t>(npy::dtype_np2mgb(dtype.ptr()).enumv());
  81. }
  82. py::bytes _get_serialized_dtype(py::object dtype) {
  83. std::string sdtype;
  84. auto write = [&sdtype](const void* data, size_t size) {
  85. auto pos = sdtype.size();
  86. sdtype.resize(pos + size);
  87. memcpy(&sdtype[pos], data, size);
  88. };
  89. mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype.ptr()), write);
  90. return py::bytes(sdtype.data(), sdtype.size());
  91. }
  92. int fork_exec_impl(
  93. const std::string& arg0, const std::string& arg1, const std::string& arg2) {
  94. #ifdef WIN32
  95. STARTUPINFO si;
  96. PROCESS_INFORMATION pi;
  97. ZeroMemory(&si, sizeof(si));
  98. si.cb = sizeof(si);
  99. ZeroMemory(&pi, sizeof(pi));
  100. auto args_str = " " + arg1 + " " + arg2;
  101. // Start the child process.
  102. if (!CreateProcess(
  103. arg0.c_str(), // exe name
  104. const_cast<char*>(args_str.c_str()), // Command line
  105. NULL, // Process handle not inheritable
  106. NULL, // Thread handle not inheritable
  107. FALSE, // Set handle inheritance to FALSE
  108. 0, // No creation flags
  109. NULL, // Use parent's environment block
  110. NULL, // Use parent's starting directory
  111. &si, // Pointer to STARTUPINFO structure
  112. &pi) // Pointer to PROCESS_INFORMATION structure
  113. ) {
  114. mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError());
  115. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n", arg0.c_str(),
  116. arg1.c_str(), arg2.c_str());
  117. __builtin_trap();
  118. }
  119. return pi.dwProcessId;
  120. #else
  121. auto pid = fork();
  122. if (!pid) {
  123. execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
  124. fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n", arg0.c_str(),
  125. arg1.c_str(), arg2.c_str(), std::strerror(errno));
  126. std::terminate();
  127. }
  128. mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
  129. return pid;
  130. #endif
  131. }
  132. } // namespace
  133. void init_utils(py::module m) {
  134. auto atexit = py::module::import("atexit");
  135. atexit.attr("register")(py::cpp_function([]() { g_global_finalized = true; }));
  136. py::class_<std::atomic<uint64_t>>(m, "AtomicUint64")
  137. .def(py::init<>())
  138. .def(py::init<uint64_t>())
  139. .def("load", [](const std::atomic<uint64_t>& self) { return self.load(); })
  140. .def("store", [](std::atomic<uint64_t>& self,
  141. uint64_t value) { return self.store(value); })
  142. .def("fetch_add", [](std::atomic<uint64_t>& self,
  143. uint64_t value) { return self.fetch_add(value); })
  144. .def("fetch_sub", [](std::atomic<uint64_t>& self,
  145. uint64_t value) { return self.fetch_sub(value); })
  146. .def(py::self += uint64_t())
  147. .def(py::self -= uint64_t());
  148. // FIXME!!! Should add a submodule instead of using a class for logger
  149. py::class_<LoggerWrapper> logger(m, "Logger");
  150. logger.def(py::init<>())
  151. .def_static("set_log_level", &LoggerWrapper::set_log_level)
  152. .def_static("set_log_handler", &LoggerWrapper::set_log_handler);
  153. py::enum_<LoggerWrapper::LogLevel>(logger, "LogLevel")
  154. .value("Debug", LoggerWrapper::LogLevel::DEBUG)
  155. .value("Info", LoggerWrapper::LogLevel::INFO)
  156. .value("Warn", LoggerWrapper::LogLevel::WARN)
  157. .value("Error", LoggerWrapper::LogLevel::ERROR);
  158. m.def("_get_dtype_num", &_get_dtype_num, "Convert numpy dtype to internal dtype");
  159. m.def("_get_serialized_dtype", &_get_serialized_dtype,
  160. "Convert numpy dtype to internal dtype for serialization");
  161. m.def("_get_device_count", &mgb::CompNode::get_device_count,
  162. "Get total number of specific devices on this system");
  163. m.def("_try_coalesce_all_free_memory", &mgb::CompNode::try_coalesce_all_free_memory,
  164. "This function will try it best to free all consecutive free chunks back to "
  165. "operating system");
  166. using mgb::imperative::TensorSanityCheck;
  167. py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
  168. .def(py::init<>())
  169. .def("enable",
  170. [](TensorSanityCheck& checker) -> TensorSanityCheck& {
  171. checker.enable();
  172. return checker;
  173. })
  174. .def("disable", [](TensorSanityCheck& checker) { checker.disable(); });
  175. #if MGB_ENABLE_OPR_MM
  176. m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"),
  177. py::arg("port") = 0);
  178. #else
  179. m.def("create_mm_server", []() {});
  180. #endif
  181. // Debug code, internal only
  182. m.def("_defrag", [](const mgb::CompNode& cn) {
  183. mgb::imperative::BlobManager::inst()->defrag(cn);
  184. });
  185. m.def("_set_fork_exec_path_for_timed_func",
  186. [](const std::string& arg0, const ::std::string arg1) {
  187. using namespace std::placeholders;
  188. mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(std::bind(
  189. fork_exec_impl, std::string{arg0}, std::string{arg1}, _1));
  190. });
  191. m.def("_timed_func_exec_cb", [](const std::string& user_data) {
  192. mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
  193. });
  194. using PersistentCache = mgb::PersistentCache;
  195. using ExtendedPersistentCache =
  196. mgb::imperative::persistent_cache::ExtendedPersistentCache;
  197. struct ConfigurablePersistentCache : mgb::PersistentCache {
  198. struct Config {
  199. std::string type;
  200. std::unordered_map<std::string, std::string> args;
  201. std::string on_success;
  202. std::string on_fail;
  203. };
  204. std::shared_ptr<ExtendedPersistentCache> impl;
  205. std::optional<Config> impl_config;
  206. std::vector<Config> configs;
  207. void add_config(
  208. std::string type, std::unordered_map<std::string, std::string> args,
  209. std::string on_success, std::string on_fail) {
  210. configs.push_back({type, args, on_success, on_fail});
  211. }
  212. std::optional<size_t> clean() { return get_impl()->clear(); }
  213. void load_config() {
  214. std::optional<std::string> err_msg;
  215. for (size_t i = 0; i < configs.size(); ++i) {
  216. auto& config = configs[i];
  217. if (err_msg) {
  218. mgb_log_warn("try fallback to %s cache", config.type.c_str());
  219. } else {
  220. err_msg.emplace();
  221. }
  222. auto cache = ExtendedPersistentCache::make_from_config(
  223. config.type, config.args, *err_msg);
  224. if (!cache) {
  225. mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str());
  226. } else {
  227. impl = cache;
  228. impl_config = config;
  229. break;
  230. }
  231. }
  232. mgb_assert(impl_config.has_value(), "not valid config");
  233. }
  234. std::shared_ptr<ExtendedPersistentCache> get_impl() {
  235. if (!impl) {
  236. load_config();
  237. }
  238. return impl;
  239. }
  240. virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) {
  241. return get_impl()->get(category, key);
  242. }
  243. virtual void put(
  244. const std::string& category, const Blob& key, const Blob& value) {
  245. return get_impl()->put(category, key, value);
  246. }
  247. virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); }
  248. py::object py_get(std::string category, std::string key) {
  249. auto value = get_impl()->get(category, {key.data(), key.size()});
  250. if (value.valid()) {
  251. return py::bytes(std::string((const char*)value->ptr, value->size));
  252. } else {
  253. return py::none();
  254. }
  255. }
  256. void py_put(std::string category, std::string key, std::string value) {
  257. get_impl()->put(
  258. category, {key.data(), key.size()}, {value.data(), value.size()});
  259. }
  260. void flush() {
  261. if (impl) {
  262. impl->flush();
  263. }
  264. }
  265. };
  266. auto PyConfigurablePersistentCache =
  267. py::class_<
  268. ConfigurablePersistentCache,
  269. std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache")
  270. .def(py::init<>())
  271. .def("add_config", &ConfigurablePersistentCache::add_config)
  272. .def("reg",
  273. [](std::shared_ptr<ConfigurablePersistentCache> inst) {
  274. PersistentCache::set_impl(inst);
  275. })
  276. .def("clean", &ConfigurablePersistentCache::clean)
  277. .def("get", &ConfigurablePersistentCache::py_get)
  278. .def("put", &ConfigurablePersistentCache::py_put)
  279. .def_readonly("config", &ConfigurablePersistentCache::impl_config)
  280. .def("flush", &ConfigurablePersistentCache::flush);
  281. py::class_<ConfigurablePersistentCache::Config>(
  282. PyConfigurablePersistentCache, "Config")
  283. .def_readwrite("type", &ConfigurablePersistentCache::Config::type)
  284. .def_readwrite("args", &ConfigurablePersistentCache::Config::args)
  285. .def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail)
  286. .def_readwrite(
  287. "on_success", &ConfigurablePersistentCache::Config::on_success);
  288. }