| @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
| set(MODULE_NAME _imperative_rt) | |||
| set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
| file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
| set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
| @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | |||
| target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
| @@ -11,6 +11,7 @@ import argparse | |||
| import getpass | |||
| import os | |||
| import sys | |||
| import urllib.parse | |||
| from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
| from ..logger import get_logger | |||
| @@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager): | |||
| if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
| get_logger().info("fastrun use in-memory cache") | |||
| self.open_memory() | |||
| else: | |||
| elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": | |||
| self.open_file() | |||
| else: | |||
| self.open_redis() | |||
| def open_memory(self): | |||
| pass | |||
| @@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager): | |||
| ) | |||
| self.open_memory() | |||
| def open_redis(self): | |||
| prefix = "mgbcache:{}:MGB{}:GIT:{}".format( | |||
| getpass.getuser(), __version__, git_version | |||
| ) | |||
| url = os.getenv("MGE_FASTRUN_CACHE_URL") | |||
| if url is None: | |||
| self.open_file() | |||
| try: | |||
| assert sys.platform != "win32", "redis cache on windows not tested" | |||
| parse_result = urllib.parse.urlparse(url, scheme="redis") | |||
| assert parse_result.scheme == "redis", "unsupported scheme" | |||
| assert not parse_result.username, "redis conn with username unsupported" | |||
| assert self.try_open_redis( | |||
| parse_result.hostname, parse_result.port, parse_result.password, prefix | |||
| ), "connect failed" | |||
| except Exception as exc: | |||
| get_logger().error( | |||
| "failed to connect to cache server {!r}; try fallback to " | |||
| "in-file cache".format(exc) | |||
| ) | |||
| self.open_file() | |||
| _manager = None | |||
| @@ -60,3 +85,23 @@ def get_manager(): | |||
| if _manager is None: | |||
| _manager = PersistentCacheManager() | |||
| return _manager | |||
| def _clean(): | |||
| nr_del = get_manager().clean() | |||
| if nr_del is not None: | |||
| print("{} cache entries deleted".format(nr_del)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="manage persistent cache") | |||
| subp = parser.add_subparsers(description="action to be performed", dest="cmd") | |||
| subp.required = True | |||
| subp_clean = subp.add_parser("clean", help="clean all the cache of current user") | |||
| subp_clean.set_defaults(action=_clean) | |||
| args = parser.parse_args() | |||
| args.action() | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -245,6 +245,11 @@ void init_utils(py::module m) { | |||
| } | |||
| return false; | |||
| } | |||
| bool open_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
| ip, port, password, prefix)); | |||
| } | |||
| bool open_file(std::string path) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
| } | |||
| @@ -271,6 +276,7 @@ void init_utils(py::module m) { | |||
| py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
| .def(py::init<>()) | |||
| .def("try_open_redis", &PersistentCacheManager::open_redis) | |||
| .def("try_open_file", &PersistentCacheManager::open_file) | |||
| .def("clean", &PersistentCacheManager::clean) | |||
| .def("put", &PersistentCacheManager::put) | |||
| @@ -13,12 +13,109 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "cpp_redis/cpp_redis" | |||
| #include "megbrain/imperative/persistent_cache.h" | |||
| #include "megbrain/imperative/utils/base64.h" | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| namespace mgb::imperative::persistent_cache { | |||
| class RedisCache final : public ExtendedPersistentCache { | |||
| public: | |||
| RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||
| m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
| } | |||
| bool connect(std::string ip, size_t port, std::string password) { | |||
| m_client.auth(password); | |||
| m_client.connect( | |||
| ip, port, | |||
| [](const std::string& host, std::size_t port, | |||
| cpp_redis::connect_state status) { | |||
| if (status == cpp_redis::connect_state::dropped) { | |||
| mgb_log("client disconnected from %s.", host.c_str()); | |||
| mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||
| port); | |||
| } | |||
| }, | |||
| std::uint32_t(200)); | |||
| if (!m_client.is_connected()) { | |||
| return false; | |||
| } | |||
| auto flag = m_client.get("mgb-cache-flag"); | |||
| sync(); | |||
| return flag.get().ok(); | |||
| } | |||
| bool valid() const override { return m_client.is_connected(); } | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto mem_result = m_local->get(category, key); | |||
| if (mem_result.valid()) | |||
| return mem_result; | |||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
| std::string redis_key_str; | |||
| encode(category + '@' + key_str, redis_key_str, 24); | |||
| auto result = m_client.get(redis_key_str); | |||
| sync(); | |||
| auto content = result.get(); | |||
| if (content.is_null()) | |||
| return mgb::None; | |||
| std::string decode_content; | |||
| decode(content.as_string(), decode_content); | |||
| m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||
| return m_local->get(category, key); | |||
| } | |||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
| std::string redis_key_str; | |||
| encode(category + '@' + key_str, redis_key_str); | |||
| std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
| std::string redis_value_str; | |||
| encode(value_str, redis_value_str); | |||
| auto result = m_client.set(redis_key_str, redis_value_str); | |||
| m_local->put(category, key, value); | |||
| sync(); | |||
| } | |||
| std::optional<size_t> clear() override { | |||
| size_t cursor = 0, nr_deleted = 0; | |||
| std::string pattern = m_prefix + "@*"; | |||
| do { | |||
| auto reply = m_client.scan(cursor, pattern).share(); | |||
| sync(); | |||
| auto keys = reply.get().as_array(); | |||
| std::vector<std::string> string_keys; | |||
| for (auto&& key : keys) { | |||
| string_keys.push_back(key.as_string()); | |||
| } | |||
| m_client.del(string_keys); | |||
| nr_deleted += string_keys.size(); | |||
| cursor = reply.get().as_array()[0].as_integer(); | |||
| } while (cursor != 0); | |||
| return nr_deleted; | |||
| } | |||
| private: | |||
| std::shared_ptr<mgb::PersistentCache> m_local; | |||
| std::mutex m_mtx; | |||
| cpp_redis::client m_client; | |||
| std::string m_prefix; | |||
| uint64_t m_timeout; | |||
| void sync() { | |||
| m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||
| mgb_assert(valid()); | |||
| } | |||
| }; | |||
| class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
| private: | |||
| std::string m_path; | |||
| @@ -68,6 +165,15 @@ public: | |||
| bool valid() const override { return m_impl != nullptr; } | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
| if (!cache->connect(ip, port, password)) { | |||
| return nullptr; | |||
| } | |||
| return cache; | |||
| } | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
| if (!cache->open(path)) { | |||
| @@ -22,6 +22,9 @@ public: | |||
| virtual std::optional<size_t> clear() = 0; | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix); | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
| } // namespace mgb::imperative::persistent_cache | |||
| @@ -12,7 +12,7 @@ endif() | |||
| # TODO: turn python binding into a static/object library | |||
| add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
| add_dependencies(imperative_test mgb_opdef) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
| # Python binding | |||
| target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
| @@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l | |||
| install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) | |||
| # begin config lite | |||
| if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||
| # FXIME third_party cpp redis do not support build with clang-cl | |||
| list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS}) | |||
| endif() | |||
| add_library(lite_static STATIC ${SOURCES_LITE}) | |||
| add_dependencies(lite_static lite_fbs_generate) | |||
| include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>) | |||
| @@ -106,6 +110,14 @@ endif() | |||
| if(LITE_BUILD_WITH_MGE) | |||
| target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) | |||
| endif() | |||
| if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) | |||
| # FXIME third_party cpp redis do not support build with clang-cl | |||
| target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES}) | |||
| target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES}) | |||
| target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES}) | |||
| target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES}) | |||
| endif() | |||
| set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") | |||
| add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) | |||
| if(NOT MSVC AND NOT WIN32) | |||