You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api_cache.h 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. /**
  2. * \file dnn/src/common/api_cache.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <unordered_map>
  14. #include <memory>
  15. #include <cstring>
  16. #include <tuple>
  17. #include "megdnn/thin/function.h"
  18. namespace megdnn {
  19. template <typename TSignature>
  20. class FunctionCache;
  21. template <typename TRet, typename... TArgs>
  22. class FunctionCache<TRet(TArgs...)> {
  23. public:
  24. using key_t = std::string;
  25. using value_t = TRet;
  26. using key_mapper_t = thin_function<key_t(TArgs...)>;
  27. using value_mapper_t = thin_function<value_t(TArgs...)>;
  28. using storage_t = std::unordered_map<key_t, value_t>;
  29. public:
  30. storage_t storage;
  31. key_mapper_t key_mapper;
  32. value_mapper_t value_mapper;
  33. public:
  34. TRet operator()(TArgs... args) {
  35. key_t key = key_mapper(args...);
  36. if (storage.count(key) == 0) {
  37. storage[key] = value_mapper(std::forward<TArgs>(args)...);
  38. }
  39. return storage[key];
  40. }
  41. };
  42. // FIFO
  43. class StringSerializer {
  44. private:
  45. std::string m_buffer;
  46. size_t m_cursor = 0;
  47. public:
  48. template <typename T>
  49. T read_plain() {
  50. T result;
  51. std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T));
  52. m_cursor += sizeof(T);
  53. return result;
  54. }
  55. template <typename T>
  56. void write_plain(T value) {
  57. m_buffer.resize(m_buffer.size() + sizeof(T));
  58. std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T));
  59. }
  60. std::string take() {
  61. std::string result;
  62. m_buffer.erase(0, m_cursor);
  63. return std::move(m_buffer);
  64. }
  65. void set(std::string new_buf) {
  66. m_cursor = 0;
  67. m_buffer = new_buf;
  68. }
  69. };
  70. struct Empty {};
  71. template <typename... TParams>
  72. class ParamBundle {
  73. private:
  74. template<std::size_t N, std::size_t... Seq>
  75. static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){
  76. return {};
  77. }
  78. template<std::size_t Min, std::size_t Max>
  79. using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>()));
  80. using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
  81. storage_t m_storage;
  82. template <typename TFunctor, size_t... Indices>
  83. auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
  84. return functor(std::get<Indices>(m_storage).value...);
  85. }
  86. template <size_t Index, size_t... Indices, typename TPrev>
  87. auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
  88. return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>());
  89. }
  90. template <typename TPrev>
  91. auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
  92. template <size_t Index, size_t... Indices, typename TPrev>
  93. auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
  94. return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>());
  95. }
  96. template <typename TPrev>
  97. auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
  98. template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
  99. void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) {
  100. std::get<Index>(m_storage).value = arg;
  101. set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...);
  102. }
  103. template <size_t... Indices>
  104. void set_values_helper(std::index_sequence<Indices...>) {
  105. static_assert(sizeof...(Indices) == 0, "redundant indices");
  106. }
  107. public:
  108. template <typename TFunctor>
  109. auto call_by(TFunctor&& functor) {
  110. return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>());
  111. }
  112. template <size_t NBegin, size_t NEnd>
  113. void serialize_params(StringSerializer& ser) {
  114. static_assert(NEnd >= NBegin, "invalid range");
  115. serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
  116. }
  117. template <size_t NBegin, size_t NEnd>
  118. void deserialize_params(StringSerializer& ser) {
  119. static_assert(NEnd >= NBegin, "invalid range");
  120. deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
  121. }
  122. template <size_t NBegin, size_t NEnd, typename... TArgs>
  123. void set_values(TArgs&&... args) {
  124. set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...);
  125. }
  126. };
  127. template <typename T>
  128. class RetParam {
  129. public:
  130. T value;
  131. Empty serialize(StringSerializer& ser, Empty) {
  132. ser.write_plain(value);
  133. return Empty{};
  134. }
  135. Empty deserialize(StringSerializer& ser, Empty) {
  136. value = ser.read_plain<T>();
  137. return Empty{};
  138. }
  139. };
  140. template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>>
  141. class FunctionCacheBuilder {
  142. private:
  143. static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; }
  144. template <size_t... Indices>
  145. static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; }
  146. static auto declfunction() {
  147. return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>());
  148. }
  149. template <size_t... Indices>
  150. static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; }
  151. static auto declbundle() {
  152. return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>());
  153. }
  154. using function_t = decltype(declfunction());
  155. using bundle_t = decltype(declbundle());
  156. public:
  157. template <typename TNewRet>
  158. auto ret() {
  159. static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition");
  160. return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
  161. }
  162. template <typename TNewInput>
  163. auto input() {
  164. using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>())));
  165. return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
  166. }
  167. template <typename TNewOutput>
  168. auto output() {
  169. using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>())));
  170. return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
  171. }
  172. template <typename TFunctor>
  173. function_t build(TFunctor func) {
  174. FunctionCache<std::string(bundle_t)> cache;
  175. cache.key_mapper = [](bundle_t bundle) {
  176. StringSerializer ser;
  177. bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser);
  178. return ser.take();
  179. };
  180. cache.value_mapper = [=](bundle_t bundle) {
  181. StringSerializer ser;
  182. TRet ret;
  183. ret.value = bundle.call_by(func);
  184. ret.serialize(ser, Empty{});
  185. bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser);
  186. return ser.take();
  187. };
  188. return [=](auto&&... args) mutable {
  189. bundle_t bundle;
  190. TRet ret;
  191. StringSerializer ser;
  192. static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value,
  193. "arg count mismatch");
  194. bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...);
  195. ser.set(cache(bundle));
  196. ret.deserialize(ser, Empty{});
  197. constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
  198. constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
  199. bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser);
  200. return ret.value;
  201. };
  202. }
  203. };
  204. template <typename T>
  205. class PlainParam {
  206. public:
  207. T value;
  208. Empty serialize(StringSerializer& ser, Empty) {
  209. ser.write_plain(value);
  210. return Empty{};
  211. }
  212. Empty deserialize(StringSerializer& ser, Empty) {
  213. value = ser.read_plain<T>();
  214. return Empty{};
  215. }
  216. };
  217. template <typename T>
  218. class RefParam {
  219. public:
  220. T* value;
  221. Empty serialize(StringSerializer& ser, Empty) {
  222. ser.write_plain(*value);
  223. return Empty{};
  224. }
  225. Empty deserialize(StringSerializer& ser, Empty) {
  226. *value = ser.read_plain<T>();
  227. return Empty{};
  228. }
  229. };
  230. template <typename T>
  231. class RefArraySizeParam {
  232. public:
  233. T* value;
  234. T serialize(StringSerializer& ser, Empty) {
  235. ser.write_plain(*value);
  236. return *value;
  237. }
  238. T deserialize(StringSerializer& ser, Empty) {
  239. return *value = ser.read_plain<T>();
  240. }
  241. };
  242. template <typename TSize, typename TItem>
  243. class ArrayParam {
  244. public:
  245. TItem* value;
  246. Empty serialize(StringSerializer& ser, TSize size) {
  247. for (TSize i = 0; i < size; ++i) {
  248. ser.write_plain(value[i]);
  249. }
  250. return Empty{};
  251. }
  252. Empty deserialize(StringSerializer& ser, TSize size) {
  253. for (TSize i = 0; i < size; ++i) {
  254. value[i] = ser.read_plain<TItem>();
  255. }
  256. return Empty{};
  257. }
  258. };
  259. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台