You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 46 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "range/v3/all.hpp"
  13. #include "megbrain/common.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/utils/to_string.h"
  19. #include "../event_pool.h"
  20. #include "../op_trait.h"
  21. using namespace mgb;
  22. using namespace imperative;
  23. using namespace interpreter;
  24. using namespace interpreter::intl;
  25. #define RECORD_EVENT(type, ...) \
  26. if (Profiler::is_profiling()) { \
  27. Profiler::record<type>(type{__VA_ARGS__}); \
  28. } \
  29. namespace {
  30. auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
  31. SmallVector<uint64_t> tid;
  32. for (auto* ptinfo: tinfo) {
  33. tid.push_back(ptinfo->id);
  34. }
  35. return tid;
  36. };
  37. }
  38. namespace mgb {
  39. using namespace profiler;
  40. }
  41. #ifdef __GNUG__
  42. namespace mgb {
  43. /**
  44. * USAGE
  45. *
  46. * header:
  47. * namespace mgb { bool imperative_log_profile(const char* message); }
  48. *
  49. * code:
  50. * mgb::imperative_log_profile("MY MESSAGE");
  51. *
  52. **/
  53. __attribute__((visibility("default")))
  54. void imperative_log_profile_begin(const char* message) {
  55. RECORD_EVENT(CustomEvent, std::string{message});
  56. }
  57. __attribute__((visibility("default")))
  58. void imperative_log_profile_end(const char* message) {
  59. RECORD_EVENT(CustomFinishEvent, std::string{message});
  60. }
  61. __attribute__((visibility("default")))
  62. void imperative_log_profile(const char* message){
  63. imperative_log_profile_begin(message);
  64. imperative_log_profile_end(message);
  65. }
  66. }
  67. #endif
  68. std::thread::id ChannelImpl::get_worker_tid() {
  69. return m_worker_state.tid;
  70. }
  71. ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
  72. assert_in_channel();
  73. return m_channel_state;
  74. }
  75. ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
  76. assert_in_worker();
  77. return m_worker_state;
  78. }
  79. // Do not use m_xxx_state directly
  80. #define m_channel_state
  81. #define m_worker_state
  82. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  83. return std::make_unique<ChannelImpl>();
  84. }
  85. Interpreter& Interpreter::inst() {
  86. static InterpreterImpl inst_;
  87. return inst_;
  88. }
  89. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  90. mgb_assert(check_available(), "Channel already closed");
  91. auto& state = get_channel_state();
  92. state.scopes.push("Put");
  93. auto info = put_impl(value, no_cache);
  94. state.scopes.pop("Put");
  95. return info;
  96. }
  97. TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
  98. auto info = alloc();
  99. init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()});
  100. info->h_value = value;
  101. m_buffer.enqueue(Put{info, value, no_cache});
  102. if (m_async_level == 0) {
  103. sync();
  104. info->desc.comp_node.sync();
  105. }
  106. return info;
  107. }
  108. Handle ChannelImpl::put(const DeviceTensorND& data) {
  109. auto& state = get_channel_state();
  110. mgb_assert(check_available(), "Channel already closed");
  111. state.scopes.push("Put");
  112. auto info = alloc();
  113. RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandEvent::Put);
  114. init(info, {data.layout(), data.comp_node()});
  115. info->ptr = Tensor::make(data);
  116. RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, data.raw_ptr());
  117. info->status = TensorInfo::Produced;
  118. RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandFinishEvent::Put);
  119. state.scopes.pop("Put");
  120. return info;
  121. }
  122. void ChannelImpl::del(Handle handle) {
  123. if (!check_available()){
  124. return;
  125. }
  126. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  127. auto* info = reinterpret_cast<TensorInfo*>(handle);
  128. m_valid_handle.erase(handle);
  129. m_buffer.enqueue(Del{info});
  130. }
  131. void ChannelImpl::swap_in(Handle handle) {
  132. mgb_assert(check_available(), "Channel already closed");
  133. auto& state = get_channel_state();
  134. if (state.options.enable_swap) {
  135. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  136. "invalid handle: %p", handle);
  137. auto* info = reinterpret_cast<TensorInfo*>(handle);
  138. m_buffer.enqueue(SwapIn{info});
  139. }
  140. }
  141. void ChannelImpl::swap_out(Handle handle) {
  142. mgb_assert(check_available(), "Channel already closed");
  143. auto& state = get_channel_state();
  144. if (state.options.enable_swap) {
  145. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  146. "invalid handle: %p", handle);
  147. auto* info = reinterpret_cast<TensorInfo*>(handle);
  148. m_buffer.enqueue(SwapOut{info});
  149. }
  150. }
  151. void ChannelImpl::drop(Handle handle) {
  152. mgb_assert(check_available(), "Channel already closed");
  153. auto& state = get_channel_state();
  154. if (state.options.enable_drop) {
  155. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  156. "invalid handle: %p", handle);
  157. auto* info = reinterpret_cast<TensorInfo*>(handle);
  158. m_buffer.enqueue(Drop{info});
  159. }
  160. }
  161. void ChannelImpl::dispatch_default_cpu(
  162. std::shared_ptr<OpDef> op,
  163. const SmallVector<TensorInfo*>& input_infos,
  164. const SmallVector<LogicalTensorDesc>& input_descs,
  165. SmallVector<Handle>* outputs) {
  166. auto& state = get_channel_state();
  167. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  168. RECORD_EVENT(ShapeInferEvent, validated);
  169. SmallVector<DeviceTensorND> input_tensornds;
  170. input_tensornds.reserve(input_descs.size());
  171. CompNode output_cn;
  172. {
  173. MGB_LOCK_GUARD(m_mutex);
  174. for (auto&& info : input_infos) {
  175. auto input_cn = info->desc.comp_node;
  176. if (!output_cn.valid()) {
  177. output_cn = input_cn;
  178. } else {
  179. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  180. }
  181. if (info->ptr && info->ptr->try_get_value()) {
  182. input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
  183. } else {
  184. // It's OK for SwapOut. We assign h_value before drop ptr
  185. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  186. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  187. }
  188. }
  189. }
  190. outputs->reserve(output_descs.size());
  191. SmallVector<DeviceTensorND> output_tensornds;
  192. output_tensornds.reserve(output_descs.size());
  193. for (auto&& desc : output_descs) {
  194. // TODO: may conflict with condtake, which need alloc inside
  195. mgb_assert(!desc.layout.is_empty());
  196. // use HostTensorND alloc_host for cuda pinned memory
  197. output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  198. }
  199. uint64_t op_id = Profiler::next_id();
  200. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  201. SmallVector<TensorInfo*> output_infos;
  202. output_infos.reserve(output_descs.size());
  203. for (auto&& tensornd : output_tensornds) {
  204. HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
  205. .proxy_to_comp_node(output_cn);
  206. // use `put` for consistency
  207. auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
  208. mgb_assert(info->desc.layout.ndim != 0);
  209. output_infos.push_back(info);
  210. outputs->push_back(info);
  211. }
  212. auto op_info_getter = [op]{
  213. std::unordered_map<std::string, std::string> op_info;
  214. auto props = OpDef::props(*op);
  215. for (auto&& [key, value]: props) {
  216. op_info[key] = value;
  217. }
  218. return op_info;
  219. };
  220. RECORD_EVENT(OpDispatchEvent, op_id, op->trait()->name, op_info_getter, tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
  221. }
  222. void ChannelImpl::dispatch_kernel(
  223. std::shared_ptr<OpDef> op,
  224. const SmallVector<TensorInfo*>& input_infos,
  225. const SmallVector<LogicalTensorDesc>& input_descs,
  226. SmallVector<Handle>* outputs) {
  227. auto& state = get_channel_state();
  228. auto& options = state.options;
  229. auto name = op->trait()->make_name(*op);
  230. state.scopes.push(name);
  231. auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
  232. RECORD_EVENT(ShapeInferEvent, validated);
  233. ApplyOp cmd{Profiler::next_id(), std::move(op)};
  234. cmd.inputs = std::move(input_infos);
  235. cmd.outputs.reserve(output_descs.size());
  236. outputs->reserve(output_descs.size());
  237. for (int i = 0; i < output_descs.size(); ++i) {
  238. auto&& desc = output_descs[i];
  239. auto info = alloc();
  240. init(info, desc);
  241. // make sure desc's value is consistent with h_value
  242. if (!info->desc.value.empty()) {
  243. info->h_value = HostTensorND::make_proxy(desc.value)
  244. .proxy_to_comp_node(desc.comp_node);
  245. }
  246. cmd.outputs.push_back(info);
  247. outputs->push_back(info);
  248. }
  249. auto op_info_getter = [op=cmd.op]{
  250. std::unordered_map<std::string, std::string> op_info;
  251. auto props = OpDef::props(*op);
  252. for (auto&& [key, value]: props) {
  253. op_info[key] = value;
  254. }
  255. return op_info;
  256. };
  257. RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
  258. m_buffer.enqueue(std::move(cmd));
  259. if (!validated && options.async_level == 1) {
  260. sync();
  261. } else if (options.async_level == 0) {
  262. sync();
  263. // check device error
  264. for (auto&& oup : *outputs) {
  265. auto info = reinterpret_cast<TensorInfo*>(oup);
  266. info->ptr->comp_node().sync();
  267. }
  268. }
  269. state.scopes.pop(name);
  270. }
  271. SmallVector<Handle> ChannelImpl::apply_op(
  272. std::shared_ptr<OpDef> op,
  273. const SmallVector<Handle>& inputs) {
  274. mgb_assert(check_available(), "Channel already closed");
  275. auto& state = get_channel_state();
  276. for (auto i : inputs) {
  277. mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
  278. "invalid handle: %p", i);
  279. }
  280. SmallVector<TensorInfo*> input_infos;
  281. input_infos.reserve(inputs.size());
  282. SmallVector<LogicalTensorDesc> input_descs;
  283. input_descs.reserve(inputs.size());
  284. {
  285. MGB_LOCK_GUARD(m_mutex);
  286. for (auto i : inputs) {
  287. auto info = reinterpret_cast<TensorInfo*>(i);
  288. mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
  289. input_infos.push_back(info);
  290. input_descs.push_back(info->desc);
  291. }
  292. }
  293. SmallVector<Handle> outputs;
  294. DispatchMode dispatch_mode = state.options.enable_host_compute
  295. ? OpDef::decide_dispatch_mode(*op, input_descs)
  296. : DispatchMode::KERNEL;
  297. switch (dispatch_mode) {
  298. case DEFAULT_CPU: {
  299. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  300. break;
  301. }
  302. case KERNEL: {
  303. dispatch_kernel(op, input_infos, input_descs, &outputs);
  304. break;
  305. }
  306. }
  307. return outputs;
  308. }
  309. HostTensorND ChannelImpl::get_value(Handle handle) {
  310. mgb_assert(check_available(), "Channel already closed");
  311. auto& state = get_channel_state();
  312. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  313. "invalid handle: %p", handle);
  314. auto info = reinterpret_cast<TensorInfo*>(handle);
  315. // donnot use info->value_fetched, it's unsafe
  316. mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
  317. return wait_tensor(info, TensorProp::HostValue)->get_value();
  318. }
  319. TensorShape ChannelImpl::get_shape(Handle handle) {
  320. mgb_assert(check_available(), "Channel already closed");
  321. auto& state = get_channel_state();
  322. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  323. "invalid handle: %p", handle);
  324. auto info = reinterpret_cast<TensorInfo*>(handle);
  325. if (info->desc.layout.ndim != 0) {
  326. return info->desc.layout;
  327. }
  328. TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
  329. mgb_assert(ret.ndim != 0);
  330. return ret;
  331. }
  332. DType ChannelImpl::get_dtype(Handle handle) {
  333. mgb_assert(check_available(), "Channel already closed");
  334. auto& state = get_channel_state();
  335. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  336. "invalid handle: %p", handle);
  337. auto info = reinterpret_cast<TensorInfo*>(handle);
  338. RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
  339. auto ret = info->desc.layout.dtype;
  340. mgb_assert(ret.valid());
  341. return ret;
  342. }
  343. CompNode ChannelImpl::get_device(Handle handle) {
  344. mgb_assert(check_available(), "Channel already closed");
  345. auto& state = get_channel_state();
  346. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  347. "invalid handle: %p", handle);
  348. auto info = reinterpret_cast<TensorInfo*>(handle);
  349. RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
  350. auto ret = info->desc.comp_node;
  351. mgb_assert(ret.valid());
  352. return ret;
  353. }
  354. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  355. mgb_assert(check_available(), "Channel already closed");
  356. auto& state = get_channel_state();
  357. mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
  358. "invalid handle: %p", handle);
  359. auto info = reinterpret_cast<TensorInfo*>(handle);
  360. return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
  361. }
  362. void ChannelImpl::sync() {
  363. mgb_assert(check_available(), "Channel already closed");
  364. auto& state = get_channel_state();
  365. m_buffer.flush();
  366. m_worker.wait_all_task_finish();
  367. MGB_LOCK_GUARD(m_mutex);
  368. check_worker_exc_unsafe();
  369. }
  370. void ChannelImpl::close() {
  371. if (!check_available()) {
  372. return;
  373. }
  374. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  375. for (auto* handle: valid_handles) {
  376. del(handle);
  377. }
  378. mgb_assert(m_valid_handle.empty());
  379. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  380. sync();
  381. m_closed = true;
  382. }
  383. size_t ChannelImpl::get_option(std::string name) {
  384. mgb_assert(check_available(), "Channel already closed");
  385. auto& state = get_channel_state();
  386. return state.options.get_option(name);
  387. }
  388. void ChannelImpl::set_option(std::string name, size_t value) {
  389. mgb_assert(check_available(), "Channel already closed");
  390. auto& state = get_channel_state();
  391. state.options.set_option(name, value);
  392. m_buffer.enqueue(SetOption{name, value});
  393. }
  394. TensorInfo* ChannelImpl::alloc() {
  395. auto& state = get_channel_state();
  396. auto info = [this]{
  397. MGB_LOCK_GUARD(m_mutex);
  398. return m_pool.alloc();
  399. }();
  400. info->id = Profiler::next_id();
  401. if (Profiler::is_profiling()) {
  402. info->name = state.scopes.next_tensor_name();
  403. }
  404. return info;
  405. }
  406. void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
  407. m_valid_handle.insert(info);
  408. RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  409. info->status = TensorInfo::Allocated;
  410. info->desc = std::move(desc);
  411. }
  412. void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) {
  413. if (!ptr->producer) {
  414. if (user) {
  415. mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", ptr);
  416. }
  417. return;
  418. }
  419. if (ptr->evict_type != EvictType::NONE) {
  420. return;
  421. }
  422. ptr->evict_type = EvictType::DROP;
  423. ptr->status = TensorInfo::Dropped;
  424. release_tensor(ptr);
  425. }
  426. void ChannelImpl::free(TensorInfo* ptr) {
  427. auto& state = get_worker_state();
  428. if (state.options.enable_dtr_auto_drop) {
  429. // Evicting a tensor, rather than freeing it, can avoid pinning
  430. // potentially exploding amounts of memory and allow us to save
  431. // more memory.
  432. ptr->allow_delete = true;
  433. if (!ptr->ref_cnt) {
  434. recursive_free(ptr);
  435. } else {
  436. do_drop(ptr);
  437. }
  438. } else {
  439. real_free(ptr);
  440. }
  441. }
  442. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  443. RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandEvent::RecFree);
  444. SmallVector<TensorInfo*> inps;
  445. if (ptr->producer) {
  446. for (auto i : ptr->producer->inputs) {
  447. if (i && --i->ref_cnt == 0) {
  448. inps.push_back(i);
  449. }
  450. }
  451. }
  452. real_free(ptr);
  453. for (auto i : inps) {
  454. if (i->allow_delete) {
  455. recursive_free(i);
  456. }
  457. }
  458. RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandFinishEvent::RecFree);
  459. }
  460. void ChannelImpl::real_free(TensorInfo* ptr) {
  461. auto& state = get_worker_state();
  462. MGB_LOCK_GUARD(m_mutex);
  463. if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  464. m_dtr.erase_candidate(ptr);
  465. }
  466. detach_users(ptr);
  467. ptr->detach_producer();
  468. bool has_value = ptr->ptr != nullptr;
  469. if (has_value) {
  470. RECORD_EVENT(TensorReleaseEvent, ptr->id);
  471. }
  472. RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
  473. ptr->status = TensorInfo::Deleted;
  474. m_pool.free(ptr);
  475. }
  476. ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
  477. ChannelImpl::~ChannelImpl() {
  478. close();
  479. }
  480. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) {
  481. auto& state = get_worker_state();
  482. std::unique_lock<std::mutex> lock{m_mutex, std::defer_lock};
  483. if (notice) {
  484. lock.lock();
  485. }
  486. m_dtr.update_used_time(dest);
  487. RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->dev_tensor().raw_ptr());
  488. // update tensor desc for static infer
  489. dest->desc.layout = ptr->layout();
  490. dest->desc.comp_node = ptr->comp_node();
  491. dest->memory = ptr->blob()->size();
  492. dest->ptr = std::move(ptr);
  493. dest->evict_type = EvictType::NONE;
  494. dest->status = TensorInfo::Produced;
  495. if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  496. m_dtr.insert_candidate(dest);
  497. }
  498. if (notice) {
  499. notify_tensor_unsafe(dest);
  500. }
  501. }
  502. void ChannelImpl::release_tensor(TensorInfo* dest) {
  503. RECORD_EVENT(TensorReleaseEvent, dest->id);
  504. MGB_LOCK_GUARD(m_mutex);
  505. dest->ptr.reset();
  506. }
  507. void ChannelImpl::regenerate(TensorInfo* dest) {
  508. RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen);
  509. if (dest->evict_type == EvictType::DROP) {
  510. recompute(dest->producer);
  511. } else if (dest->evict_type == EvictType::SWAP) {
  512. produce_tensor(dest, Tensor::make(dest->h_value));
  513. }
  514. RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen);
  515. }
  516. void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
  517. using namespace ranges;
  518. using namespace ranges::views;
  519. auto& state = get_worker_state();
  520. bool profiling_device = Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
  521. uint64_t apply_id = cmd.id;
  522. SmallVector<TensorPtr> tensor_inputs;
  523. if (state.options.enable_dtr_auto_drop) {
  524. m_dtr.pin(cmd.inputs);
  525. }
  526. for (auto i : cmd.inputs) {
  527. if (!i->ptr && i->evict_type != EvictType::NONE) {
  528. regenerate(i);
  529. }
  530. m_dtr.update_used_time(i);
  531. }
  532. tensor_inputs.reserve(cmd.inputs.size());
  533. // refcnt == 1, owners: [TensorInfo::ptr]
  534. for (auto i : cmd.inputs) {
  535. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  536. // refcnt ++, owners: [i->ptr, tensor_inputs]
  537. tensor_inputs.push_back(i->ptr);
  538. }
  539. RECORD_EVENT(OpExecuteEvent, apply_id);
  540. // Begin profiling operator
  541. SmallVector<std::pair<CompNode, uint64_t>> kernels;
  542. if (profiling_device) {
  543. // Collecting devices
  544. SmallVector<CompNode> devices;
  545. for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
  546. if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
  547. devices.push_back(i->desc.comp_node);
  548. kernels.push_back({i->desc.comp_node, Profiler::next_id()});
  549. }
  550. }
  551. }
  552. for (auto* input: cmd.inputs) {
  553. auto input_id = input->id;
  554. RECORD_EVENT(OpInputEvent, input_id);
  555. RECORD_EVENT(TensorUsageEvent, input_id);
  556. RECORD_EVENT(OpInputFinishEvent, input_id);
  557. }
  558. // Fused by command buffer. @see: CommandBuffer::fuse_del
  559. // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
  560. // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
  561. for (auto* del : cmd.dels) {
  562. // refcnt --, owners: [tensor_inputs]
  563. // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor
  564. uint64_t del_id = del->id;
  565. RECORD_EVENT(OpDelEvent, del_id);
  566. free(del);
  567. RECORD_EVENT(OpDelFinishEvent, del_id);
  568. }
  569. // Before wait
  570. //TODO: split operator wait and execute so that OpWait could be corrected recorded.
  571. // Before execute
  572. for (auto&& [device, kernel_id]: kernels) {
  573. RECORD_EVENT(KernelExecuteEvent, apply_id, kernel_id, Timer::record_event(device));
  574. }
  575. if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
  576. auto_evict();
  577. }
  578. // Apply op
  579. // Here std::move is REQUIRED for removing duplicated references.
  580. auto tensor_outputs = OpDef::apply_on_physical_tensor(
  581. *cmd.op, std::move(tensor_inputs));
  582. // After execute
  583. for (auto&& [device, kernel_id]: kernels) {
  584. RECORD_EVENT(KernelExecuteFinishEvent, apply_id, kernel_id, Timer::record_event(device));
  585. }
  586. // End profiling operator
  587. mgb_assert(tensor_outputs.size() == cmd.outputs.size());
  588. for (size_t i = 0; i < tensor_outputs.size(); ++i) {
  589. auto output = cmd.outputs[i];
  590. if (output == nullptr) {
  591. RECORD_EVENT(OpOutputEvent, 0);
  592. RECORD_EVENT(OpOutputFinishEvent, 0);
  593. } else if (output->ptr != nullptr) {
  594. RECORD_EVENT(OpOutputEvent, output->id);
  595. RECORD_EVENT(OpOutputFinishEvent, output->id);
  596. } else {
  597. RECORD_EVENT(OpOutputEvent, output->id);
  598. produce_tensor(output, tensor_outputs[i]);
  599. RECORD_EVENT(OpOutputFinishEvent, output->id);
  600. sample_on_device(output->desc.comp_node, false);
  601. }
  602. }
  603. if (state.options.enable_dtr_auto_drop) {
  604. double estimate_compute_time = 0;
  605. for (auto i : cmd.inputs) {
  606. estimate_compute_time += i->memory;
  607. }
  608. for (auto i : tensor_outputs) {
  609. estimate_compute_time += i->blob()->size();
  610. }
  611. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  612. for (auto i : cmd.outputs) {
  613. if (i != nullptr) {
  614. i->compute_time = estimate_compute_time;
  615. }
  616. }
  617. m_dtr.unpin(cmd.inputs);
  618. }
  619. RECORD_EVENT(OpExecuteFinishEvent, apply_id);
  620. // End profiling operator
  621. }
  622. void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
  623. auto& state = get_worker_state();
  624. do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}});
  625. for (size_t i = 0;i < path->outputs.size();i ++) {
  626. auto&& o = path->outputs[i];
  627. if (o) {
  628. o->recompute_times ++;
  629. if (!o->ptr) {
  630. if (state.options.enable_dtr_auto_drop) {
  631. m_dtr.update_dsu_after_recompute(o);
  632. }
  633. }
  634. }
  635. }
  636. }
  637. void ChannelImpl::auto_evict() {
  638. auto& state = get_worker_state();
  639. if (!m_dtr.comp_node.valid()) {
  640. return;
  641. }
  642. size_t current_memory = m_dtr.comp_node.get_used_memory();
  643. while (current_memory > state.options.dtr_eviction_threshold) {
  644. RECORD_EVENT(AutoEvictEvent);
  645. sample_on_device(m_dtr.comp_node, false);
  646. auto best = m_dtr.find_best_tensor();
  647. if (!best) {
  648. if (!m_dtr.warn_printed) {
  649. m_dtr.warn_printed = true;
  650. mgb_log_warn("No tensors on %s can be evicted automatically "
  651. "when memory usage is %.0lfMB. Maybe memory "
  652. "budget is too small.",
  653. m_dtr.comp_node.to_string().c_str(),
  654. current_memory / 1024.0 / 1024.0);
  655. }
  656. break;
  657. }
  658. if (best->ptr.unique() && best->ptr->blob().unique()) {
  659. current_memory -= best->memory;
  660. }
  661. do_drop(best);
  662. if (best->evict_type == EvictType::DROP) {
  663. m_dtr.update_dsu_after_evict(best);
  664. }
  665. sample_on_device(m_dtr.comp_node, false);
  666. RECORD_EVENT(AutoEvictFinishEvent);
  667. }
  668. }
  669. void ChannelImpl::detach_users(TensorInfo* dest) {
  670. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  671. for (auto* user: users) {
  672. SmallVector<TensorInfo*> outputs = user->outputs;
  673. SmallVector<TensorInfo*> inputs = user->inputs;
  674. for (auto* output: outputs) {
  675. // When a `ComputePath` is detach from it's input,
  676. // there is no need to reserve it,
  677. // so we detach all output of this path
  678. // to decrease it's `ref_cnt` to zero.
  679. if (output == nullptr) {
  680. continue;
  681. }
  682. regenerate(output);
  683. output->detach_producer();
  684. for (auto* input: inputs) {
  685. input->ref_cnt --;
  686. }
  687. }
  688. // now user is dead
  689. }
  690. mgb_assert(dest->users.empty(), "ComputePath leaking");
  691. }
  692. bool ChannelImpl::check_available() {
  693. return !m_closed;
  694. }
  695. TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
  696. m_buffer.flush();
  697. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  698. mgb_assert(!m_waitee, "duplicate waitee");
  699. m_waitee = info;
  700. m_waitee_id = Profiler::next_id();
  701. RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
  702. bool require_host = prop == TensorProp::HostValue;
  703. bool value_fetching = false;
  704. m_cv.wait(lock, [&]() {
  705. check_worker_exc_unsafe();
  706. if (require_host) {
  707. if (info->ptr && info->ptr->value_fetched()) {
  708. return true;
  709. }
  710. if (!value_fetching) {
  711. m_buffer.enqueue(GetValue{info});
  712. value_fetching = true;
  713. }
  714. return false;
  715. } else {
  716. return static_cast<bool>(info->ptr);
  717. }
  718. });
  719. RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, m_waitee == nullptr);
  720. if (m_waitee != nullptr) {
  721. mgb_assert(m_waitee == info, "waitee mismatch");
  722. m_waitee = nullptr;
  723. }
  724. return info->ptr;
  725. }
  726. void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
  727. if (info == m_waitee) {
  728. m_waitee = nullptr;
  729. RECORD_EVENT(TensorNotifyPropEvent, info->id);
  730. m_cv.notify_all();
  731. }
  732. }
  733. std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
  734. std::unordered_set<TensorInfo*> valid_tensors;
  735. for (auto* handle: m_valid_handle) {
  736. auto* info = reinterpret_cast<TensorInfo*>(handle);
  737. valid_tensors.insert(info);
  738. //TODO: valid_tensors.insert({info, info->status});
  739. }
  740. return valid_tensors;
  741. }
  742. void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
  743. using namespace ranges;
  744. using namespace ranges::views;
  745. auto& state = get_worker_state();
  746. auto& options = state.options;
  747. //TODO: remove std::visit for support osx 10.12
  748. auto cmd_visitor = [&](const auto& cmd) {
  749. using T = std::decay_t<decltype(cmd)>;
  750. if constexpr (std::is_same_v<T, Put>) {
  751. RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Put);
  752. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value);
  753. produce_tensor(cmd.dest, std::move(value));
  754. RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put);
  755. sample_on_device(cmd.dest->desc.comp_node, false);
  756. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  757. do_apply_op(cmd);
  758. for (size_t i = 0; i < cmd.outputs.size(); ++i) {
  759. auto output = cmd.outputs[i];
  760. if (output == nullptr) {
  761. continue;
  762. }
  763. if (state.options.enable_dtr_auto_drop) {
  764. output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
  765. }
  766. }
  767. if (state.options.enable_drop && state.options.record_computing_path) {
  768. auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
  769. auto& input = std::get<0>(tuple2);
  770. auto& output = std::get<1>(tuple2);
  771. if (!input->ptr || !output->ptr) {
  772. return false;
  773. }
  774. return input->ptr->blob()->storage() == output->ptr->blob()->storage();
  775. };
  776. // FIXME: do not use opname as identifier
  777. auto get_name = [](const OpDef& opdef) {
  778. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  779. return attr->type.c_str();
  780. }
  781. return opdef.dyn_typeinfo()->name;
  782. };
  783. auto is_cross_cn = [comp_node=m_dtr.comp_node](TensorInfo* info){
  784. return info->desc.comp_node != comp_node;
  785. };
  786. bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
  787. bool inplace = any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
  788. if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  789. TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
  790. size_t detach_cnt = 0;
  791. for (auto output : cmd.outputs) {
  792. if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  793. output->detach_producer();
  794. detach_cnt ++;
  795. }
  796. }
  797. for (auto input : cmd.inputs) {
  798. input->ref_cnt -= detach_cnt;
  799. }
  800. }
  801. }
  802. } else if constexpr (std::is_same_v<T, Del>) {
  803. RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Del);
  804. CompNode device = cmd.dest->desc.comp_node;
  805. uint64_t tensor_id = cmd.dest->id;
  806. free(cmd.dest);
  807. RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del);
  808. sample_on_device(device, false);
  809. } else if constexpr (std::is_same_v<T, GetValue>) {
  810. imperative_log_profile_begin("GetValue");
  811. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  812. regenerate(cmd.dest);
  813. }
  814. mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
  815. cmd.dest->ptr->fetch_value();
  816. MGB_LOCK_GUARD(m_mutex);
  817. notify_tensor_unsafe(cmd.dest);
  818. imperative_log_profile_end("GetValue");
  819. } else if constexpr (std::is_same_v<T, SwapIn>) {
  820. RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn);
  821. produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
  822. RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn);
  823. sample_on_device(cmd.dest->desc.comp_node, false);
  824. } else if constexpr (std::is_same_v<T, SwapOut>) {
  825. RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut);
  826. cmd.dest->h_value = cmd.dest->ptr->get_value();
  827. if (cmd.dest->evict_type == EvictType::NONE) {
  828. cmd.dest->evict_type = EvictType::SWAP;
  829. cmd.dest->status = TensorInfo::Swapped;
  830. release_tensor(cmd.dest);
  831. }
  832. RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut);
  833. sample_on_device(cmd.dest->desc.comp_node, false);
  834. } else if constexpr (std::is_same_v<T, Drop>) {
  835. RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop);
  836. do_drop(cmd.dest, true);
  837. RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop);
  838. } else if constexpr (std::is_same_v<T, SetOption>) {
  839. options.set_option(cmd.key, cmd.value);
  840. } else if constexpr (std::is_same_v<T, StartProfile>) {
  841. RECORD_EVENT(StartProfileEvent);
  842. CompNode::sync_all();
  843. for (auto* info: cmd.capture_tensors) {
  844. RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  845. if (info->status == TensorInfo::Produced) {
  846. // TODO: handle swap/drop
  847. RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
  848. }
  849. }
  850. CompNode::foreach([&](CompNode device){
  851. if (Profiler::get_option("sample_rate", 0)) {
  852. sample_on_device(device, true);
  853. }
  854. });
  855. RECORD_EVENT(StartProfileFinishEvent);
  856. } else if constexpr (std::is_same_v<T, StopProfile>) {
  857. RECORD_EVENT(StopProfileEvent);
  858. for (auto* info: cmd.escape_tensors) {
  859. bool has_value = info->status == TensorInfo::Produced;
  860. if (has_value) {
  861. RECORD_EVENT(TensorReleaseEvent, info->id);
  862. }
  863. RECORD_EVENT(TensorEraseEvent, info->id);
  864. }
  865. CompNode::foreach([&](CompNode device){
  866. if (Profiler::get_option("sample_rate", 0)) {
  867. sample_on_device(device, true);
  868. }
  869. });
  870. RECORD_EVENT(StopProfileFinishEvent);
  871. } else if constexpr (std::is_same_v<T, PushScope>) {
  872. RECORD_EVENT(ScopeEvent, cmd.scope_name);
  873. } else if constexpr (std::is_same_v<T, PopScope>) {
  874. RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
  875. } else {
  876. static_assert(!std::is_same_v<T, T>);
  877. }
  878. };
  879. std::visit([&](const auto& cmd){
  880. using T = std::decay_t<decltype(cmd)>;
  881. if (!options.catch_worker_execption) {
  882. cmd_visitor(cmd);
  883. return;
  884. }
  885. try {
  886. cmd_visitor(cmd);
  887. } catch (...) {
  888. MGB_LOCK_GUARD(m_mutex);
  889. if constexpr (std::is_same_v<T, ApplyOp>) {
  890. for (auto oup : cmd.outputs) {
  891. oup->invalid = true;
  892. }
  893. } else if constexpr (std::is_same_v<T, Put>) {
  894. cmd.dest->invalid = true;
  895. }
  896. m_worker_exc = std::current_exception();
  897. RECORD_EVENT(WorkerExceptionEvent);
  898. if (m_waitee) {
  899. notify_tensor_unsafe(m_waitee);
  900. }
  901. }
  902. }, icmd.second);
  903. }
  904. void ChannelImpl::check_worker_exc_unsafe() {
  905. if (m_worker_exc) {
  906. // for reuse interpreter_for_py after some exception tests
  907. m_waitee = nullptr;
  908. std::exception_ptr exc;
  909. std::swap(exc, m_worker_exc);
  910. std::rethrow_exception(exc);
  911. }
  912. }
  913. void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
  914. if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
  915. return;
  916. }
  917. // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
  918. m_commands.push_back(std::move(cmd));
  919. auto flush_pos = flush_pos_for(m_commands.back());
  920. flush(flush_pos);
  921. }
  922. void ChannelImpl::CommandBuffer::flush() {
  923. flush(m_commands.end());
  924. }
  925. void ChannelImpl::CommandBuffer::flush(Handle pos) {
  926. auto& state = m_owner->get_channel_state();
  927. for (auto iter = m_commands.begin(); iter != pos; ++iter) {
  928. if (Profiler::is_profiling()) {
  929. mgb_log_debug("%s Flushed", to_string(*iter).c_str());
  930. }
  931. m_owner->m_worker.add_task(IdentifiedCommand{Profiler::next_id(), std::move(*iter)});
  932. }
  933. m_commands.erase(m_commands.begin(), pos);
  934. }
  935. auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
  936. auto& state = m_owner->get_channel_state();
  937. return std::visit([this, &state](const auto& cmd) {
  938. using T = std::decay_t<decltype(cmd)>;
  939. if constexpr (std::is_same_v<T, ApplyOp>) {
  940. auto* op_type = cmd.op->dyn_typeinfo();
  941. if (op_type == RemoteRecv::typeinfo() ||
  942. op_type == RemoteSend::typeinfo() ||
  943. op_type == CollectiveComm::typeinfo() ||
  944. op_type == opr::InputCallback::typeinfo() ||
  945. op_type == opr::OutputCallback::typeinfo()) {
  946. return m_commands.end();
  947. }
  948. } else if constexpr (std::is_same_v<T, GetValue>) {
  949. return m_commands.end();
  950. }
  951. size_t buffer_length = state.options.buffer_length;
  952. if (m_commands.size() > buffer_length) {
  953. return m_commands.begin() + (m_commands.size() - buffer_length);
  954. }
  955. return m_commands.begin();
  956. }, cmd);
  957. }
  958. /**
  959. * 1. Find ApplyOp(dest) in buffered commands
  960. * 2. Check if there are other usages between ApplyOp and Del, return false if not
  961. * 3. Fuse Del into ApplyOp, return true
  962. */
  963. bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
  964. auto* dest = cmd.dest;
  965. // TODO: eliminate Puts
  966. auto begin = m_commands.begin(), end = m_commands.end();
  967. auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){
  968. if (auto* apply = std::get_if<ApplyOp>(&cmd)) {
  969. return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
  970. }
  971. return false;
  972. });
  973. if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
  974. return false;
  975. }
  976. // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
  977. std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
  978. return true;
  979. }
  980. auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
  981. -> Handle {
  982. auto found = range[1];
  983. for (auto iter = range[0]; iter != range[1]; ++iter) {
  984. std::visit([&](const auto& cmd) {
  985. using T = std::decay_t<decltype(cmd)>;
  986. if constexpr (std::is_same_v<T, ApplyOp>) {
  987. if (std::count(cmd.inputs.begin(), cmd.inputs.end(),
  988. dest) > 0) {
  989. found = iter;
  990. }
  991. } else if constexpr (std::is_same_v<T, GetValue>) {
  992. if (cmd.dest == dest) {
  993. found = iter;
  994. }
  995. } else if constexpr (std::is_same_v<T, SwapIn> ||
  996. std::is_same_v<T, SwapOut> ||
  997. std::is_same_v<T, Drop>) {
  998. //TODO: ignore swap-like commands, just remove them from buffer
  999. if (cmd.dest == dest) {
  1000. found = iter;
  1001. }
  1002. }
  1003. }, *iter);
  1004. };
  1005. return found;
  1006. }
  1007. auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
  1008. -> Handle {
  1009. return std::find_if(range[0], range[1], [dest](auto& cmd) {
  1010. return std::visit([dest](const auto& cmd){
  1011. using T = std::decay_t<decltype(cmd)>;
  1012. if constexpr (std::is_same_v<T, ApplyOp>) {
  1013. return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0;
  1014. } else if constexpr (std::is_same_v<T, Put>) {
  1015. return cmd.dest == dest;
  1016. }
  1017. return false;
  1018. }, cmd);
  1019. });
  1020. }
  1021. void ChannelImpl::start_profile() {
  1022. mgb_assert(check_available(), "Channel already closed");
  1023. auto capture_tensors = collect_valid_tensors();
  1024. if (capture_tensors.size() > 0) {
  1025. m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
  1026. }
  1027. }
  1028. void ChannelImpl::stop_profile() {
  1029. mgb_assert(check_available(), "Channel already closed");
  1030. m_buffer.flush();
  1031. auto escape_tensors = collect_valid_tensors();
  1032. if (escape_tensors.size() > 0) {
  1033. m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
  1034. }
  1035. }
  1036. void ChannelImpl::push_scope(std::string name) {
  1037. mgb_assert(check_available(), "Channel already closed");
  1038. auto& state = get_channel_state();
  1039. state.scopes.push(name);
  1040. RECORD_EVENT(ScopeEvent, name);
  1041. m_buffer.enqueue(PushScope{name});
  1042. }
  1043. void ChannelImpl::pop_scope(std::string name) {
  1044. mgb_assert(check_available(), "Channel already closed");
  1045. auto& state = get_channel_state();
  1046. state.scopes.pop(name);
  1047. RECORD_EVENT(ScopeFinishEvent, name);
  1048. m_buffer.enqueue(PopScope{name});
  1049. }
  1050. void ChannelImpl::assert_in_channel() {
  1051. mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread");
  1052. }
  1053. void ChannelImpl::assert_in_worker() {
  1054. mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread");
  1055. }
  1056. void ChannelImpl::sample_on_device(CompNode device, bool force) {
  1057. if (!force) {
  1058. thread_local int last_sample_id = 0;
  1059. int sample_rate = Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
  1060. if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
  1061. return;
  1062. }
  1063. }
  1064. RECORD_EVENT(SampleDeviceEvent, device);
  1065. auto [total, free] = device.get_mem_status_bytes();
  1066. RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
  1067. }
  1068. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  1069. for (auto i : vec) {
  1070. i->pin();
  1071. }
  1072. }
  1073. void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) {
  1074. for (auto i : vec) {
  1075. i->unpin();
  1076. }
  1077. }
  1078. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  1079. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  1080. dsu_fa->t -= ptr->compute_time;
  1081. ptr->dsu_ptr->parent.reset();
  1082. ptr->dsu_ptr->t = ptr->compute_time;
  1083. }
  1084. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  1085. for (auto i : ptr->producer->inputs) {
  1086. if (i->evict_type == EvictType::DROP) {
  1087. merge(i->dsu_ptr, ptr->dsu_ptr);
  1088. }
  1089. }
  1090. for (auto i : ptr->producer->outputs) {
  1091. if (i && i->evict_type == EvictType::DROP) {
  1092. merge(ptr->dsu_ptr, i->dsu_ptr);
  1093. }
  1094. }
  1095. }
  1096. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1097. double cost = 0;
  1098. for (auto i : ptr->producer->inputs) {
  1099. if (i->evict_type == EvictType::DROP) {
  1100. double t = find_father(i->dsu_ptr)->t;
  1101. if (t < i->compute_time) {
  1102. t = i->compute_time;
  1103. }
  1104. cost += t;
  1105. }
  1106. }
  1107. for (auto i : ptr->producer->outputs) {
  1108. if (i && i->evict_type == EvictType::DROP) {
  1109. double t = find_father(i->dsu_ptr)->t;
  1110. if (t < i->compute_time) {
  1111. t = i->compute_time;
  1112. }
  1113. cost += t;
  1114. }
  1115. }
  1116. return cost;
  1117. }
  1118. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
  1119. double min_msps = -1;
  1120. TensorInfo* best = nullptr;
  1121. for (auto i : candidates) {
  1122. if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
  1123. double neighbor_cost = estimate_neighbor_cost(i);
  1124. size_t begin_ptr = reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1125. auto side_info = i->ptr->comp_node().get_free_left_and_right(begin_ptr, begin_ptr + i->ptr->blob()->size());
  1126. double free_mem = side_info.first + side_info.second;
  1127. double msps = i->eval_func(neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1128. if (min_msps < 0 || msps < min_msps) {
  1129. min_msps = msps;
  1130. best = i;
  1131. }
  1132. }
  1133. }
  1134. return best;
  1135. }
  1136. void ChannelImpl::DynamicSublinear::merge(std::shared_ptr<DsuNode> &x, std::shared_ptr<DsuNode> &y) {
  1137. auto&& f_x = find_father(x);
  1138. auto&& f_y = find_father(y);
  1139. if (f_x.get() == f_y.get()) {
  1140. return;
  1141. }
  1142. f_y->t += f_x->t;
  1143. f_x->parent = f_y;
  1144. }
  1145. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(std::shared_ptr<DsuNode>& x) {
  1146. if (x->is_root()) {
  1147. return x;
  1148. } else {
  1149. auto&& fa = find_father(x->parent);
  1150. return x->parent = fa;
  1151. }
  1152. }
  1153. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1154. candidates.insert(ptr);
  1155. if (!comp_node.valid()) {
  1156. comp_node = ptr->ptr->comp_node();
  1157. }
  1158. }
  1159. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1160. candidates.erase(ptr);
  1161. }
  1162. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1163. ptr->last_used_time = estimate_timestamp;
  1164. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台