You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.h 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <deque>
  13. #include <future>
  14. #include <list>
  15. #include <thread>
  16. #include <unordered_set>
  17. #include <variant>
  18. #include "megbrain/comp_node.h"
  19. #include "megbrain/utils/mempool.h"
  20. #include "megbrain/imperative/interpreter.h"
  21. #include "megbrain/imperative/profiler.h"
  22. #include "./commands.h"
  23. #include "./tensor_info.h"
  24. #include "./option_manager.h"
  25. #include "../profiler/events.h"
  26. namespace mgb::imperative::interpreter::intl {
  27. using Handle = Interpreter::Handle;
  28. struct InterpreterImpl : Interpreter {
  29. std::unique_ptr<Channel> create_channel() override;
  30. };
  31. struct ChannelImpl : Interpreter::Channel {
  32. ChannelImpl();
  33. ~ChannelImpl() override;
  34. Handle put(const HostTensorND& value, bool no_cache) override;
  35. Handle put(const DeviceTensorND& value) override;
  36. void del(Handle) override;
  37. void swap_in(Handle) override;
  38. void swap_out(Handle) override;
  39. void drop(Handle) override;
  40. SmallVector<Handle> apply_op(
  41. std::shared_ptr<OpDef> op,
  42. const SmallVector<Handle>& inputs) override;
  43. HostTensorND get_value(Handle) override;
  44. TensorShape get_shape(Handle) override;
  45. DType get_dtype(Handle) override;
  46. CompNode get_device(Handle) override;
  47. DeviceTensorND get_dev_tensor(Handle) override;
  48. void sync() override;
  49. void close() override;
  50. size_t get_option(std::string name) override;
  51. void set_option(std::string name, size_t value) override;
  52. void start_profile() override;
  53. void stop_profile() override;
  54. void push_scope(std::string) override;
  55. void pop_scope(std::string) override;
  56. private:
  57. struct WorkQueue;
  58. struct State;
  59. TensorInfo* alloc();
  60. void init(TensorInfo*, LogicalTensorDesc desc);
  61. void free(TensorInfo*);
  62. void real_free(TensorInfo*);
  63. void recursive_free(TensorInfo*);
  64. void do_drop(TensorInfo*, bool);
  65. void detach_users(TensorInfo*);
  66. TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
  67. TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
  68. void notify_tensor_unsafe(TensorInfo* info);
  69. void process_one_task(IdentifiedCommand&);
  70. void check_worker_exc_unsafe();
  71. void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
  72. void release_tensor(TensorInfo* dest);
  73. void regenerate(TensorInfo* dest);
  74. void recompute(TensorInfo::ComputePath* path);
  75. void do_apply_op(const ApplyOp& cmd);
  76. void dispatch_default_cpu(
  77. std::shared_ptr<OpDef> op,
  78. const SmallVector<TensorInfo*>& input_infos,
  79. const SmallVector<LogicalTensorDesc>& input_descs,
  80. SmallVector<Handle>* outputs);
  81. void dispatch_kernel(
  82. std::shared_ptr<OpDef> op,
  83. const SmallVector<TensorInfo*>& input_infos,
  84. const SmallVector<LogicalTensorDesc>& input_descs,
  85. SmallVector<Handle>* outputs);
  86. bool check_available();
  87. void push_scope(std::string, State&);
  88. void pop_scope(std::string, State&);
  89. void assert_in_channel();
  90. void assert_in_worker();
  91. std::thread::id get_worker_tid();
  92. template <typename TCommand>
  93. void enqueue_command(TCommand&& cmd) {
  94. m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
  95. }
  96. void sample_on_device(CompNode device, bool force);
  97. // valid => status != Deleted
  98. std::unordered_set<TensorInfo*> collect_valid_tensors();
  99. std::mutex m_mutex;
  100. std::condition_variable m_cv;
  101. MemPool<TensorInfo> m_pool;
  102. std::unordered_set<Handle> m_valid_handle;
  103. TensorInfo* m_waitee = nullptr;
  104. uint64_t m_waitee_id = 0;
  105. std::exception_ptr m_worker_exc;
  106. std::function<void(std::string, std::string)> m_profile_dump_callback;
  107. bool m_closed = false;
  108. struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
  109. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
  110. // this won't affect throughput when python interpreter is sending enough task,
  111. // but will significantly save CPU time when waiting for task, e.g. wait for data input
  112. // limit pending tasks to 1000000
  113. WorkQueue(ChannelImpl* owner)
  114. : AsyncQueueSC<IdentifiedCommand, WorkQueue>(0, 1000000), m_owner(owner) {
  115. sys::set_thread_name("interpreter");
  116. }
  117. void process_one_task(IdentifiedCommand& icmd) {
  118. m_owner->process_one_task(icmd);
  119. }
  120. void on_async_queue_worker_thread_start() override {
  121. sys::set_thread_name("worker");
  122. m_owner->m_worker_state.tid = std::this_thread::get_id();
  123. }
  124. private:
  125. ChannelImpl* m_owner;
  126. } m_worker;
  127. /**
  128. * Buf a command window for following fuse
  129. * example:
  130. * ---------------------------------------------------------------------
  131. * | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
  132. * ---------------------------------------------------------------------
  133. * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
  134. * ---------------------------------------------------------------------
  135. * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
  136. * ---------------------------------------------------------------------
  137. * Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task
  138. */
  139. struct CommandBuffer {
  140. CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
  141. void enqueue(Command cmd);
  142. bool empty() const {
  143. return m_commands.empty();
  144. }
  145. void flush();
  146. private:
  147. ChannelImpl* m_owner;
  148. std::deque<Command> m_commands;
  149. using Handle = decltype(m_commands)::iterator;
  150. // [begin, end)
  151. using Range = std::array<Handle, 2>;
  152. // Launch commands in range [m_commands.begin(), pos)
  153. void flush(Handle pos);
  154. // Select flush position for incoming cmd
  155. Handle flush_pos_for(const Command& cmd);
  156. // Fuse del command into suitable ApplyOp
  157. bool fuse_del(const Del& cmd);
  158. // Returns the last handle that dest is used within range. If dest is not used, returns range[1]
  159. Handle find_last_usage(TensorInfo* dest, Range range);
  160. // Returns the produce position of dest. If not found, returns range[1]
  161. Handle find_produce(TensorInfo* dest, Range range);
  162. } m_buffer;
  163. //! config whether raise error exactly when invoking op.
  164. //! level 2: both device and user side errors are async;
  165. //! level 1: user side errors are sync;
  166. //! level 0: both sync.
  167. int m_async_level = 2;
  168. struct Scope {
  169. std::string name;
  170. std::unordered_map<std::string, std::unique_ptr<Scope>> children;
  171. size_t version = 0;
  172. size_t parent_version = 0;
  173. size_t tensor_count = 0;
  174. Scope* active_child = nullptr;
  175. Scope* parent = nullptr;
  176. Scope* enter(std::string name) {
  177. auto& child = children[name];
  178. if (!child) {
  179. child = std::make_unique<Scope>();
  180. child->name = name;
  181. child->parent = this;
  182. }
  183. if (version != child->parent_version) {
  184. child->version = 0;
  185. child->parent_version = version;
  186. } else {
  187. child->version++;
  188. }
  189. child->tensor_count = 0;
  190. return active_child = child.get();
  191. }
  192. Scope* exit(std::string name) {
  193. mgb_assert(this->name == name, "scope name mismatch");
  194. parent->active_child = nullptr;
  195. return parent;
  196. }
  197. };
  198. class ScopeManager {
  199. private:
  200. Scope m_root;
  201. Scope* m_current_scope = &m_root;
  202. public:
  203. class ScopeGuard{
  204. private:
  205. ScopeManager* m_manager;
  206. std::string m_name;
  207. public:
  208. ScopeGuard(ScopeManager* manager, std::string name): m_manager{manager}, m_name{name} {
  209. m_manager->push(m_name);
  210. }
  211. ~ScopeGuard() {
  212. m_manager->pop(m_name);
  213. }
  214. };
  215. void push(std::string name) {
  216. m_current_scope = m_current_scope->enter(name);
  217. }
  218. void pop(std::string name) {
  219. m_current_scope = m_current_scope->exit(name);
  220. }
  221. std::string next_tensor_name() {
  222. std::string builder;
  223. Scope* scope = &m_root;
  224. while (true) {
  225. builder.append(scope->name);
  226. if (scope->version != 0) {
  227. builder.append(ssprintf("(%ld)", scope->version));
  228. }
  229. if (scope != &m_root) {
  230. builder.append(".");
  231. }
  232. if (scope->active_child == nullptr) {
  233. builder.append(ssprintf(":%%%ld", scope->tensor_count++));
  234. break;
  235. } else {
  236. scope = scope->active_child;
  237. }
  238. }
  239. return builder;
  240. }
  241. };
  242. struct State {
  243. std::thread::id tid;
  244. OptionManager options;
  245. };
  246. struct ChannelState: State {
  247. ScopeManager scopes;
  248. };
  249. struct WorkerState: State {};
  250. ChannelState m_channel_state;
  251. WorkerState m_worker_state;
  252. /*!
  253. * \brief A framework of dynamic sublienar memory optimization
  254. *
  255. * Note: The main idea is that during the training process, if the memory
  256. * usage exceeds the threshold, select some tensors to evict until the
  257. * memory usage is below the threshold.
  258. */
  259. struct DynamicSublinear {
  260. /*!
  261. * \brief find an available tensor with the largest evaluation function
  262. *
  263. * Note: An available tensor must satisfy: (1) has computing path,
  264. * (2) is in memory, (3) is not pinned. Evaluation function refers to:
  265. * @see: TensorInfo::eval_func.
  266. *
  267. * \return the pointer of the best tensor; nullptr is returned if no
  268. * available tensor is found
  269. */
  270. TensorInfo* find_best_tensor();
  271. /*!
  272. * \brief estimate the cost of recomputing tensor ptr
  273. *
  274. * Note: We define the cost as the sum of the costs of each evicted
  275. * components where all the neighbors of ptr are located.
  276. */
  277. double estimate_neighbor_cost(TensorInfo* ptr);
  278. /*!
  279. * \brief update the last used time of the tensor ptr
  280. */
  281. void update_used_time(TensorInfo* ptr);
  282. /*!
  283. * \brief merge the two specified sets (the set in which the element x
  284. * is located, and the set in which the element y is located)
  285. */
  286. void merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y);
  287. /*!
  288. * \brief return the representative of the set that contains the
  289. * element x
  290. */
  291. std::shared_ptr<DsuNode> find_father(std::shared_ptr<DsuNode> &x);
  292. /*!
  293. * \brief update DSU after recomputing tensor ptr
  294. *
  295. * Delete ptr from the set where ptr is located. Since DSU does not
  296. * support this operation, instead, we reset the DSU father of ptr, and
  297. * subtract the recomputation cost of ptr from the cost of the original
  298. * set.
  299. */
  300. void update_dsu_after_recompute(TensorInfo* ptr);
  301. /*!
  302. * \brief update DSU after evicting tensor ptr
  303. *
  304. * Check the neighbors of x, that is, the input and output tensors, and
  305. * if they are evicted, merge their respective sets.
  306. */
  307. void update_dsu_after_evict(TensorInfo* ptr);
  308. /*!
  309. * \brief pin the tensors in vec
  310. */
  311. void pin(const SmallVector<TensorInfo*>& vec);
  312. /*!
  313. * \brief unpin the tensors in vec
  314. */
  315. void unpin(const SmallVector<TensorInfo*>& vec);
  316. /*!
  317. * \brief add the tensor to the candidate set
  318. *
  319. * If the size of the tensor does not exceed the minimum threshold,
  320. * it will do nothing.
  321. */
  322. void insert_candidate(TensorInfo* ptr);
  323. /*!
  324. * \brief erase the tensor from the candidate set
  325. *
  326. * If the size of the tensor does not exceed the minimum threshold,
  327. * it will do nothing.
  328. */
  329. void erase_candidate(TensorInfo* ptr);
  330. //! estimate the current time, in order to reduce the overhead of timer
  331. double estimate_timestamp = 0;
  332. //! the comp node where dynamic sublinear memory optimization works
  333. CompNode comp_node;
  334. //! store all tensors that may be evicted
  335. std::unordered_set<TensorInfo*> candidates;
  336. //! whether the warning message has been printed
  337. bool warn_printed = false;
  338. bool is_bad_op(std::string op_name) {
  339. return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end();
  340. }
  341. std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd",
  342. "ParamPackSplit", "ParamPackConcat", "GaussianRNG"};
  343. } m_dtr;
  344. //! automatically evict an optimal tensor
  345. void auto_evict();
  346. // assert thread id when call get_xxx_state to avoid misuse
  347. ChannelState& get_channel_state();
  348. WorkerState& get_worker_state();
  349. };
  350. } // namespace mgb::imperative::interpreter::intl

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台