You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.cpp 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /**
  2. * \file imperative/src/impl/persistent_cache.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <fstream>
  12. #include <string>
  13. #include <vector>
  14. #include "cpp_redis/cpp_redis"
  15. #include "megbrain/imperative/persistent_cache.h"
  16. #include "megbrain/imperative/utils/base64.h"
  17. #include "megbrain/utils/infile_persistent_cache.h"
  18. namespace mgb::imperative::persistent_cache {
  19. class RedisCache final : public ExtendedPersistentCache {
  20. public:
  21. RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
  22. m_local = std::make_shared<mgb::InMemoryPersistentCache>();
  23. }
  24. bool connect(std::string ip, size_t port, std::string password) {
  25. m_client.auth(password);
  26. m_client.connect(
  27. ip, port,
  28. [](const std::string& host, std::size_t port,
  29. cpp_redis::connect_state status) {
  30. if (status == cpp_redis::connect_state::dropped) {
  31. mgb_log("client disconnected from %s.", host.c_str());
  32. mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
  33. port);
  34. }
  35. },
  36. std::uint32_t(200));
  37. if (!m_client.is_connected()) {
  38. return false;
  39. }
  40. auto flag = m_client.get("mgb-cache-flag");
  41. sync();
  42. return flag.get().ok();
  43. }
  44. bool valid() const override { return m_client.is_connected(); }
  45. mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
  46. MGB_LOCK_GUARD(m_mtx);
  47. auto mem_result = m_local->get(category, key);
  48. if (mem_result.valid())
  49. return mem_result;
  50. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  51. std::string redis_key_str;
  52. encode(category + '@' + key_str, redis_key_str, 24);
  53. auto result = m_client.get(redis_key_str);
  54. sync();
  55. auto content = result.get();
  56. if (content.is_null())
  57. return mgb::None;
  58. std::string decode_content;
  59. decode(content.as_string(), decode_content);
  60. m_local->put(category, key, {decode_content.data(), decode_content.length()});
  61. return m_local->get(category, key);
  62. }
  63. void put(const std::string& category, const Blob& key, const Blob& value) override {
  64. MGB_LOCK_GUARD(m_mtx);
  65. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  66. std::string redis_key_str;
  67. encode(category + '@' + key_str, redis_key_str);
  68. std::string value_str(static_cast<const char*>(value.ptr), value.size);
  69. std::string redis_value_str;
  70. encode(value_str, redis_value_str);
  71. auto result = m_client.set(redis_key_str, redis_value_str);
  72. m_local->put(category, key, value);
  73. sync();
  74. }
  75. std::optional<size_t> clear() override {
  76. size_t cursor = 0, nr_deleted = 0;
  77. std::string pattern = m_prefix + "@*";
  78. do {
  79. auto reply = m_client.scan(cursor, pattern).share();
  80. sync();
  81. auto keys = reply.get().as_array();
  82. std::vector<std::string> string_keys;
  83. for (auto&& key : keys) {
  84. string_keys.push_back(key.as_string());
  85. }
  86. m_client.del(string_keys);
  87. nr_deleted += string_keys.size();
  88. cursor = reply.get().as_array()[0].as_integer();
  89. } while (cursor != 0);
  90. return nr_deleted;
  91. }
  92. private:
  93. std::shared_ptr<mgb::PersistentCache> m_local;
  94. std::mutex m_mtx;
  95. cpp_redis::client m_client;
  96. std::string m_prefix;
  97. uint64_t m_timeout;
  98. void sync() {
  99. m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
  100. mgb_assert(valid());
  101. }
  102. };
  103. class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
  104. private:
  105. std::string m_path;
  106. std::unique_ptr<mgb::InFilePersistentCache> m_impl;
  107. public:
  108. ExtendedInFilePersistentCache() = default;
  109. bool open(std::string path) {
  110. std::fstream file;
  111. file.open(path, std::ios::in | std::ios::binary);
  112. if (!file.is_open()) {
  113. return false;
  114. }
  115. std::vector<char> bytes = {
  116. std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
  117. if (bytes.size()) {
  118. m_impl = std::make_unique<mgb::InFilePersistentCache>(
  119. (const uint8_t*)bytes.data(), bytes.size());
  120. } else {
  121. m_impl = std::make_unique<mgb::InFilePersistentCache>();
  122. }
  123. m_path = path;
  124. return true;
  125. }
  126. ~ExtendedInFilePersistentCache() {
  127. if (m_impl) {
  128. m_impl->dump_cache(m_path.c_str());
  129. }
  130. }
  131. mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
  132. return m_impl->get(category, key);
  133. }
  134. void put(const std::string& category, const Blob& key, const Blob& value) override {
  135. return m_impl->put(category, key, value);
  136. }
  137. std::optional<size_t> clear() override {
  138. m_impl = std::make_unique<mgb::InFilePersistentCache>();
  139. m_impl->dump_cache(m_path.c_str());
  140. return {};
  141. }
  142. bool valid() const override { return m_impl != nullptr; }
  143. };
  144. std::shared_ptr<ExtendedPersistentCache> make_redis(
  145. std::string ip, size_t port, std::string password, std::string prefix) {
  146. auto cache = std::make_shared<RedisCache>(prefix, 100);
  147. if (!cache->connect(ip, port, password)) {
  148. return nullptr;
  149. }
  150. return cache;
  151. }
  152. std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
  153. auto cache = std::make_shared<ExtendedInFilePersistentCache>();
  154. if (!cache->open(path)) {
  155. return nullptr;
  156. }
  157. return cache;
  158. }
  159. } // namespace mgb::imperative::persistent_cache
  160. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}