You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

persistent_cache.cpp 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #include <fstream>
  2. #include <string>
  3. #include <vector>
  4. #include "cpp_redis/cpp_redis"
  5. #include "megbrain/imperative/persistent_cache.h"
  6. #include "megbrain/imperative/utils/base64.h"
  7. #include "megbrain/utils/infile_persistent_cache.h"
  8. namespace mgb::imperative::persistent_cache {
  9. class RedisCache final : public ExtendedPersistentCache {
  10. public:
  11. RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
  12. m_local = std::make_shared<mgb::InMemoryPersistentCache>();
  13. }
  14. void connect(std::string ip, size_t port, std::optional<std::string> password) {
  15. if (password) {
  16. m_client.auth(*password);
  17. }
  18. m_client.connect(
  19. ip, port,
  20. [](const std::string& host, std::size_t port,
  21. cpp_redis::connect_state status) {
  22. if (status == cpp_redis::connect_state::dropped) {
  23. mgb_log("client disconnected from %s.", host.c_str());
  24. mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
  25. port);
  26. }
  27. },
  28. std::uint32_t(200));
  29. mgb_assert(m_client.is_connected(), "connect failed");
  30. auto flag = m_client.get("mgb-cache-flag");
  31. sync();
  32. auto is_valid = [](const cpp_redis::reply& reply) {
  33. switch (reply.get_type()) {
  34. case cpp_redis::reply::type::error:
  35. case cpp_redis::reply::type::null:
  36. return false;
  37. case cpp_redis::reply::type::integer:
  38. return reply.as_integer() != 0;
  39. case cpp_redis::reply::type::simple_string:
  40. case cpp_redis::reply::type::bulk_string:
  41. return !reply.as_string().empty();
  42. case cpp_redis::reply::type::array:
  43. return !reply.as_array().empty();
  44. default:
  45. mgb_assert(false, "unknown reply type %d", (int)reply.get_type());
  46. }
  47. };
  48. mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag");
  49. }
  50. bool valid() const override { return m_client.is_connected(); }
  51. void flush() override {}
  52. mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
  53. MGB_LOCK_GUARD(m_mtx);
  54. auto mem_result = m_local->get(category, key);
  55. if (mem_result.valid()) {
  56. return mem_result;
  57. }
  58. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  59. std::string redis_key_str;
  60. encode(category + '@' + key_str, redis_key_str, 24);
  61. auto result = m_client.get(m_prefix + redis_key_str);
  62. sync();
  63. auto content = result.get();
  64. if (content.is_null()) {
  65. return None;
  66. }
  67. std::string decode_content;
  68. mgb_assert(content.is_string());
  69. decode(content.as_string(), decode_content);
  70. m_local->put(category, key, {decode_content.data(), decode_content.length()});
  71. return m_local->get(category, key);
  72. }
  73. void put(const std::string& category, const Blob& key, const Blob& value) override {
  74. MGB_LOCK_GUARD(m_mtx);
  75. std::string key_str(static_cast<const char*>(key.ptr), key.size);
  76. std::string redis_key_str;
  77. encode(category + '@' + key_str, redis_key_str, 24);
  78. std::string value_str(static_cast<const char*>(value.ptr), value.size);
  79. std::string redis_value_str;
  80. encode(value_str, redis_value_str);
  81. auto result = m_client.set(m_prefix + redis_key_str, redis_value_str);
  82. m_local->put(category, key, value);
  83. sync();
  84. }
  85. std::optional<size_t> clear() override {
  86. size_t cursor = 0, nr_deleted = 0;
  87. std::string pattern = m_prefix + "@*";
  88. do {
  89. auto reply = m_client.scan(cursor, pattern).share();
  90. sync();
  91. auto keys = reply.get().as_array();
  92. std::vector<std::string> string_keys;
  93. for (auto&& key : keys) {
  94. string_keys.push_back(key.as_string());
  95. }
  96. m_client.del(string_keys);
  97. nr_deleted += string_keys.size();
  98. cursor = reply.get().as_array()[0].as_integer();
  99. } while (cursor != 0);
  100. return nr_deleted;
  101. }
  102. private:
  103. std::shared_ptr<mgb::PersistentCache> m_local;
  104. std::mutex m_mtx;
  105. cpp_redis::client m_client;
  106. std::string m_prefix;
  107. uint64_t m_timeout;
  108. void sync() {
  109. m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
  110. mgb_assert(valid());
  111. }
  112. };
  113. class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
  114. private:
  115. std::optional<std::string> m_path;
  116. std::unique_ptr<mgb::InFilePersistentCache> m_impl;
  117. public:
  118. ExtendedInFilePersistentCache() = default;
  119. void open(std::string path) {
  120. std::fstream file;
  121. file.open(path, std::ios::in | std::ios::binary);
  122. mgb_assert(file.is_open(), "can't open file in %s", path.c_str());
  123. std::vector<char> bytes = {
  124. std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
  125. if (bytes.size()) {
  126. m_impl = std::make_unique<mgb::InFilePersistentCache>(
  127. (const uint8_t*)bytes.data(), bytes.size());
  128. } else {
  129. m_impl = std::make_unique<mgb::InFilePersistentCache>();
  130. }
  131. m_path = path;
  132. }
  133. void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); }
  134. ~ExtendedInFilePersistentCache() { flush(); }
  135. mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
  136. return m_impl->get(category, key);
  137. }
  138. void put(const std::string& category, const Blob& key, const Blob& value) override {
  139. return m_impl->put(category, key, value);
  140. }
  141. std::optional<size_t> clear() override {
  142. if (m_impl) {
  143. m_impl = std::make_unique<mgb::InFilePersistentCache>();
  144. if (m_path) {
  145. m_impl->dump_cache(m_path->c_str());
  146. }
  147. }
  148. return {};
  149. }
  150. bool valid() const override { return m_impl != nullptr; }
  151. void flush() override {
  152. if (m_impl && m_path) {
  153. m_impl->dump_cache(m_path->c_str());
  154. }
  155. }
  156. };
  157. std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config(
  158. std::string type, std::unordered_map<std::string, std::string> args,
  159. std::string& err_msg) {
  160. try {
  161. if (type == "redis") {
  162. std::string prefix = args.at("prefix");
  163. std::optional<std::string> password = args.count("password")
  164. ? args.at("password")
  165. : std::optional<std::string>();
  166. auto cache = std::make_shared<RedisCache>(prefix, 100);
  167. if (args.count("unixsocket")) {
  168. std::string unixsocket = args.at("unixsocket");
  169. cache->connect(unixsocket, 0, password);
  170. } else {
  171. std::string ip = args.at("hostname");
  172. int port = atoi(args.at("port").c_str());
  173. std::optional<std::string> password =
  174. args.count("password") ? args.at("password")
  175. : std::optional<std::string>();
  176. cache->connect(ip, port, password);
  177. }
  178. return cache;
  179. } else if (type == "in-file") {
  180. std::string path = args.at("path");
  181. auto cache = std::make_shared<ExtendedInFilePersistentCache>();
  182. cache->open(path);
  183. return cache;
  184. } else if (type == "in-memory") {
  185. auto cache = std::make_shared<ExtendedInFilePersistentCache>();
  186. cache->open();
  187. return cache;
  188. } else {
  189. mgb_assert(false, "persistent cache type %s unsupported", type.c_str());
  190. }
  191. } catch (const std::exception& exc) {
  192. err_msg = exc.what();
  193. } catch (...) {
  194. err_msg = "unknown exception";
  195. }
  196. return nullptr;
  197. }
  198. } // namespace mgb::imperative::persistent_cache
  199. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}