| @@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
| from .serialization import load, save | |||
| from .tensor import Parameter, Tensor, tensor | |||
| from .version import __version__ | |||
| from .utils import persistent_cache, comp_graph_tools as cgtools | |||
| _set_fork_exec_path_for_timed_func( | |||
| sys.executable, | |||
| os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), | |||
| ) | |||
| _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() | |||
| _persistent_cache_impl_ins.reg() | |||
| atexit.register(sync) | |||
| del sync | |||
| del _set_fork_exec_path_for_timed_func | |||
| del _persistent_cache_impl_ins | |||
| @@ -0,0 +1,90 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import argparse | |||
| import getpass | |||
| import json | |||
| import os | |||
| import shelve | |||
| from ..core._imperative_rt import PersistentCache as _PersistentCache | |||
| from ..logger import get_logger | |||
| from ..version import __version__ | |||
| class _FakeRedisConn: | |||
| def __init__(self): | |||
| try: | |||
| from ..hub.hub import _get_megengine_home | |||
| cache_dir = os.path.expanduser( | |||
| os.path.join(_get_megengine_home(), "persistent_cache") | |||
| ) | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| cache_file = os.path.join(cache_dir, "cache") | |||
| self._dict = shelve.open(cache_file) | |||
| self._is_shelve = True | |||
| except: | |||
| self._dict = {} | |||
| self._is_shelve = False | |||
| def get(self, key): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| return self._dict.get(key) | |||
| def set(self, key, val): | |||
| if self._is_shelve and isinstance(key, bytes): | |||
| key = key.decode("utf-8") | |||
| self._dict[key] = val | |||
| def __del__(self): | |||
| if self._is_shelve: | |||
| self._dict.close() | |||
| class PersistentCacheOnServer(_PersistentCache): | |||
| _cached_conn = None | |||
| _prefix = None | |||
| _prev_get_refkeep = None | |||
| @property | |||
| def _conn(self): | |||
| """get redis connection""" | |||
| if self._cached_conn is None: | |||
| self._cached_conn = _FakeRedisConn() | |||
| self._prefix = self.make_user_prefix() | |||
| return self._cached_conn | |||
| @classmethod | |||
| def make_user_prefix(cls): | |||
| return "mgbcache:{}".format(getpass.getuser()) | |||
| def _make_key(self, category, key): | |||
| prefix_with_version = "{}:MGB{}".format(self._prefix, __version__) | |||
| return b"@".join( | |||
| (prefix_with_version.encode("ascii"), category.encode("ascii"), key) | |||
| ) | |||
| def put(self, category, key, value): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| conn.set(key, value) | |||
| def get(self, category, key): | |||
| conn = self._conn | |||
| key = self._make_key(category, key) | |||
| self._prev_get_refkeep = conn.get(key) | |||
| return self._prev_get_refkeep | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #include <Python.h> | |||
| #include <string> | |||
| @@ -328,6 +329,49 @@ namespace detail { | |||
| template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {}; | |||
| template <> struct type_caster<mgb::PersistentCache::Blob> { | |||
| PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob")); | |||
| public: | |||
| bool load(handle src, bool convert) { | |||
| if (!isinstance<bytes>(src)) { | |||
| return false; | |||
| } | |||
| value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr()); | |||
| value.size = PYBIND11_BYTES_SIZE(src.ptr()); | |||
| return true; | |||
| } | |||
| static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) { | |||
| return bytes((const char*)blob.ptr, blob.size); | |||
| } | |||
| }; | |||
| template <typename T> struct type_caster<mgb::Maybe<T>> { | |||
| using value_conv = make_caster<T>; | |||
| PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]")); | |||
| public: | |||
| bool load(handle src, bool convert) { | |||
| if(!src) { | |||
| return false; | |||
| } | |||
| if (src.is_none()) { | |||
| return true; | |||
| } | |||
| value_conv inner_caster; | |||
| if (!inner_caster.load(src, convert)) { | |||
| return false; | |||
| } | |||
| value.emplace(cast_op<T&&>(std::move(inner_caster))); | |||
| return true; | |||
| } | |||
| static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) { | |||
| if(!src.valid()) { | |||
| return none().inc_ref(); | |||
| } | |||
| return pybind11::cast(src.val(), policy, parent); | |||
| } | |||
| }; | |||
| } // detail | |||
| } // PYBIND11_NAMESPACE | |||
| @@ -25,6 +25,7 @@ | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "megbrain/imperative/tensor_sanity_check.h" | |||
| #include "megbrain/serialization/helper.h" | |||
| #include "megbrain/utils/persistent_cache.h" | |||
| #if MGB_ENABLE_OPR_MM | |||
| #include "megbrain/opr/mm_handler.h" | |||
| @@ -262,4 +263,20 @@ void init_utils(py::module m) { | |||
| m.def("_timed_func_exec_cb", [](const std::string& user_data){ | |||
| mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); | |||
| }); | |||
| using mgb::PersistentCache; | |||
| class PyPersistentCache: public mgb::PersistentCache{ | |||
| public: | |||
| mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { | |||
| PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key); | |||
| } | |||
| void put(const std::string& category, const Blob& key, const Blob& value) override { | |||
| PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); | |||
| } | |||
| }; | |||
| py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache") | |||
| .def(py::init<>()) | |||
| .def("get", &PersistentCache::get) | |||
| .def("put", &PersistentCache::put) | |||
| .def("reg", &PersistentCache::set_impl); | |||
| } | |||
| @@ -0,0 +1,16 @@ | |||
| import pytest | |||
| import megengine | |||
| from megengine.utils.persistent_cache import PersistentCacheOnServer | |||
| def test_persistent_cache(): | |||
| pc = PersistentCacheOnServer() | |||
| k0 = b"\x00\x00" | |||
| k1 = b"\x00\x01" | |||
| cat = "test" | |||
| pc.put(cat, k0, k1) | |||
| pc.put(cat, k1, k0) | |||
| assert k1 == pc.get(cat, k0) | |||
| assert k0 == pc.get(cat, k1) | |||
| assert pc.get("test1", k0) == None | |||