GitOrigin-RevId: 9af7fa5c97
tags/v1.7.2.m1
| @@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||
| include(cmake/Halide.cmake) | |||
| endif() | |||
| include(cmake/cpp_redis.cmake) | |||
| # Thread | |||
| IF(APPLE) | |||
| set(CMAKE_THREAD_LIBS_INIT "-lpthread") | |||
| @@ -0,0 +1,2 @@ | |||
| file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp) | |||
| set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes) | |||
| @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) | |||
| set(MODULE_NAME _imperative_rt) | |||
| set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) | |||
| file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) | |||
| set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") | |||
| @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) | |||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) | |||
| target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) | |||
| target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
| target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) | |||
| target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) | |||
| if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
| @@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( | |||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
| ) | |||
| _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||
| _persistent_cache_impl_ins.reg() | |||
| atexit.register(_close) | |||
| del _set_fork_exec_path_for_timed_func | |||
| @@ -135,3 +132,5 @@ import megengine.quantization | |||
| import megengine.random | |||
| import megengine.utils | |||
| import megengine.traced_module | |||
| persistent_cache.get_manager() | |||
| @@ -9,108 +9,54 @@ | |||
| import argparse | |||
| import getpass | |||
| import json | |||
| import os | |||
| import shelve | |||
| import sys | |||
| from ..core._imperative_rt import PersistentCache as _PersistentCache | |||
| from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager | |||
| from ..logger import get_logger | |||
| from ..version import __version__, git_version | |||
| class _FakeRedisConn: | |||
| _cache_dir = None | |||
| _is_shelve = False | |||
| _dict = {} | |||
| class PersistentCacheManager(_PersistentCacheManager): | |||
| def __init__(self): | |||
| super().__init__() | |||
| if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": | |||
| self._dict = {} | |||
| self._is_shelve = False | |||
| get_logger().info("fastrun use in-memory cache") | |||
| self.open_memory() | |||
| else: | |||
| try: | |||
| self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
| if not self._cache_dir: | |||
| from ..hub.hub import _get_megengine_home | |||
| self._cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache") | |||
| ) | |||
| os.makedirs(self._cache_dir, exist_ok=True) | |||
| cache_file = os.path.join(self._cache_dir, "cache") | |||
| self._dict = shelve.open(cache_file) | |||
| self._is_shelve = True | |||
| get_logger().info( | |||
| "fastrun use in-file cache in {}".format(self._cache_dir) | |||
| ) | |||
| except Exception as exc: | |||
| self._dict = {} | |||
| self._is_shelve = False | |||
| get_logger().error( | |||
| "failed to create cache file in {} {!r}; fallback to " | |||
| "in-memory cache".format(self._cache_dir, exc) | |||
| ) | |||
| def get(self, key): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| return self._dict.get(key) | |||
| def set(self, key, val): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| self._dict[key] = val | |||
| def clear(self): | |||
| print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir)) | |||
| self._dict.clear() | |||
| self.open_file() | |||
| def __del__(self): | |||
| if self._is_shelve: | |||
| self._dict.close() | |||
| def open_memory(self): | |||
| pass | |||
| def open_file(self): | |||
| cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") | |||
| try: | |||
| if not cache_dir: | |||
| from ..hub.hub import _get_megengine_home | |||
| class PersistentCacheOnServer(_PersistentCache): | |||
| _cached_conn = None | |||
| _prefix = None | |||
| _prev_get_refkeep = None | |||
| @property | |||
| def _conn(self): | |||
| """get redis connection""" | |||
| if self._cached_conn is None: | |||
| self._cached_conn = _FakeRedisConn() | |||
| self._prefix = self.make_user_prefix() | |||
| return self._cached_conn | |||
| @classmethod | |||
| def make_user_prefix(cls): | |||
| return "mgbcache:{}".format(getpass.getuser()) | |||
| def _make_key(self, category, key): | |||
| prefix_with_version = "{}:MGB{}:GIT:{}".format( | |||
| self._prefix, __version__, git_version | |||
| ) | |||
| return b"@".join( | |||
| (prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||
| ) | |||
| def put(self, category, key, value): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| conn.set(key, value) | |||
| def get(self, category, key): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| self._prev_get_refkeep = conn.get(key) | |||
| return self._prev_get_refkeep | |||
| def clean(self): | |||
| conn = self._conn | |||
| if isinstance(conn, _FakeRedisConn): | |||
| conn.clear() | |||
| cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache.bin") | |||
| ) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| cache_file = os.path.join(cache_dir, "cache") | |||
| with open(cache_file, "a"): | |||
| pass | |||
| assert self.try_open_file(cache_file), "cannot create file" | |||
| get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) | |||
| except Exception as exc: | |||
| get_logger().error( | |||
| "failed to create cache file in {} {!r}; fallback to " | |||
| "in-memory cache".format(cache_dir, exc) | |||
| ) | |||
| self.open_memory() | |||
| _manager = None | |||
| def get_manager(): | |||
| global _manager | |||
| if _manager is None: | |||
| _manager = PersistentCacheManager() | |||
| return _manager | |||
| @@ -23,6 +23,7 @@ | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/imperative/blob_manager.h" | |||
| #include "megbrain/imperative/persistent_cache.h" | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "megbrain/imperative/tensor_sanity_check.h" | |||
| #include "megbrain/serialization/helper.h" | |||
| @@ -229,83 +230,55 @@ void init_utils(py::module m) { | |||
| mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | |||
| }); | |||
| using mgb::PersistentCache; | |||
| class PyPersistentCache : public mgb::PersistentCache { | |||
| private: | |||
| using KeyPair = std::pair<std::string, std::string>; | |||
| using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>; | |||
| using PersistentCache = mgb::PersistentCache; | |||
| using ExtendedPersistentCache = | |||
| mgb::imperative::persistent_cache::ExtendedPersistentCache; | |||
| std::shared_mutex m_mutex; | |||
| std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; | |||
| struct PersistentCacheManager { | |||
| std::shared_ptr<ExtendedPersistentCache> instance; | |||
| static size_t hash_key_pair(const KeyPair& kp) { | |||
| std::hash<std::string> hasher; | |||
| return hasher(kp.first) ^ hasher(kp.second); | |||
| bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) { | |||
| if (cache) { | |||
| instance = cache; | |||
| PersistentCache::set_impl(cache); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| std::string blob_to_str(const Blob& key) { | |||
| return std::string(reinterpret_cast<const char*>(key.ptr), key.size); | |||
| bool open_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_redis( | |||
| ip, port, password, prefix)); | |||
| } | |||
| BlobPtr copy_blob(const Blob& blob) { | |||
| auto blob_deleter = [](Blob* blob) { | |||
| if (blob) { | |||
| std::free(const_cast<void*>(blob->ptr)); | |||
| delete blob; | |||
| } | |||
| }; | |||
| auto blob_ptr = BlobPtr{new Blob(), blob_deleter}; | |||
| blob_ptr->ptr = std::malloc(blob.size); | |||
| std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size); | |||
| blob_ptr->size = blob.size; | |||
| return blob_ptr; | |||
| bool open_file(std::string path) { | |||
| return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); | |||
| } | |||
| BlobPtr str_to_blob(const std::string& str) { | |||
| auto blob = Blob{str.data(), str.size()}; | |||
| return copy_blob(blob); | |||
| std::optional<size_t> clean() { | |||
| if (instance) { | |||
| return instance->clear(); | |||
| } | |||
| return {}; | |||
| } | |||
| std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() { | |||
| return BlobPtr{nullptr, [](Blob* blob) {}}; | |||
| void put(std::string category, std::string key, std::string value) { | |||
| PersistentCache::inst().put( | |||
| category, {key.data(), key.size()}, {value.data(), value.size()}); | |||
| } | |||
| public: | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| auto py_get = [this](const std::string& category, | |||
| const Blob& key) -> mgb::Maybe<Blob> { | |||
| PYBIND11_OVERLOAD_PURE( | |||
| mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||
| }; | |||
| KeyPair kp = {category, blob_to_str(key)}; | |||
| std::shared_lock<decltype(m_mutex)> rlock; | |||
| auto iter = m_local_cache.find(kp); | |||
| if (iter == m_local_cache.end()) { | |||
| auto py_ret = py_get(category, key); | |||
| if (!py_ret.valid()) { | |||
| iter = m_local_cache.insert({kp, empty_blob()}).first; | |||
| } else { | |||
| iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; | |||
| } | |||
| } | |||
| if (iter->second) { | |||
| return *iter->second; | |||
| py::object get(std::string category, std::string key) { | |||
| auto value = | |||
| PersistentCache::inst().get(category, {key.data(), key.size()}); | |||
| if (value.valid()) { | |||
| return py::bytes(std::string((const char*)value->ptr, value->size)); | |||
| } else { | |||
| return {}; | |||
| return py::none(); | |||
| } | |||
| } | |||
| void put(const std::string& category, const Blob& key, const Blob& value) | |||
| override { | |||
| KeyPair kp = {category, blob_to_str(key)}; | |||
| std::unique_lock<decltype(m_mutex)> wlock; | |||
| m_local_cache.insert_or_assign(kp, copy_blob(value)); | |||
| PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||
| } | |||
| }; | |||
| py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>( | |||
| m, "PersistentCache") | |||
| py::class_<PersistentCacheManager>(m, "PersistentCacheManager") | |||
| .def(py::init<>()) | |||
| .def("get", &PersistentCache::get) | |||
| .def("put", &PersistentCache::put) | |||
| .def("reg", &PersistentCache::set_impl); | |||
| .def("try_open_redis", &PersistentCacheManager::open_redis) | |||
| .def("try_open_file", &PersistentCacheManager::open_file) | |||
| .def("clean", &PersistentCacheManager::clean) | |||
| .def("put", &PersistentCacheManager::put) | |||
| .def("get", &PersistentCacheManager::get); | |||
| } | |||
| @@ -1,12 +1,11 @@ | |||
| import pytest | |||
| import megengine | |||
| from megengine.utils.persistent_cache import PersistentCacheOnServer | |||
| from megengine.utils.persistent_cache import _manager | |||
| @pytest.mark.skip(reason="fixme: github ci failed") | |||
| def test_persistent_cache(): | |||
| pc = PersistentCacheOnServer() | |||
| pc = _manager | |||
| k0 = b"\x00\x00" | |||
| k1 = b"\x00\x01" | |||
| cat = "test" | |||
| @@ -0,0 +1,186 @@ | |||
| /** | |||
| * \file imperative/src/impl/persistent_cache.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <fstream> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "cpp_redis/cpp_redis" | |||
| #include "megbrain/imperative/persistent_cache.h" | |||
| #include "megbrain/imperative/utils/base64.h" | |||
| #include "megbrain/utils/infile_persistent_cache.h" | |||
| namespace mgb::imperative::persistent_cache { | |||
| class RedisCache final : public ExtendedPersistentCache { | |||
| public: | |||
| RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { | |||
| m_local = std::make_shared<mgb::InMemoryPersistentCache>(); | |||
| } | |||
| bool connect(std::string ip, size_t port, std::string password) { | |||
| m_client.auth(password); | |||
| m_client.connect( | |||
| ip, port, | |||
| [](const std::string& host, std::size_t port, | |||
| cpp_redis::connect_state status) { | |||
| if (status == cpp_redis::connect_state::dropped) { | |||
| mgb_log("client disconnected from %s.", host.c_str()); | |||
| mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), | |||
| port); | |||
| } | |||
| }, | |||
| std::uint32_t(200)); | |||
| if (!m_client.is_connected()) { | |||
| return false; | |||
| } | |||
| auto flag = m_client.get("mgb-cache-flag"); | |||
| sync(); | |||
| return flag.get().ok(); | |||
| } | |||
| bool valid() const override { return m_client.is_connected(); } | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| auto mem_result = m_local->get(category, key); | |||
| if (mem_result.valid()) | |||
| return mem_result; | |||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
| std::string redis_key_str; | |||
| encode(category + '@' + key_str, redis_key_str, 24); | |||
| auto result = m_client.get(redis_key_str); | |||
| sync(); | |||
| auto content = result.get(); | |||
| if (content.is_null()) | |||
| return mgb::None; | |||
| std::string decode_content; | |||
| decode(content.as_string(), decode_content); | |||
| m_local->put(category, key, {decode_content.data(), decode_content.length()}); | |||
| return m_local->get(category, key); | |||
| } | |||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| std::string key_str(static_cast<const char*>(key.ptr), key.size); | |||
| std::string redis_key_str; | |||
| encode(category + '@' + key_str, redis_key_str); | |||
| std::string value_str(static_cast<const char*>(value.ptr), value.size); | |||
| std::string redis_value_str; | |||
| encode(value_str, redis_value_str); | |||
| auto result = m_client.set(redis_key_str, redis_value_str); | |||
| m_local->put(category, key, value); | |||
| sync(); | |||
| } | |||
| std::optional<size_t> clear() override { | |||
| size_t cursor = 0, nr_deleted = 0; | |||
| std::string pattern = m_prefix + "@*"; | |||
| do { | |||
| auto reply = m_client.scan(cursor, pattern).share(); | |||
| sync(); | |||
| auto keys = reply.get().as_array(); | |||
| std::vector<std::string> string_keys; | |||
| for (auto&& key : keys) { | |||
| string_keys.push_back(key.as_string()); | |||
| } | |||
| m_client.del(string_keys); | |||
| nr_deleted += string_keys.size(); | |||
| cursor = reply.get().as_array()[0].as_integer(); | |||
| } while (cursor != 0); | |||
| return nr_deleted; | |||
| } | |||
| private: | |||
| std::shared_ptr<mgb::PersistentCache> m_local; | |||
| std::mutex m_mtx; | |||
| cpp_redis::client m_client; | |||
| std::string m_prefix; | |||
| uint64_t m_timeout; | |||
| void sync() { | |||
| m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout)); | |||
| mgb_assert(valid()); | |||
| } | |||
| }; | |||
| class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { | |||
| private: | |||
| std::string m_path; | |||
| std::unique_ptr<mgb::InFilePersistentCache> m_impl; | |||
| public: | |||
| ExtendedInFilePersistentCache() = default; | |||
| bool open(std::string path) { | |||
| std::fstream file; | |||
| file.open(path, std::ios::in | std::ios::binary); | |||
| if (!file.is_open()) { | |||
| return false; | |||
| } | |||
| std::vector<char> bytes = { | |||
| std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; | |||
| if (bytes.size()) { | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>( | |||
| (const uint8_t*)bytes.data(), bytes.size()); | |||
| } else { | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
| } | |||
| m_path = path; | |||
| return true; | |||
| } | |||
| ~ExtendedInFilePersistentCache() { | |||
| if (m_impl) { | |||
| m_impl->dump_cache(m_path.c_str()); | |||
| } | |||
| } | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| return m_impl->get(category, key); | |||
| } | |||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
| return m_impl->put(category, key, value); | |||
| } | |||
| std::optional<size_t> clear() override { | |||
| m_impl = std::make_unique<mgb::InFilePersistentCache>(); | |||
| m_impl->dump_cache(m_path.c_str()); | |||
| return {}; | |||
| } | |||
| bool valid() const override { return m_impl != nullptr; } | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix) { | |||
| auto cache = std::make_shared<RedisCache>(prefix, 100); | |||
| if (!cache->connect(ip, port, password)) { | |||
| return nullptr; | |||
| } | |||
| return cache; | |||
| } | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) { | |||
| auto cache = std::make_shared<ExtendedInFilePersistentCache>(); | |||
| if (!cache->open(path)) { | |||
| return nullptr; | |||
| } | |||
| return cache; | |||
| } | |||
| } // namespace mgb::imperative::persistent_cache | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,172 @@ | |||
| /** | |||
| * \file imperative/src/impl/base64.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/utils/base64.h" | |||
| namespace mgb::imperative { | |||
| namespace { | |||
| /* | |||
| ** Translation Table as described in RFC1113 | |||
| */ | |||
| const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | |||
| /* | |||
| ** Translation Table to decode: | |||
| *https://github.com/dgiardini/imgcalkap/blob/master/base64.c | |||
| */ | |||
| const char cd64[] = | |||
| "|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`" | |||
| "abcdefghijklmnopq"; | |||
| /* | |||
| ** encodeblock | |||
| ** | |||
| ** encode 3 8-bit binary bytes as 4 '6-bit' characters | |||
| */ | |||
| void encodeblock(unsigned char in[3], unsigned char out[4], int len) { | |||
| out[0] = cb64[in[0] >> 2]; | |||
| out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)]; | |||
| out[2] = | |||
| (unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '='); | |||
| out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '='); | |||
| } | |||
| /* | |||
| ** decodeblock | |||
| ** | |||
| ** decode 4 '6-bit' characters into 3 8-bit binary bytes | |||
| */ | |||
| void decodeblock(unsigned char in[4], unsigned char out[3]) { | |||
| out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4); | |||
| out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2); | |||
| out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]); | |||
| } | |||
| } // namespace | |||
| /** | |||
| * Encode string to base64 string | |||
| * @param input - source string | |||
| * @param outdata - target base64 string | |||
| * @param linesize - max size of line | |||
| */ | |||
| void encode( | |||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||
| int linesize) { | |||
| outdata.clear(); | |||
| unsigned char in[3], out[4]; | |||
| int i, len, blocksout = 0; | |||
| size_t j = 0; | |||
| auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||
| unsigned int insize = input.size(); | |||
| while (j <= insize) { | |||
| len = 0; | |||
| for (i = 0; i < 3; i++) { | |||
| in[i] = (unsigned char)indata[j]; | |||
| j++; | |||
| if (j <= insize) { | |||
| len++; | |||
| } else { | |||
| in[i] = 0; | |||
| } | |||
| } | |||
| if (len) { | |||
| encodeblock(in, out, len); | |||
| for (i = 0; i < 4; i++) { | |||
| outdata.push_back(out[i]); | |||
| } | |||
| blocksout++; | |||
| } | |||
| if (blocksout >= (linesize / 4) || (j == insize)) { | |||
| if (blocksout) { | |||
| outdata.push_back('\r'); | |||
| outdata.push_back('\n'); | |||
| } | |||
| blocksout = 0; | |||
| } | |||
| } | |||
| } | |||
| /** | |||
| * Decode base64 string ot source | |||
| * @param input - base64 string | |||
| * @param outdata - source string | |||
| */ | |||
| void decode( | |||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) { | |||
| outdata.clear(); | |||
| unsigned char in[4], out[3], v; | |||
| int i, len; | |||
| size_t j = 0; | |||
| auto* indata = reinterpret_cast<const unsigned char*>(input.data()); | |||
| unsigned int insize = input.size(); | |||
| while (j <= insize) { | |||
| for (len = 0, i = 0; i < 4 && (j <= insize); i++) { | |||
| v = 0; | |||
| while ((j <= insize) && v == 0) { | |||
| v = (unsigned char)indata[j++]; | |||
| v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]); | |||
| if (v) { | |||
| v = (unsigned char)((v == '$') ? 0 : v - 61); | |||
| } | |||
| } | |||
| if (j <= insize) { | |||
| len++; | |||
| if (v) { | |||
| in[i] = (unsigned char)(v - 1); | |||
| } | |||
| } else { | |||
| in[i] = 0; | |||
| } | |||
| } | |||
| if (len) { | |||
| decodeblock(in, out); | |||
| for (i = 0; i < len - 1; i++) { | |||
| outdata.push_back(out[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /** | |||
| * Encode binary data to base64 buffer | |||
| * @param input - source data | |||
| * @param outdata - target base64 buffer | |||
| * @param linesize | |||
| */ | |||
| void encode(const std::string& input, std::string& outdata, int linesize) { | |||
| std::vector<std::uint8_t> out; | |||
| std::vector<std::uint8_t> in(input.begin(), input.end()); | |||
| encode(in, out, linesize); | |||
| outdata = std::string(out.begin(), out.end()); | |||
| } | |||
| /** | |||
| * Decode base64 buffer to source binary data | |||
| * @param input - base64 buffer | |||
| * @param outdata - source binary data | |||
| */ | |||
| void decode(const std::string& input, std::string& outdata) { | |||
| std::vector<std::uint8_t> in(input.begin(), input.end()); | |||
| std::vector<std::uint8_t> out; | |||
| decode(in, out); | |||
| outdata = std::string(out.begin(), out.end()); | |||
| } | |||
| } // namespace mgb::imperative | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/persistent_cache.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <memory> | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| namespace mgb::imperative::persistent_cache { | |||
| class ExtendedPersistentCache : public mgb::PersistentCache { | |||
| public: | |||
| virtual bool valid() const = 0; | |||
| virtual std::optional<size_t> clear() = 0; | |||
| }; | |||
| std::shared_ptr<ExtendedPersistentCache> make_redis( | |||
| std::string ip, size_t port, std::string password, std::string prefix); | |||
| std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path); | |||
| } // namespace mgb::imperative::persistent_cache | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/utils/base64.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/common.h" | |||
| namespace mgb::imperative { | |||
| /** | |||
| * Encode string to base64 string | |||
| * @param input - source string | |||
| * @param outdata - target base64 string | |||
| * @param linesize - max size of line | |||
| */ | |||
| void encode( | |||
| const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata, | |||
| int linesize = 76); | |||
| /** | |||
| * Decode base64 string ot source | |||
| * @param input - base64 string | |||
| * @param outdata - source string | |||
| */ | |||
| void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata); | |||
| /** | |||
| * Encode binary data to base64 buffer | |||
| * @param input - source data | |||
| * @param outdata - target base64 buffer | |||
| * @param linesize | |||
| */ | |||
| void encode(const std::string& input, std::string& outdata, int linesize = 76); | |||
| /** | |||
| * Decode base64 buffer to source binary data | |||
| * @param input - base64 buffer | |||
| * @param outdata - source binary data | |||
| */ | |||
| void decode(const std::string& input, std::string& outdata); | |||
| } // namespace mgb::imperative | |||
| @@ -12,7 +12,7 @@ endif() | |||
| # TODO: turn python binding into a static/object library | |||
| add_executable(imperative_test ${SOURCES} ${SRCS}) | |||
| add_dependencies(imperative_test mgb_opdef) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) | |||
| target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) | |||
| # Python binding | |||
| target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) | |||
| @@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { | |||
| }; | |||
| }; | |||
| Maybe<Blob> get(const std::string& category, const Blob& key) override; | |||
| void put(const std::string& category, const Blob& key, const Blob& value) override; | |||
| MGE_WIN_DECLSPEC_FUC Maybe<Blob> get( | |||
| const std::string& category, const Blob& key) override; | |||
| MGE_WIN_DECLSPEC_FUC void put( | |||
| const std::string& category, const Blob& key, const Blob& value) override; | |||
| std::unordered_map< | |||
| std::string, | |||
| std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> | |||
| m_cache; | |||
| MGB_MUTEX m_mtx; | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; | |||
| }; | |||
| /*! | |||