| @@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| from .serialization import load, save | |||
| from .tensor import Parameter, Tensor, tensor | |||
| from .utils import comp_graph_tools as cgtools | |||
| from .utils import persistent_cache | |||
| from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer | |||
| from .version import __version__ | |||
| _set_fork_exec_path_for_timed_func( | |||
| @@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func( | |||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
| ) | |||
| atexit.register(_close) | |||
| del _set_fork_exec_path_for_timed_func | |||
| _exit_handlers = [] | |||
| def _run_exit_handlers(): | |||
| for handler in _exit_handlers: | |||
| for handler in reversed(_exit_handlers): | |||
| handler() | |||
| _exit_handlers.clear() | |||
| @@ -117,6 +115,13 @@ def _atexit(handler): | |||
| _exit_handlers.append(handler) | |||
| _atexit(_close) | |||
| _persistent_cache = _PersistentCacheOnServer() | |||
| _persistent_cache.reg() | |||
| _atexit(_persistent_cache.flush) | |||
| # subpackages | |||
| import megengine.amp | |||
| import megengine.autodiff | |||
| @@ -132,5 +137,3 @@ import megengine.quantization | |||
| import megengine.random | |||
| import megengine.utils | |||
| import megengine.traced_module | |||
| persistent_cache.get_manager() | |||
| @@ -8,87 +8,114 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import contextlib | |||
| import getpass | |||
| import os | |||
| import sys | |||
| import urllib.parse | |||
| from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
| import filelock | |||
| from ..core._imperative_rt import PersistentCache as _PersistentCache | |||
| from ..logger import get_logger | |||
| from ..version import __version__, git_version | |||
| class PersistentCacheManager(_PersistentCacheManager): | |||
| class PersistentCacheOnServer(_PersistentCache): | |||
| def __init__(self): | |||
| super().__init__() | |||
| if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
| get_logger().info("fastrun use in-memory cache") | |||
| self.open_memory() | |||
| elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||
| self.open_file() | |||
| else: | |||
| self.open_redis() | |||
| def open_memory(self): | |||
| pass | |||
| cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") | |||
| if cache_type not in ("FILE", "MEMORY"): | |||
| try: | |||
| redis_config = self.get_redis_config() | |||
| except Exception as exc: | |||
| get_logger().error( | |||
| "failed to connect to cache server {!r}; try fallback to " | |||
| "in-file cache".format(exc) | |||
| ) | |||
| else: | |||
| self.add_config( | |||
| "redis", | |||
| redis_config, | |||
| "fastrun use redis cache", | |||
| "failed to connect to cache server", | |||
| ) | |||
| if cache_type != "MEMORY": | |||
| path = self.get_cache_file(self.get_cache_dir()) | |||
| self.add_config( | |||
| "in-file", | |||
| {"path": path}, | |||
| "fastrun use in-file cache in {}".format(path), | |||
| "failed to create cache file in {}".format(path), | |||
| ) | |||
| self.add_config( | |||
| "in-memory", | |||
| {}, | |||
| "fastrun use in-memory cache", | |||
| "failed to create in-memory cache", | |||
| ) | |||
| def open_file(self): | |||
| def get_cache_dir(self): | |||
| cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
| try: | |||
| if not cache_dir: | |||
| from ..hub.hub import _get_megengine_home | |||
| if not cache_dir: | |||
| from ..hub.hub import _get_megengine_home | |||
| cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||
| ) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| cache_file = os.path.join(cache_dir, "cache") | |||
| with open(cache_file, "a"): | |||
| pass | |||
| assert self.try_open_file(cache_file), "cannot create file" | |||
| get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||
| except Exception as exc: | |||
| get_logger().error( | |||
| "failed to create cache file in {} {!r}; fallback to " | |||
| "in-memory cache".format(cache_dir, exc) | |||
| cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache") | |||
| ) | |||
| self.open_memory() | |||
| def open_redis(self): | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| return cache_dir | |||
| def get_cache_file(self, cache_dir): | |||
| cache_file = os.path.join(cache_dir, "cache.bin") | |||
| with open(cache_file, "a"): | |||
| pass | |||
| return cache_file | |||
| @contextlib.contextmanager | |||
| def lock_cache_file(self, cache_dir): | |||
| lock_file = os.path.join(cache_dir, "cache.lock") | |||
| with filelock.FileLock(lock_file): | |||
| yield | |||
| def get_redis_config(self): | |||
| url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
| if url is None: | |||
| return None | |||
| assert sys.platform != "win32", "redis cache on windows not tested" | |||
| prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | |||
| getpass.getuser(), __version__, git_version | |||
| ) | |||
| url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
| if url is None: | |||
| self.open_file() | |||
| try: | |||
| assert sys.platform != "win32", "redis cache on windows not tested" | |||
| parse_result = urllib.parse.urlparse(url, scheme="redis") | |||
| assert parse_result.scheme == "redis", "unsupported scheme" | |||
| assert not parse_result.username, "redis conn with username unsupported" | |||
| assert self.try_open_redis( | |||
| parse_result.hostname, parse_result.port, parse_result.password, prefix | |||
| ), "connect failed" | |||
| except Exception as exc: | |||
| get_logger().error( | |||
| "failed to connect to cache server {!r}; try fallback to " | |||
| "in-file cache".format(exc) | |||
| ) | |||
| self.open_file() | |||
| _manager = None | |||
| parse_result = urllib.parse.urlparse(url) | |||
| assert not parse_result.username, "redis conn with username unsupported" | |||
| if parse_result.scheme == "redis": | |||
| assert parse_result.hostname and parse_result.port, "invalid url" | |||
| assert not parse_result.path | |||
| config = { | |||
| "hostname": parse_result.hostname, | |||
| "port": str(parse_result.port), | |||
| } | |||
| elif parse_result.scheme == "redis+socket": | |||
| assert not (parse_result.hostname or parse_result.port) | |||
| assert parse_result.path | |||
| config = { | |||
| "unixsocket": parse_result.path, | |||
| } | |||
| else: | |||
| assert False, "unsupported scheme" | |||
| if parse_result.password is not None: | |||
| config["password"] = parse_result.password | |||
| config["prefix"] = prefix | |||
| return config | |||
| def get_manager(): | |||
| global _manager | |||
| if _manager is None: | |||
| _manager = PersistentCacheManager() | |||
| return _manager | |||
| def flush(self): | |||
| if self.config is not None and self.config.type == "in-file": | |||
| with self.lock_cache_file(self.get_cache_dir()): | |||
| super().flush() | |||
| def _clean(): | |||
| nr_del = get_manager().clean() | |||
| nr_del = PersistentCacheOnServer().clean() | |||
| if nr_del is not None: | |||
| print("{} cache entries deleted".format(nr_del)) | |||
| @@ -4,8 +4,8 @@ pyarrow | |||
| requests | |||
| tabulate | |||
| tqdm | |||
| redispy | |||
| deprecated | |||
| mprop | |||
| wheel | |||
| megfile>=0.0.10 | |||
| megfile>=0.0.10 | |||
| filelock | |||
| @@ -210,7 +210,7 @@ void init_utils(py::module m) { | |||
| .def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); | |||
| #if MGB_ENABLE_OPR_MM | |||
| m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), | |||
| m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"), | |||
| py::arg("port") = 0); | |||
| #else | |||
| m.def("create_mm_server", []() {}); | |||
| @@ -234,51 +234,108 @@ void init_utils(py::module m) { | |||
| using ExtendedPersistentCache = | |||
| mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||
| struct PersistentCacheManager { | |||
| std::shared_ptr<ExtendedPersistentCache> instance; | |||
| struct ConfigurablePersistentCache : mgb::PersistentCache { | |||
| struct Config { | |||
| std::string type; | |||
| std::unordered_map<std::string, std::string> args; | |||
| std::string on_success; | |||
| std::string on_fail; | |||
| }; | |||
| bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||
| if (cache) { | |||
| instance = cache; | |||
| PersistentCache::set_impl(cache); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool open_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
| ip, port, password, prefix)); | |||
| std::shared_ptr<ExtendedPersistentCache> impl; | |||
| std::optional<Config> impl_config; | |||
| std::vector<Config> configs; | |||
| void add_config( | |||
| std::string type, std::unordered_map<std::string, std::string> args, | |||
| std::string on_success, std::string on_fail) { | |||
| configs.push_back({type, args, on_success, on_fail}); | |||
| } | |||
| bool open_file(std::string path) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
| std::optional<size_t> clean() { return get_impl()->clear(); } | |||
| void load_config() { | |||
| std::optional<std::string> err_msg; | |||
| for (size_t i = 0; i < configs.size(); ++i) { | |||
| auto& config = configs[i]; | |||
| if (err_msg) { | |||
| mgb_log_warn("try fallback to %s cache", config.type.c_str()); | |||
| } else { | |||
| err_msg.emplace(); | |||
| } | |||
| auto cache = ExtendedPersistentCache::make_from_config( | |||
| config.type, config.args, *err_msg); | |||
| if (!cache) { | |||
| mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str()); | |||
| } else { | |||
| impl = cache; | |||
| impl_config = config; | |||
| break; | |||
| } | |||
| } | |||
| mgb_assert(impl_config.has_value(), "not valid config"); | |||
| } | |||
| std::optional<size_t> clean() { | |||
| if (instance) { | |||
| return instance->clear(); | |||
| std::shared_ptr<ExtendedPersistentCache> get_impl() { | |||
| if (!impl) { | |||
| load_config(); | |||
| } | |||
| return {}; | |||
| return impl; | |||
| } | |||
| void put(std::string category, std::string key, std::string value) { | |||
| PersistentCache::inst().put( | |||
| category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
| virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) { | |||
| return get_impl()->get(category, key); | |||
| } | |||
| virtual void put( | |||
| const std::string& category, const Blob& key, const Blob& value) { | |||
| return get_impl()->put(category, key, value); | |||
| } | |||
| py::object get(std::string category, std::string key) { | |||
| auto value = | |||
| PersistentCache::inst().get(category, {key.data(), key.size()}); | |||
| virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); } | |||
| py::object py_get(std::string category, std::string key) { | |||
| auto value = get_impl()->get(category, {key.data(), key.size()}); | |||
| if (value.valid()) { | |||
| return py::bytes(std::string((const char*)value->ptr, value->size)); | |||
| } else { | |||
| return py::none(); | |||
| } | |||
| } | |||
| void py_put(std::string category, std::string key, std::string value) { | |||
| get_impl()->put( | |||
| category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
| } | |||
| void flush() { | |||
| if (impl) { | |||
| impl->flush(); | |||
| } | |||
| } | |||
| }; | |||
| py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
| .def(py::init<>()) | |||
| .def("try_open_redis", &PersistentCacheManager::open_redis) | |||
| .def("try_open_file", &PersistentCacheManager::open_file) | |||
| .def("clean", &PersistentCacheManager::clean) | |||
| .def("put", &PersistentCacheManager::put) | |||
| .def("get", &PersistentCacheManager::get); | |||
| auto PyConfigurablePersistentCache = | |||
| py::class_< | |||
| ConfigurablePersistentCache, | |||
| std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache") | |||
| .def(py::init<>()) | |||
| .def("add_config", &ConfigurablePersistentCache::add_config) | |||
| .def("reg", | |||
| [](std::shared_ptr<ConfigurablePersistentCache> inst) { | |||
| PersistentCache::set_impl(inst); | |||
| }) | |||
| .def("clean", &ConfigurablePersistentCache::clean) | |||
| .def("get", &ConfigurablePersistentCache::py_get) | |||
| .def("put", &ConfigurablePersistentCache::py_put) | |||
| .def_readonly("config", &ConfigurablePersistentCache::impl_config) | |||
| .def("flush", &ConfigurablePersistentCache::flush); | |||
| py::class_<ConfigurablePersistentCache::Config>( | |||
| PyConfigurablePersistentCache, "Config") | |||
| .def_readwrite("type", &ConfigurablePersistentCache::Config::type) | |||
| .def_readwrite("args", &ConfigurablePersistentCache::Config::args) | |||
| .def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail) | |||
| .def_readwrite( | |||
| "on_success", &ConfigurablePersistentCache::Config::on_success); | |||
| } | |||
| @@ -27,7 +27,7 @@ namespace imperative { | |||
| namespace { | |||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& comm = def.cast_final_safe<CollectiveComm>(); | |||
| auto group_client = std::make_shared<GroupClientProxy>( | |||
| auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
| ssprintf("%s:%d", comm.addr.data(), comm.port)); | |||
| SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr); | |||
| auto disable = std::make_shared<DTypeScalar>(); | |||
| @@ -28,7 +28,7 @@ namespace { | |||
| cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
| const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& send = def.cast_final_safe<RemoteSend>(); | |||
| auto group_client = std::make_shared<GroupClientProxy>( | |||
| auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
| ssprintf("%s:%d", send.addr.data(), send.port)); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| @@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
| auto&& recv = def.cast_final_safe<RemoteRecv>(); | |||
| OperatorNodeConfig config{recv.cn}; | |||
| config.name(recv.make_name()); | |||
| auto group_client = std::make_shared<GroupClientProxy>( | |||
| auto group_client = std::make_shared<opr::GroupClientProxy>( | |||
| ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
| auto&& graph = inputs[0]->owner_graph(); | |||
| return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
| @@ -27,8 +27,10 @@ public: | |||
| m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
| } | |||
| bool connect(std::string ip, size_t port, std::string password) { | |||
| m_client.auth(password); | |||
| void connect(std::string ip, size_t port, std::optional<std::string> password) { | |||
| if (password) { | |||
| m_client.auth(*password); | |||
| } | |||
| m_client.connect( | |||
| ip, port, | |||
| [](const std::string& host, std::size_t port, | |||
| @@ -40,16 +42,32 @@ public: | |||
| } | |||
| }, | |||
| std::uint32_t(200)); | |||
| if (!m_client.is_connected()) { | |||
| return false; | |||
| } | |||
| mgb_assert(m_client.is_connected(), "connect failed"); | |||
| auto flag = m_client.get("mgb-cache-flag"); | |||
| sync(); | |||
| return flag.get().ok(); | |||
| auto is_valid = [](const cpp_redis::reply& reply) { | |||
| switch (reply.get_type()) { | |||
| case cpp_redis::reply::type::error: | |||
| case cpp_redis::reply::type::null: | |||
| return false; | |||
| case cpp_redis::reply::type::integer: | |||
| return reply.as_integer() != 0; | |||
| case cpp_redis::reply::type::simple_string: | |||
| case cpp_redis::reply::type::bulk_string: | |||
| return !reply.as_string().empty(); | |||
| case cpp_redis::reply::type::array: | |||
| return !reply.as_array().empty(); | |||
| default: | |||
| mgb_assert(false, "unknown reply type %d", (int)reply.get_type()); | |||
| } | |||
| }; | |||
| mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag"); | |||
| } | |||
| bool valid() const override { return m_client.is_connected(); } | |||
| void flush() override {} | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto mem_result = m_local->get(category, key); | |||
| @@ -75,7 +93,7 @@ public: | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
| std::string redis_key_str; | |||
| encode(category + '@' + key_str, redis_key_str); | |||
| encode(category + '@' + key_str, redis_key_str, 24); | |||
| std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
| std::string redis_value_str; | |||
| encode(value_str, redis_value_str); | |||
| @@ -118,18 +136,16 @@ private: | |||
| class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
| private: | |||
| std::string m_path; | |||
| std::optional<std::string> m_path; | |||
| std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||
| public: | |||
| ExtendedInFilePersistentCache() = default; | |||
| bool open(std::string path) { | |||
| void open(std::string path) { | |||
| std::fstream file; | |||
| file.open(path, std::ios::in | std::ios::binary); | |||
| if (!file.is_open()) { | |||
| return false; | |||
| } | |||
| mgb_assert(file.is_open(), "can't open file in %s", path.c_str()); | |||
| std::vector<char> bytes = { | |||
| std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||
| if (bytes.size()) { | |||
| @@ -139,14 +155,11 @@ public: | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
| } | |||
| m_path = path; | |||
| return true; | |||
| } | |||
| ~ExtendedInFilePersistentCache() { | |||
| if (m_impl) { | |||
| m_impl->dump_cache(m_path.c_str()); | |||
| } | |||
| } | |||
| void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); } | |||
| ~ExtendedInFilePersistentCache() { flush(); } | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| return m_impl->get(category, key); | |||
| @@ -157,29 +170,64 @@ public: | |||
| } | |||
| std::optional<size_t> clear() override { | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
| m_impl->dump_cache(m_path.c_str()); | |||
| if (m_impl) { | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
| if (m_path) { | |||
| m_impl->dump_cache(m_path->c_str()); | |||
| } | |||
| } | |||
| return {}; | |||
| } | |||
| bool valid() const override { return m_impl != nullptr; } | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
| if (!cache->connect(ip, port, password)) { | |||
| return nullptr; | |||
| void flush() override { | |||
| if (m_impl && m_path) { | |||
| m_impl->dump_cache(m_path->c_str()); | |||
| } | |||
| } | |||
| return cache; | |||
| } | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
| if (!cache->open(path)) { | |||
| return nullptr; | |||
| std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config( | |||
| std::string type, std::unordered_map<std::string, std::string> args, | |||
| std::string& err_msg) { | |||
| try { | |||
| if (type == "redis") { | |||
| std::string prefix = args.at("prefix"); | |||
| std::optional<std::string> password = args.count("password") | |||
| ? args.at("password") | |||
| : std::optional<std::string>(); | |||
| auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
| if (args.count("unixsocket")) { | |||
| std::string unixsocket = args.at("unixsocket"); | |||
| cache->connect(unixsocket, 0, password); | |||
| } else { | |||
| std::string ip = args.at("hostname"); | |||
| int port = atoi(args.at("port").c_str()); | |||
| std::optional<std::string> password = | |||
| args.count("password") ? args.at("password") | |||
| : std::optional<std::string>(); | |||
| cache->connect(ip, port, password); | |||
| } | |||
| return cache; | |||
| } else if (type == "in-file") { | |||
| std::string path = args.at("path"); | |||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
| cache->open(path); | |||
| return cache; | |||
| } else if (type == "in-memory") { | |||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
| cache->open(); | |||
| return cache; | |||
| } else { | |||
| mgb_assert(false, "persistent cache type %s unsupported", type.c_str()); | |||
| } | |||
| } catch (const std::exception& exc) { | |||
| err_msg = exc.what(); | |||
| } catch (...) { | |||
| err_msg = "unknown exception"; | |||
| } | |||
| return cache; | |||
| return nullptr; | |||
| } | |||
| } // namespace mgb::imperative::persistent_cache | |||
| @@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache { | |||
| public: | |||
| virtual bool valid() const = 0; | |||
| virtual std::optional<size_t> clear() = 0; | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix); | |||
| virtual void flush() = 0; | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
| static std::shared_ptr<ExtendedPersistentCache> make_from_config( | |||
| std::string type, std::unordered_map<std::string, std::string> args, | |||
| std::string& err_msg); | |||
| }; | |||
| } // namespace mgb::imperative::persistent_cache | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) { | |||
| REQUIRE_GPU(2); | |||
| const char* server_addr = "127.0.0.1"; | |||
| uint32_t port = 3456; | |||
| mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||
| mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||
| HostTensorGenerator<> gen; | |||
| CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | |||
| @@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) { | |||
| REQUIRE_GPU(2); | |||
| const char* server_addr = "127.0.0.1"; | |||
| uint32_t port = 4567; | |||
| mgb_assert(create_zmqrpc_server(server_addr, port) > 0); | |||
| mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); | |||
| HostTensorGenerator<> gen; | |||
| CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); | |||
| @@ -17,6 +17,9 @@ | |||
| #include "megbrain/opr/zmq_rpc.h" | |||
| #include "mm_handler.pb.h" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| /* ======================== GroupServerProxy ========================== */ | |||
| /*! | |||
| * A proxy that receives zmqrpc call, direct call to NCCL Manager | |||
| @@ -213,7 +216,7 @@ struct ServerInfo { | |||
| std::unique_ptr<ZmqRpc::ZmqRpcServer> server; | |||
| }; | |||
| int create_zmqrpc_server(const std::string& server_addr, int port) { | |||
| int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) { | |||
| static std::unordered_map<std::string, ServerInfo> addr2server; | |||
| static std::mutex mtx; | |||
| MGB_LOCK_GUARD(mtx); | |||
| @@ -16,8 +16,8 @@ | |||
| #include "megbrain/opr/collective_comm.h" | |||
| #include "megbrain/opr/group_manager.h" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| namespace mgb { | |||
| namespace opr { | |||
| /*! | |||
| * Comm MM Client Proxy. | |||
| @@ -56,6 +56,9 @@ private: | |||
| int create_zmqrpc_server(const std::string& server_addr, int port); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -13,8 +13,6 @@ global: | |||
| base_exceptions*; | |||
| }; | |||
| megcore*; | |||
| *GroupClientProxy*; | |||
| *create_zmqrpc_server*; | |||
| *custom*; | |||