You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <deque>
  13. #include <future>
  14. #include <list>
  15. #include <stack>
  16. #include <thread>
  17. #include <unordered_set>
  18. #include <variant>
  19. #include "megbrain/comp_node.h"
  20. #include "megbrain/utils/mempool.h"
  21. #include "megbrain/imperative/interpreter.h"
  22. #include "megbrain/imperative/profiler.h"
  23. #include "./commands.h"
  24. #include "./tensor_info.h"
  25. #include "./option_manager.h"
  26. #include "./stack_manager.h"
  27. #include "../profiler/events.h"
  28. namespace mgb::imperative::interpreter::intl {
  29. using Handle = Interpreter::Handle;
  30. struct InterpreterImpl : Interpreter {
  31. std::unique_ptr<Channel> create_channel() override;
  32. };
  33. struct ChannelImpl : Interpreter::Channel {
  34. ChannelImpl();
  35. ~ChannelImpl() override;
  36. Handle put(const HostTensorND& value, bool no_cache) override;
  37. Handle put(const DeviceTensorND& value, const HostTensorND& hvalue) override;
  38. void del(Handle) override;
  39. void swap_in(Handle) override;
  40. void swap_out(Handle) override;
  41. void drop(Handle) override;
  42. SmallVector<Handle> apply_op(
  43. std::shared_ptr<OpDef> op,
  44. const SmallVector<Handle>& inputs) override;
  45. HostTensorND get_value(Handle) override;
  46. TensorShape get_shape(Handle) override;
  47. DType get_dtype(Handle) override;
  48. CompNode get_device(Handle) override;
  49. DeviceTensorND get_dev_tensor(Handle) override;
  50. bool check_available() override;
  51. void sync() override;
  52. void close() override;
  53. size_t get_option(std::string name) override;
  54. void set_option(std::string name, size_t value) override;
  55. void start_profile() override;
  56. void stop_profile() override;
  57. void push_scope(std::string) override;
  58. void pop_scope(std::string) override;
  59. private:
  60. struct WorkQueue;
  61. struct State;
  62. TensorInfo* alloc();
  63. void init(TensorInfo*, LogicalTensorDesc desc);
  64. void free(TensorInfo*);
  65. void real_free(TensorInfo*);
  66. void recursive_free(TensorInfo*);
  67. void do_drop(TensorInfo*, bool);
  68. void detach_users(TensorInfo*);
  69. TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
  70. TensorInfo* put_impl(const DeviceTensorND& value, const HostTensorND& hvalue);
  71. void del_impl(Handle);
  72. void sync_impl();
  73. SmallVector<Handle> apply_op_impl(
  74. std::shared_ptr<OpDef> op,
  75. const SmallVector<Handle>& inputs);
  76. TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
  77. void notify_tensor_unsafe(TensorInfo* info);
  78. void process_one_task(Command&);
  79. void check_worker_exc_unsafe();
  80. void produce_tensor(TensorInfo* dest, TensorPtr ptr);
  81. void release_tensor(TensorInfo* dest);
  82. void regenerate(TensorInfo* dest);
  83. void flush_apply_stack();
  84. void do_apply_op(const ApplyOp& cmd, std::string reason);
  85. std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> init_output_and_workspace(
  86. const OpDef& def,
  87. SmallVector<TensorPtr> inputs,
  88. SmallVector<MemoryDesc> inputs_mem_desc);
  89. void dispatch_default_cpu(
  90. std::shared_ptr<OpDef> op,
  91. const SmallVector<TensorInfo*>& input_infos,
  92. const SmallVector<LogicalTensorDesc>& input_descs,
  93. SmallVector<Handle>* outputs);
  94. void dispatch_kernel(
  95. std::shared_ptr<OpDef> op,
  96. const SmallVector<TensorInfo*>& input_infos,
  97. const SmallVector<LogicalTensorDesc>& input_descs,
  98. SmallVector<Handle>* outputs);
  99. void push_scope(std::string, State&);
  100. void pop_scope(std::string, State&);
  101. void assert_in_channel();
  102. void assert_in_worker();
  103. std::thread::id get_worker_tid();
  104. // template <typename TCommand>
  105. // void enqueue_command(TCommand&& cmd) {
  106. // m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
  107. // }
  108. void sample_on_device(CompNode device, bool force);
  109. // valid => status != Deleted
  110. std::unordered_set<TensorInfo*> collect_valid_tensors();
  111. std::mutex m_mutex;
  112. Spinlock m_spin;
  113. std::condition_variable m_cv;
  114. MemPool<TensorInfo> m_pool;
  115. std::unordered_set<Handle> m_valid_handle;
  116. TensorInfo* m_waitee = nullptr;
  117. uint64_t m_waitee_id = 0;
  118. std::exception_ptr m_worker_exc;
  119. std::function<void(std::string, std::string)> m_profile_dump_callback;
  120. size_t m_storage_id = 0;
  121. // TODO: use explicit struct
  122. std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
  123. bool m_applying = false;
  124. bool m_closed = false;
  125. struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
  126. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
  127. // this won't affect throughput when python interpreter is sending enough task,
  128. // but will significantly save CPU time when waiting for task, e.g. wait for data input
  129. // limit pending tasks to 10000
  130. WorkQueue(ChannelImpl* owner)
  131. : AsyncQueueSC<Command, WorkQueue>(0, 10000), m_owner(owner) {
  132. sys::set_thread_name("interpreter");
  133. if (const char* env_val = MGB_GETENV("MEGENGINE_ASYNC_QUEUE_SIZE")) {
  134. int len = strlen(env_val);
  135. for (int i = 0; i < len; i ++) {
  136. mgb_assert(env_val[i] >= '0' && env_val[i] <= '9', "async queue size should be an integer");
  137. }
  138. size_t val;
  139. sscanf(env_val, "%zu", &val);
  140. update_max_items(val);
  141. }
  142. }
  143. void process_one_task(Command& icmd) {
  144. m_owner->process_one_task(icmd);
  145. }
  146. void on_async_queue_worker_thread_start() override;
  147. private:
  148. ChannelImpl* m_owner;
  149. } m_worker;
  150. /**
  151. * Buf a command window for following fuse
  152. * example:
  153. * ---------------------------------------------------------------------
  154. * | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
  155. * ---------------------------------------------------------------------
  156. * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
  157. * ---------------------------------------------------------------------
  158. * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
  159. * ---------------------------------------------------------------------
  160. * Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
  161. */
  162. struct CommandBuffer {
  163. CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
  164. void enqueue(CommandData cmd);
  165. bool empty() const {
  166. return m_commands.empty();
  167. }
  168. void flush();
  169. private:
  170. ChannelImpl* m_owner;
  171. std::deque<Command> m_commands;
  172. using Handle = decltype(m_commands)::iterator;
  173. // [begin, end)
  174. using Range = std::array<Handle, 2>;
  175. // Launch commands in range [m_commands.begin(), pos)
  176. void flush(Handle pos);
  177. // Select flush position for incoming cmd
  178. Handle flush_pos_for(const Command& cmd);
  179. // Fuse del command into suitable ApplyOp
  180. bool fuse_del(const Del& cmd);
  181. // Returns the last handle that dest is used within range. If dest is not used, returns range[1]
  182. Handle find_last_usage(TensorInfo* dest, Range range);
  183. // Returns the produce position of dest. If not found, returns range[1]
  184. Handle find_produce(TensorInfo* dest, Range range);
  185. } m_buffer;
  186. //! config whether raise error exactly when invoking op.
  187. //! level 2: both device and user side errors are async;
  188. //! level 1: user side errors are sync;
  189. //! level 0: both sync.
  190. int m_async_level = 2;
  191. struct State {
  192. std::thread::id tid;
  193. OptionManager options;
  194. };
  195. struct ChannelState: State {
  196. StackManager stack_manager;
  197. };
  198. struct WorkerState: State {};
  199. ChannelState m_channel_state;
  200. WorkerState m_worker_state;
  201. /*!
  202. * \brief A framework of dynamic sublienar memory optimization
  203. *
  204. * Note: The main idea is that during the training process, if the memory
  205. * usage exceeds the threshold, select some tensors to evict until the
  206. * memory usage is below the threshold.
  207. */
  208. struct DynamicSublinear {
  209. /*!
  210. * \brief find an available tensor with the largest evaluation function
  211. *
  212. * Note: An available tensor must satisfy: (1) has computing path,
  213. * (2) is in memory, (3) is not pinned. Evaluation function refers to:
  214. * @see: TensorInfo::eval_func.
  215. *
  216. * \return the pointer of the best tensor; nullptr is returned if no
  217. * available tensor is found
  218. */
  219. TensorInfo* find_best_tensor(bool);
  220. /*!
  221. * \brief estimate the cost of recomputing tensor ptr
  222. *
  223. * Note: We define the cost as the sum of the costs of each evicted
  224. * components where all the neighbors of ptr are located.
  225. */
  226. double estimate_neighbor_cost(TensorInfo* ptr);
  227. /*!
  228. * \brief update the last used time of the tensor ptr
  229. */
  230. void update_used_time(TensorInfo* ptr);
  231. /*!
  232. * \brief merge the two specified sets (the set in which the element x
  233. * is located, and the set in which the element y is located)
  234. */
  235. void merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y);
  236. /*!
  237. * \brief return the representative of the set that contains the
  238. * element x
  239. */
  240. std::shared_ptr<DsuNode> find_father(std::shared_ptr<DsuNode> &x);
  241. /*!
  242. * \brief update DSU after recomputing tensor ptr
  243. *
  244. * Delete ptr from the set where ptr is located. Since DSU does not
  245. * support this operation, instead, we reset the DSU father of ptr, and
  246. * subtract the recomputation cost of ptr from the cost of the original
  247. * set.
  248. */
  249. void update_dsu_after_recompute(TensorInfo* ptr);
  250. /*!
  251. * \brief update DSU after evicting tensor ptr
  252. *
  253. * Check the neighbors of x, that is, the input and output tensors, and
  254. * if they are evicted, merge their respective sets.
  255. */
  256. void update_dsu_after_evict(TensorInfo* ptr);
  257. /*!
  258. * \brief pin the tensors in vec
  259. */
  260. void pin(const SmallVector<TensorInfo*>& vec);
  261. /*!
  262. * \brief unpin the tensors in vec
  263. */
  264. void unpin(const SmallVector<TensorInfo*>& vec);
  265. /*!
  266. * \brief add the tensor to the candidate set
  267. *
  268. * If the size of the tensor does not exceed the minimum threshold,
  269. * it will do nothing.
  270. */
  271. void insert_candidate(TensorInfo* ptr);
  272. /*!
  273. * \brief erase the tensor from the candidate set
  274. *
  275. * If the size of the tensor does not exceed the minimum threshold,
  276. * it will do nothing.
  277. */
  278. void erase_candidate(TensorInfo* ptr);
  279. //! estimate the current time, in order to reduce the overhead of timer
  280. double estimate_timestamp = 0;
  281. //! the comp node where dynamic sublinear memory optimization works
  282. CompNode comp_node;
  283. //! store all tensors that may be evicted
  284. std::unordered_set<TensorInfo*> candidates;
  285. bool is_bad_op(std::string op_name) {
  286. return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end();
  287. }
  288. std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd",
  289. "ParamPackSplit", "ParamPackConcat", "GaussianRNG", "UniformRNG",
  290. "GammaRNG", "PermutationRNG", "PoissonRNG", "BetaRNG"};
  291. } m_dtr;
  292. //! automatically evict an optimal tensor
  293. bool auto_evict(size_t);
  294. void alloc_tensor_with_evict(Blob*);
  295. // assert thread id when call get_xxx_state to avoid misuse
  296. ChannelState& get_channel_state();
  297. WorkerState& get_worker_state();
  298. };
  299. } // namespace mgb::imperative::interpreter::intl

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台