GitOrigin-RevId: 9af7fa5c97
tags/v1.7.2.m1
| @@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||||
| include(cmake/Halide.cmake) | include(cmake/Halide.cmake) | ||||
| endif() | endif() | ||||
| include(cmake/cpp_redis.cmake) | |||||
| # Thread | # Thread | ||||
| IF(APPLE) | IF(APPLE) | ||||
| set(CMAKE_THREAD_LIBS_INIT "-lpthread") | set(CMAKE_THREAD_LIBS_INIT "-lpthread") | ||||
| @@ -0,0 +1,2 @@ | |||||
| file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp) | |||||
| set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes) | |||||
| @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||||
| set(MODULE_NAME _imperative_rt) | set(MODULE_NAME _imperative_rt) | ||||
| set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | ||||
| file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | ||||
| set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | ||||
| @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | ||||
| target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | ||||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | ||||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | ||||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | if(CXX_SUPPORT_WCLASS_MEMACCESS) | ||||
| @@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( | |||||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | ||||
| ) | ) | ||||
| _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||||
| _persistent_cache_impl_ins.reg() | |||||
| atexit.register(_close) | atexit.register(_close) | ||||
| del _set_fork_exec_path_for_timed_func | del _set_fork_exec_path_for_timed_func | ||||
| @@ -135,3 +132,5 @@ import megengine.quantization | |||||
| import megengine.random | import megengine.random | ||||
| import megengine.utils | import megengine.utils | ||||
| import megengine.traced_module | import megengine.traced_module | ||||
| persistent_cache.get_manager() | |||||
| @@ -9,108 +9,54 @@ | |||||
| import argparse | import argparse | ||||
| import getpass | import getpass | ||||
| import json | |||||
| import os | import os | ||||
| import shelve | |||||
| import sys | |||||
| from ..core._imperative_rt import PersistentCache as _PersistentCache | |||||
| from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..version import __version__, git_version | from ..version import __version__, git_version | ||||
| class _FakeRedisConn: | |||||
| _cache_dir = None | |||||
| _is_shelve = False | |||||
| _dict = {} | |||||
| class PersistentCacheManager(_PersistentCacheManager): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | |||||
| if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | ||||
| self._dict = {} | |||||
| self._is_shelve = False | |||||
| get_logger().info("fastrun use in-memory cache") | get_logger().info("fastrun use in-memory cache") | ||||
| self.open_memory() | |||||
| else: | else: | ||||
| try: | |||||
| self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||||
| if not self._cache_dir: | |||||
| from ..hub.hub import _get_megengine_home | |||||
| self._cache_dir = os.path.expanduser( | |||||
| os.path.join(_get_megengine_home(), "persistent_cache") | |||||
| ) | |||||
| os.makedirs(self._cache_dir, exist_ok=True) | |||||
| cache_file = os.path.join(self._cache_dir, "cache") | |||||
| self._dict = shelve.open(cache_file) | |||||
| self._is_shelve = True | |||||
| get_logger().info( | |||||
| "fastrun use in-file cache in {}".format(self._cache_dir) | |||||
| ) | |||||
| except Exception as exc: | |||||
| self._dict = {} | |||||
| self._is_shelve = False | |||||
| get_logger().error( | |||||
| "failed to create cache file in {} {!r}; fallback to " | |||||
| "in-memory cache".format(self._cache_dir, exc) | |||||
| ) | |||||
| def get(self, key): | |||||
| if self._is_shelve and isinstance(key, bytes): | |||||
| key = key.decode("utf-8") | |||||
| return self._dict.get(key) | |||||
| def set(self, key, val): | |||||
| if self._is_shelve and isinstance(key, bytes): | |||||
| key = key.decode("utf-8") | |||||
| self._dict[key] = val | |||||
| def clear(self): | |||||
| print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir)) | |||||
| self._dict.clear() | |||||
| self.open_file() | |||||
| def __del__(self): | |||||
| if self._is_shelve: | |||||
| self._dict.close() | |||||
| def open_memory(self): | |||||
| pass | |||||
| def open_file(self): | |||||
| cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||||
| try: | |||||
| if not cache_dir: | |||||
| from ..hub.hub import _get_megengine_home | |||||
| class PersistentCacheOnServer(_PersistentCache): | |||||
| _cached_conn = None | |||||
| _prefix = None | |||||
| _prev_get_refkeep = None | |||||
| @property | |||||
| def _conn(self): | |||||
| """get redis connection""" | |||||
| if self._cached_conn is None: | |||||
| self._cached_conn = _FakeRedisConn() | |||||
| self._prefix = self.make_user_prefix() | |||||
| return self._cached_conn | |||||
| @classmethod | |||||
| def make_user_prefix(cls): | |||||
| return "mgbcache:{}".format(getpass.getuser()) | |||||
| def _make_key(self, category, key): | |||||
| prefix_with_version = "{}:MGB{}:GIT:{}".format( | |||||
| self._prefix, __version__, git_version | |||||
| ) | |||||
| return b"@".join( | |||||
| (prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||||
| ) | |||||
| def put(self, category, key, value): | |||||
| conn = self._conn | |||||
| key = self._make_key(category, key) | |||||
| conn.set(key, value) | |||||
| def get(self, category, key): | |||||
| conn = self._conn | |||||
| key = self._make_key(category, key) | |||||
| self._prev_get_refkeep = conn.get(key) | |||||
| return self._prev_get_refkeep | |||||
| def clean(self): | |||||
| conn = self._conn | |||||
| if isinstance(conn, _FakeRedisConn): | |||||
| conn.clear() | |||||
| cache_dir = os.path.expanduser( | |||||
| os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||||
| ) | |||||
| os.makedirs(cache_dir, exist_ok=True) | |||||
| cache_file = os.path.join(cache_dir, "cache") | |||||
| with open(cache_file, "a"): | |||||
| pass | |||||
| assert self.try_open_file(cache_file), "cannot create file" | |||||
| get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||||
| except Exception as exc: | |||||
| get_logger().error( | |||||
| "failed to create cache file in {} {!r}; fallback to " | |||||
| "in-memory cache".format(cache_dir, exc) | |||||
| ) | |||||
| self.open_memory() | |||||
| _manager = None | |||||
| def get_manager(): | |||||
| global _manager | |||||
| if _manager is None: | |||||
| _manager = PersistentCacheManager() | |||||
| return _manager | |||||
| @@ -23,6 +23,7 @@ | |||||
| #include "megbrain/common.h" | #include "megbrain/common.h" | ||||
| #include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
| #include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
| #include "megbrain/imperative/persistent_cache.h" | |||||
| #include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
| #include "megbrain/imperative/tensor_sanity_check.h" | #include "megbrain/imperative/tensor_sanity_check.h" | ||||
| #include "megbrain/serialization/helper.h" | #include "megbrain/serialization/helper.h" | ||||
| @@ -229,83 +230,55 @@ void init_utils(py::module m) { | |||||
| mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | ||||
| }); | }); | ||||
| using mgb::PersistentCache; | |||||
| class PyPersistentCache : public mgb::PersistentCache { | |||||
| private: | |||||
| using KeyPair = std::pair<std::string, std::string>; | |||||
| using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>; | |||||
| using PersistentCache = mgb::PersistentCache; | |||||
| using ExtendedPersistentCache = | |||||
| mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||||
| std::shared_mutex m_mutex; | |||||
| std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; | |||||
| struct PersistentCacheManager { | |||||
| std::shared_ptr<ExtendedPersistentCache> instance; | |||||
| static size_t hash_key_pair(const KeyPair& kp) { | |||||
| std::hash<std::string> hasher; | |||||
| return hasher(kp.first) ^ hasher(kp.second); | |||||
| bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||||
| if (cache) { | |||||
| instance = cache; | |||||
| PersistentCache::set_impl(cache); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| std::string blob_to_str(const Blob& key) { | |||||
| return std::string(reinterpret_cast<const char*>(key.ptr), key.size); | |||||
| bool open_redis( | |||||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||||
| return try_reg(mgb::imperative::persistent_cache::make_redis( | |||||
| ip, port, password, prefix)); | |||||
| } | } | ||||
| BlobPtr copy_blob(const Blob& blob) { | |||||
| auto blob_deleter = [](Blob* blob) { | |||||
| if (blob) { | |||||
| std::free(const_cast<void*>(blob->ptr)); | |||||
| delete blob; | |||||
| } | |||||
| }; | |||||
| auto blob_ptr = BlobPtr{new Blob(), blob_deleter}; | |||||
| blob_ptr->ptr = std::malloc(blob.size); | |||||
| std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size); | |||||
| blob_ptr->size = blob.size; | |||||
| return blob_ptr; | |||||
| bool open_file(std::string path) { | |||||
| return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||||
| } | } | ||||
| BlobPtr str_to_blob(const std::string& str) { | |||||
| auto blob = Blob{str.data(), str.size()}; | |||||
| return copy_blob(blob); | |||||
| std::optional<size_t> clean() { | |||||
| if (instance) { | |||||
| return instance->clear(); | |||||
| } | |||||
| return {}; | |||||
| } | } | ||||
| std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() { | |||||
| return BlobPtr{nullptr, [](Blob* blob) {}}; | |||||
| void put(std::string category, std::string key, std::string value) { | |||||
| PersistentCache::inst().put( | |||||
| category, {key.data(), key.size()}, {value.data(), value.size()}); | |||||
| } | } | ||||
| public: | |||||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
| auto py_get = [this](const std::string& category, | |||||
| const Blob& key) -> mgb::Maybe<Blob> { | |||||
| PYBIND11_OVERLOAD_PURE( | |||||
| mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||||
| }; | |||||
| KeyPair kp = {category, blob_to_str(key)}; | |||||
| std::shared_lock<decltype(m_mutex)> rlock; | |||||
| auto iter = m_local_cache.find(kp); | |||||
| if (iter == m_local_cache.end()) { | |||||
| auto py_ret = py_get(category, key); | |||||
| if (!py_ret.valid()) { | |||||
| iter = m_local_cache.insert({kp, empty_blob()}).first; | |||||
| } else { | |||||
| iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; | |||||
| } | |||||
| } | |||||
| if (iter->second) { | |||||
| return *iter->second; | |||||
| py::object get(std::string category, std::string key) { | |||||
| auto value = | |||||
| PersistentCache::inst().get(category, {key.data(), key.size()}); | |||||
| if (value.valid()) { | |||||
| return py::bytes(std::string((const char*)value->ptr, value->size)); | |||||
| } else { | } else { | ||||
| return {}; | |||||
| return py::none(); | |||||
| } | } | ||||
| } | } | ||||
| void put(const std::string& category, const Blob& key, const Blob& value) | |||||
| override { | |||||
| KeyPair kp = {category, blob_to_str(key)}; | |||||
| std::unique_lock<decltype(m_mutex)> wlock; | |||||
| m_local_cache.insert_or_assign(kp, copy_blob(value)); | |||||
| PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||||
| } | |||||
| }; | }; | ||||
| py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>( | |||||
| m, "PersistentCache") | |||||
| py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("get", &PersistentCache::get) | |||||
| .def("put", &PersistentCache::put) | |||||
| .def("reg", &PersistentCache::set_impl); | |||||
| .def("try_open_redis", &PersistentCacheManager::open_redis) | |||||
| .def("try_open_file", &PersistentCacheManager::open_file) | |||||
| .def("clean", &PersistentCacheManager::clean) | |||||
| .def("put", &PersistentCacheManager::put) | |||||
| .def("get", &PersistentCacheManager::get); | |||||
| } | } | ||||
| @@ -1,12 +1,11 @@ | |||||
| import pytest | import pytest | ||||
| import megengine | |||||
| from megengine.utils.persistent_cache import PersistentCacheOnServer | |||||
| from megengine.utils.persistent_cache import _manager | |||||
| @pytest.mark.skip(reason="fixme: github ci failed") | @pytest.mark.skip(reason="fixme: github ci failed") | ||||
| def test_persistent_cache(): | def test_persistent_cache(): | ||||
| pc = PersistentCacheOnServer() | |||||
| pc = _manager | |||||
| k0 = b"\x00\x00" | k0 = b"\x00\x00" | ||||
| k1 = b"\x00\x01" | k1 = b"\x00\x01" | ||||
| cat = "test" | cat = "test" | ||||
| @@ -0,0 +1,186 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/persistent_cache.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include <fstream> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "cpp_redis/cpp_redis" | |||||
| #include "megbrain/imperative/persistent_cache.h" | |||||
| #include "megbrain/imperative/utils/base64.h" | |||||
| #include "megbrain/utils/infile_persistent_cache.h" | |||||
| namespace mgb::imperative::persistent_cache { | |||||
| class RedisCache final : public ExtendedPersistentCache { | |||||
| public: | |||||
| RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||||
| m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||||
| } | |||||
| bool connect(std::string ip, size_t port, std::string password) { | |||||
| m_client.auth(password); | |||||
| m_client.connect( | |||||
| ip, port, | |||||
| [](const std::string& host, std::size_t port, | |||||
| cpp_redis::connect_state status) { | |||||
| if (status == cpp_redis::connect_state::dropped) { | |||||
| mgb_log("client disconnected from %s.", host.c_str()); | |||||
| mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||||
| port); | |||||
| } | |||||
| }, | |||||
| std::uint32_t(200)); | |||||
| if (!m_client.is_connected()) { | |||||
| return false; | |||||
| } | |||||
| auto flag = m_client.get("mgb-cache-flag"); | |||||
| sync(); | |||||
| return flag.get().ok(); | |||||
| } | |||||
| bool valid() const override { return m_client.is_connected(); } | |||||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| auto mem_result = m_local->get(category, key); | |||||
| if (mem_result.valid()) | |||||
| return mem_result; | |||||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
| std::string redis_key_str; | |||||
| encode(category + '@' + key_str, redis_key_str, 24); | |||||
| auto result = m_client.get(redis_key_str); | |||||
| sync(); | |||||
| auto content = result.get(); | |||||
| if (content.is_null()) | |||||
| return mgb::None; | |||||
| std::string decode_content; | |||||
| decode(content.as_string(), decode_content); | |||||
| m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||||
| return m_local->get(category, key); | |||||
| } | |||||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
| MGB_LOCK_GUARD(m_mtx); | |||||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||||
| std::string redis_key_str; | |||||
| encode(category + '@' + key_str, redis_key_str); | |||||
| std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||||
| std::string redis_value_str; | |||||
| encode(value_str, redis_value_str); | |||||
| auto result = m_client.set(redis_key_str, redis_value_str); | |||||
| m_local->put(category, key, value); | |||||
| sync(); | |||||
| } | |||||
| std::optional<size_t> clear() override { | |||||
| size_t cursor = 0, nr_deleted = 0; | |||||
| std::string pattern = m_prefix + "@*"; | |||||
| do { | |||||
| auto reply = m_client.scan(cursor, pattern).share(); | |||||
| sync(); | |||||
| auto keys = reply.get().as_array(); | |||||
| std::vector<std::string> string_keys; | |||||
| for (auto&& key : keys) { | |||||
| string_keys.push_back(key.as_string()); | |||||
| } | |||||
| m_client.del(string_keys); | |||||
| nr_deleted += string_keys.size(); | |||||
| cursor = reply.get().as_array()[0].as_integer(); | |||||
| } while (cursor != 0); | |||||
| return nr_deleted; | |||||
| } | |||||
| private: | |||||
| std::shared_ptr<mgb::PersistentCache> m_local; | |||||
| std::mutex m_mtx; | |||||
| cpp_redis::client m_client; | |||||
| std::string m_prefix; | |||||
| uint64_t m_timeout; | |||||
| void sync() { | |||||
| m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||||
| mgb_assert(valid()); | |||||
| } | |||||
| }; | |||||
| class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||||
| private: | |||||
| std::string m_path; | |||||
| std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||||
| public: | |||||
| ExtendedInFilePersistentCache() = default; | |||||
| bool open(std::string path) { | |||||
| std::fstream file; | |||||
| file.open(path, std::ios::in | std::ios::binary); | |||||
| if (!file.is_open()) { | |||||
| return false; | |||||
| } | |||||
| std::vector<char> bytes = { | |||||
| std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||||
| if (bytes.size()) { | |||||
| m_impl = std::make_unique<mgb::InFilePersistentCache>( | |||||
| (const uint8_t*)bytes.data(), bytes.size()); | |||||
| } else { | |||||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
| } | |||||
| m_path = path; | |||||
| return true; | |||||
| } | |||||
| ~ExtendedInFilePersistentCache() { | |||||
| if (m_impl) { | |||||
| m_impl->dump_cache(m_path.c_str()); | |||||
| } | |||||
| } | |||||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||||
| return m_impl->get(category, key); | |||||
| } | |||||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||||
| return m_impl->put(category, key, value); | |||||
| } | |||||
| std::optional<size_t> clear() override { | |||||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||||
| m_impl->dump_cache(m_path.c_str()); | |||||
| return {}; | |||||
| } | |||||
| bool valid() const override { return m_impl != nullptr; } | |||||
| }; | |||||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||||
| auto cache = std::make_shared<RedisCache>(prefix, 100); | |||||
| if (!cache->connect(ip, port, password)) { | |||||
| return nullptr; | |||||
| } | |||||
| return cache; | |||||
| } | |||||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||||
| if (!cache->open(path)) { | |||||
| return nullptr; | |||||
| } | |||||
| return cache; | |||||
| } | |||||
| } // namespace mgb::imperative::persistent_cache | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,172 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/base64.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/utils/base64.h" | |||||
| namespace mgb::imperative { | |||||
| namespace { | |||||
| /* | |||||
| ** Translation Table as described in RFC1113 | |||||
| */ | |||||
| const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | |||||
| /* | |||||
| ** Translation Table to decode: | |||||
| *https://github.com/dgiardini/imgcalkap/blob/master/base64.c | |||||
| */ | |||||
| const char cd64[] = | |||||
| "|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`" | |||||
| "abcdefghijklmnopq"; | |||||
| /* | |||||
| ** encodeblock | |||||
| ** | |||||
| ** encode 3 8-bit binary bytes as 4 '6-bit' characters | |||||
| */ | |||||
| void encodeblock(unsigned char in[3], unsigned char out[4], int len) { | |||||
| out[0] = cb64[in[0] >> 2]; | |||||
| out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)]; | |||||
| out[2] = | |||||
| (unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '='); | |||||
| out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '='); | |||||
| } | |||||
| /* | |||||
| ** decodeblock | |||||
| ** | |||||
| ** decode 4 '6-bit' characters into 3 8-bit binary bytes | |||||
| */ | |||||
| void decodeblock(unsigned char in[4], unsigned char out[3]) { | |||||
| out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4); | |||||
| out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2); | |||||
| out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]); | |||||
| } | |||||
| } // namespace | |||||
| /** | |||||
| * Encode string to base64 string | |||||
| * @param input - source string | |||||
| * @param outdata - target base64 string | |||||
| * @param linesize - max size of line | |||||
| */ | |||||
| void encode( | |||||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||||
| int linesize) { | |||||
| outdata.clear(); | |||||
| unsigned char in[3], out[4]; | |||||
| int i, len, blocksout = 0; | |||||
| size_t j = 0; | |||||
| auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||||
| unsigned int insize = input.size(); | |||||
| while (j <= insize) { | |||||
| len = 0; | |||||
| for (i = 0; i < 3; i++) { | |||||
| in[i] = (unsigned char)indata[j]; | |||||
| j++; | |||||
| if (j <= insize) { | |||||
| len++; | |||||
| } else { | |||||
| in[i] = 0; | |||||
| } | |||||
| } | |||||
| if (len) { | |||||
| encodeblock(in, out, len); | |||||
| for (i = 0; i < 4; i++) { | |||||
| outdata.push_back(out[i]); | |||||
| } | |||||
| blocksout++; | |||||
| } | |||||
| if (blocksout >= (linesize / 4) || (j == insize)) { | |||||
| if (blocksout) { | |||||
| outdata.push_back('\r'); | |||||
| outdata.push_back('\n'); | |||||
| } | |||||
| blocksout = 0; | |||||
| } | |||||
| } | |||||
| } | |||||
| /** | |||||
| * Decode base64 string ot source | |||||
| * @param input - base64 string | |||||
| * @param outdata - source string | |||||
| */ | |||||
| void decode( | |||||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) { | |||||
| outdata.clear(); | |||||
| unsigned char in[4], out[3], v; | |||||
| int i, len; | |||||
| size_t j = 0; | |||||
| auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||||
| unsigned int insize = input.size(); | |||||
| while (j <= insize) { | |||||
| for (len = 0, i = 0; i < 4 && (j <= insize); i++) { | |||||
| v = 0; | |||||
| while ((j <= insize) && v == 0) { | |||||
| v = (unsigned char)indata[j++]; | |||||
| v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]); | |||||
| if (v) { | |||||
| v = (unsigned char)((v == '$') ? 0 : v - 61); | |||||
| } | |||||
| } | |||||
| if (j <= insize) { | |||||
| len++; | |||||
| if (v) { | |||||
| in[i] = (unsigned char)(v - 1); | |||||
| } | |||||
| } else { | |||||
| in[i] = 0; | |||||
| } | |||||
| } | |||||
| if (len) { | |||||
| decodeblock(in, out); | |||||
| for (i = 0; i < len - 1; i++) { | |||||
| outdata.push_back(out[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| /** | |||||
| * Encode binary data to base64 buffer | |||||
| * @param input - source data | |||||
| * @param outdata - target base64 buffer | |||||
| * @param linesize | |||||
| */ | |||||
| void encode(const std::string& input, std::string& outdata, int linesize) { | |||||
| std::vector<std::uint8_t> out; | |||||
| std::vector<std::uint8_t> in(input.begin(), input.end()); | |||||
| encode(in, out, linesize); | |||||
| outdata = std::string(out.begin(), out.end()); | |||||
| } | |||||
| /** | |||||
| * Decode base64 buffer to source binary data | |||||
| * @param input - base64 buffer | |||||
| * @param outdata - source binary data | |||||
| */ | |||||
| void decode(const std::string& input, std::string& outdata) { | |||||
| std::vector<std::uint8_t> in(input.begin(), input.end()); | |||||
| std::vector<std::uint8_t> out; | |||||
| decode(in, out); | |||||
| outdata = std::string(out.begin(), out.end()); | |||||
| } | |||||
| } // namespace mgb::imperative | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/persistent_cache.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <memory> | |||||
| #include "megbrain/utils/persistent_cache.h" | |||||
| namespace mgb::imperative::persistent_cache { | |||||
| class ExtendedPersistentCache : public mgb::PersistentCache { | |||||
| public: | |||||
| virtual bool valid() const = 0; | |||||
| virtual std::optional<size_t> clear() = 0; | |||||
| }; | |||||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||||
| std::string ip, size_t port, std::string password, std::string prefix); | |||||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||||
| } // namespace mgb::imperative::persistent_cache | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/utils/base64.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megbrain/common.h" | |||||
| namespace mgb::imperative { | |||||
| /** | |||||
| * Encode string to base64 string | |||||
| * @param input - source string | |||||
| * @param outdata - target base64 string | |||||
| * @param linesize - max size of line | |||||
| */ | |||||
| void encode( | |||||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||||
| int linesize = 76); | |||||
| /** | |||||
| * Decode base64 string ot source | |||||
| * @param input - base64 string | |||||
| * @param outdata - source string | |||||
| */ | |||||
| void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata); | |||||
| /** | |||||
| * Encode binary data to base64 buffer | |||||
| * @param input - source data | |||||
| * @param outdata - target base64 buffer | |||||
| * @param linesize | |||||
| */ | |||||
| void encode(const std::string& input, std::string& outdata, int linesize = 76); | |||||
| /** | |||||
| * Decode base64 buffer to source binary data | |||||
| * @param input - base64 buffer | |||||
| * @param outdata - source binary data | |||||
| */ | |||||
| void decode(const std::string& input, std::string& outdata); | |||||
| } // namespace mgb::imperative | |||||
| @@ -12,7 +12,7 @@ endif() | |||||
| # TODO: turn python binding into a static/object library | # TODO: turn python binding into a static/object library | ||||
| add_executable(imperative_test ${SOURCES} ${SRCS}) | add_executable(imperative_test ${SOURCES} ${SRCS}) | ||||
| add_dependencies(imperative_test mgb_opdef) | add_dependencies(imperative_test mgb_opdef) | ||||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||||
| # Python binding | # Python binding | ||||
| target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | ||||
| @@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { | |||||
| }; | }; | ||||
| }; | }; | ||||
| Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||||
| void put(const std::string& category, const Blob& key, const Blob& value) override; | |||||
| MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||||
| const std::string& category, const Blob& key) override; | |||||
| MGE_WIN_DECLSPEC_FUC void put( | |||||
| const std::string& category, const Blob& key, const Blob& value) override; | |||||
| std::unordered_map< | std::unordered_map< | ||||
| std::string, | std::string, | ||||
| std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | ||||
| m_cache; | m_cache; | ||||
| MGB_MUTEX m_mtx; | MGB_MUTEX m_mtx; | ||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | |||||
| }; | }; | ||||
| /*! | /*! | ||||