You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 58 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "range/v3/all.hpp"
  13. #include "megbrain/common.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/ops/utility.h"
  19. #include "megbrain/imperative/utils/to_string.h"
  20. #include "../blob_manager_impl.h"
  21. #include "../event_pool.h"
  22. #include "../op_trait.h"
  23. using namespace mgb;
  24. using namespace imperative;
  25. using namespace interpreter;
  26. using namespace interpreter::intl;
  27. namespace {
  28. auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
  29. SmallVector<uint64_t> tid;
  30. for (auto* ptinfo : tinfo) {
  31. tid.push_back(ptinfo->id);
  32. }
  33. return tid;
  34. };
  35. } // namespace
  36. namespace mgb {
  37. using namespace profiler;
  38. }
  39. #if defined(_WIN32) || defined(_WIN64)
  40. #define SYMBOL_EXPORT __declspec(dllexport)
  41. #else
  42. #define SYMBOL_EXPORT __attribute__((visibility("default")))
  43. #endif
  44. namespace mgb {
  45. /**
  46. * USAGE
  47. *
  48. * header:
  49. * namespace mgb { void imperative_log_profile(const char* message); }
  50. *
  51. * code:
  52. * mgb::imperative_log_profile("MY MESSAGE");
  53. *
  54. **/
  55. SYMBOL_EXPORT
  56. void imperative_log_profile_begin(const char* message) {
  57. MGB_RECORD_EVENT(CustomEvent, std::string{message});
  58. }
  59. SYMBOL_EXPORT
  60. void imperative_log_profile_end(const char* message) {
  61. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
  62. }
  63. SYMBOL_EXPORT
  64. void imperative_log_profile(const char* message) {
  65. imperative_log_profile_begin(message);
  66. imperative_log_profile_end(message);
  67. }
  68. SYMBOL_EXPORT
  69. void imperative_log_profile_begin(const char* message, const char* device) {
  70. auto comp_node = CompNode::load(device);
  71. MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
  72. MGB_RECORD_EVENT(
  73. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  74. }
  75. SYMBOL_EXPORT
  76. void imperative_log_profile_end(const char* message, const char* device) {
  77. auto comp_node = CompNode::load(device);
  78. MGB_RECORD_EVENT(
  79. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  80. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
  81. }
  82. } // namespace mgb
  83. std::thread::id ChannelImpl::get_worker_tid() {
  84. return m_worker_state.tid;
  85. }
  86. ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
  87. assert_in_channel();
  88. return m_channel_state;
  89. }
  90. ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
  91. assert_in_worker();
  92. return m_worker_state;
  93. }
  94. void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
  95. sys::set_thread_name("worker");
  96. m_owner->m_worker_state.tid = std::this_thread::get_id();
  97. OpDef::set_allocator([&](CompNode device, size_t size) {
  98. auto blob = Blob::make(device, size);
  99. m_owner->alloc_tensor_with_evict(blob.get());
  100. return blob->storage();
  101. });
  102. }
  103. // Do not use m_xxx_state directly
  104. #define m_channel_state
  105. #define m_worker_state
  106. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  107. return std::make_unique<ChannelImpl>();
  108. }
  109. Interpreter& Interpreter::inst() {
  110. static InterpreterImpl inst_;
  111. return inst_;
  112. }
  113. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  114. MGB_LOCK_GUARD(m_spin);
  115. mgb_assert(check_available(), "Channel already closed");
  116. auto& state = get_channel_state();
  117. auto _ = StackManager::Guard{"Put", &state.stack_manager};
  118. auto info = put_impl(value, no_cache);
  119. return reinterpret_cast<Handle>(info);
  120. }
  121. TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
  122. if (value.empty()) {
  123. auto layout = value.layout();
  124. layout.init_contiguous_stride();
  125. const_cast<HostTensorND&>(value).reset(value.storage(), layout);
  126. }
  127. auto info = alloc();
  128. constexpr int size_threshold = TensorShape::MAX_NDIM;
  129. init(info, {value.layout(), value.comp_node()});
  130. if (value.layout().total_nr_elems() <= size_threshold) {
  131. info->h_value = value;
  132. info->desc.value = value.proxy_to_default_cpu();
  133. }
  134. info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
  135. m_buffer.enqueue(Put{info, value, no_cache});
  136. if (m_async_level == 0) {
  137. sync_impl();
  138. info->desc.comp_node.sync();
  139. auto err = info->desc.comp_node.check_async_error();
  140. mgb_assert(!err, "%s", err->what());
  141. }
  142. return info;
  143. }
  144. Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
  145. MGB_LOCK_GUARD(m_spin);
  146. mgb_assert(check_available(), "Channel already closed");
  147. return reinterpret_cast<Handle>(put_impl(data, hvalue));
  148. }
  149. TensorInfo* ChannelImpl::put_impl(
  150. const DeviceTensorND& data, const HostTensorND& hvalue) {
  151. auto& state = get_channel_state();
  152. auto _ = StackManager::Guard{"Put", &state.stack_manager};
  153. auto info = alloc();
  154. MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
  155. init(info, {data.layout(), data.comp_node()});
  156. info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
  157. info->ptr = Tensor::make(data, hvalue);
  158. MGB_RECORD_EVENT(
  159. TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
  160. data.raw_ptr());
  161. info->status = TensorInfo::Produced;
  162. MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
  163. return info;
  164. }
  165. void ChannelImpl::del(Handle handle) {
  166. MGB_LOCK_GUARD(m_spin);
  167. if (!check_available()) {
  168. return;
  169. }
  170. del_impl(handle);
  171. }
  172. void ChannelImpl::del_impl(Handle handle) {
  173. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  174. auto* info = reinterpret_cast<TensorInfo*>(handle);
  175. m_valid_handle.erase(handle);
  176. m_buffer.enqueue(Del{info});
  177. }
  178. void ChannelImpl::drop(Handle handle) {
  179. MGB_LOCK_GUARD(m_spin);
  180. mgb_assert(check_available(), "Channel already closed");
  181. auto& state = get_channel_state();
  182. if (state.options.enable_drop) {
  183. mgb_assert(
  184. m_valid_handle.find(handle) != m_valid_handle.end(),
  185. "invalid handle: %p", handle);
  186. auto* info = reinterpret_cast<TensorInfo*>(handle);
  187. m_buffer.enqueue(Drop{info});
  188. }
  189. }
  190. void ChannelImpl::dispatch_default_cpu(
  191. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  192. const SmallVector<LogicalTensorDesc>& input_descs,
  193. SmallVector<Handle>* outputs) {
  194. auto& state = get_channel_state();
  195. auto name = op->trait()->make_name(*op);
  196. auto _ = StackManager::Guard(name, &state.stack_manager);
  197. auto [output_descs, validated] =
  198. OpDef::infer_output_attrs_fallible(*op, input_descs);
  199. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  200. SmallVector<DeviceTensorND> input_tensornds;
  201. input_tensornds.reserve(input_descs.size());
  202. CompNode output_cn;
  203. {
  204. MGB_LOCK_GUARD(m_mutex);
  205. for (auto&& info : input_infos) {
  206. auto input_cn = info->desc.comp_node;
  207. if (!output_cn.valid()) {
  208. output_cn = input_cn;
  209. } else {
  210. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  211. }
  212. if (info->ptr && info->ptr->try_get_value()) {
  213. input_tensornds.emplace_back(
  214. info->ptr->get_value().proxy_to_default_cpu());
  215. } else {
  216. // We assign h_value before drop ptr
  217. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  218. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  219. }
  220. }
  221. }
  222. outputs->reserve(output_descs.size());
  223. SmallVector<DeviceTensorND> output_tensornds;
  224. output_tensornds.reserve(output_descs.size());
  225. for (auto&& desc : output_descs) {
  226. // TODO: may conflict with condtake, which need alloc inside
  227. mgb_assert(!desc.layout.is_empty());
  228. // use HostTensorND alloc_host for cuda pinned memory
  229. output_tensornds.emplace_back(
  230. HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  231. }
  232. uint64_t op_id = Profiler::next_id();
  233. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  234. SmallVector<TensorInfo*> output_infos;
  235. output_infos.reserve(output_descs.size());
  236. for (auto&& tensornd : output_tensornds) {
  237. HostTensorND host_tensornd =
  238. HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
  239. // use `put` for consistency
  240. auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
  241. mgb_assert(info->desc.layout.ndim != 0);
  242. output_infos.push_back(info);
  243. outputs->push_back(reinterpret_cast<Handle>(info));
  244. }
  245. auto op_info_getter = [op] {
  246. std::unordered_map<std::string, std::string> op_info;
  247. auto props = OpDef::props(*op);
  248. for (auto&& [key, value] : props) {
  249. op_info[key] = value;
  250. }
  251. return op_info;
  252. };
  253. MGB_RECORD_EVENT(
  254. OpDispatchEvent, op_id, name, op_info_getter, tinfo_to_tid(input_infos),
  255. tinfo_to_tid(output_infos), state.stack_manager.dump());
  256. }
  257. void ChannelImpl::dispatch_kernel(
  258. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  259. const SmallVector<LogicalTensorDesc>& input_descs,
  260. SmallVector<Handle>* outputs) {
  261. auto& state = get_channel_state();
  262. auto& options = state.options;
  263. auto name = op->trait()->make_name(*op);
  264. auto _ = StackManager::Guard{name, &state.stack_manager};
  265. auto [output_descs, validated] =
  266. OpDef::infer_output_attrs_fallible(*op, input_descs);
  267. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  268. ApplyOp cmd{Profiler::next_id(), std::move(op)};
  269. cmd.inputs = std::move(input_infos);
  270. cmd.outputs.reserve(output_descs.size());
  271. outputs->reserve(output_descs.size());
  272. for (int i = 0; i < output_descs.size(); ++i) {
  273. auto&& desc = output_descs[i];
  274. auto info = alloc();
  275. init(info, desc);
  276. // make sure desc's value is consistent with h_value
  277. if (!info->desc.value.empty()) {
  278. info->h_value = HostTensorND::make_proxy(desc.value)
  279. .proxy_to_comp_node(desc.comp_node);
  280. }
  281. cmd.outputs.push_back(info);
  282. outputs->push_back(reinterpret_cast<Handle>(info));
  283. }
  284. auto op_info_getter = [op = cmd.op] {
  285. std::unordered_map<std::string, std::string> op_info;
  286. auto props = OpDef::props(*op);
  287. for (auto&& [key, value] : props) {
  288. op_info[key] = value;
  289. }
  290. return op_info;
  291. };
  292. MGB_RECORD_EVENT(
  293. OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
  294. tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
  295. m_buffer.enqueue(std::move(cmd));
  296. if (!validated && options.async_level == 1) {
  297. sync_impl();
  298. } else if (options.async_level == 0) {
  299. sync_impl();
  300. // check device error
  301. for (auto&& oup : *outputs) {
  302. auto info = reinterpret_cast<TensorInfo*>(oup);
  303. info->ptr->comp_node().sync();
  304. auto err = info->ptr->comp_node().check_async_error();
  305. mgb_assert(!err, "%s", err->what());
  306. }
  307. }
  308. }
  309. SmallVector<Handle> ChannelImpl::apply_op(
  310. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  311. MGB_LOCK_GUARD(m_spin);
  312. mgb_assert(check_available(), "Channel already closed");
  313. return apply_op_impl(std::move(op), inputs);
  314. }
  315. SmallVector<Handle> ChannelImpl::apply_op_impl(
  316. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  317. auto& state = get_channel_state();
  318. for (auto i : inputs) {
  319. mgb_assert(
  320. m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
  321. i);
  322. }
  323. SmallVector<TensorInfo*> input_infos;
  324. input_infos.reserve(inputs.size());
  325. SmallVector<LogicalTensorDesc> input_descs;
  326. input_descs.reserve(inputs.size());
  327. {
  328. MGB_LOCK_GUARD(m_mutex);
  329. for (auto i : inputs) {
  330. auto info = reinterpret_cast<TensorInfo*>(i);
  331. mgb_assert(
  332. !info->invalid,
  333. "an input tensor is unusable due to previous error");
  334. input_infos.push_back(info);
  335. input_descs.push_back(info->desc);
  336. }
  337. }
  338. SmallVector<Handle> outputs;
  339. DispatchMode dispatch_mode = state.options.enable_host_compute
  340. ? OpDef::decide_dispatch_mode(*op, input_descs)
  341. : DispatchMode::KERNEL;
  342. switch (dispatch_mode) {
  343. case DEFAULT_CPU: {
  344. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  345. break;
  346. }
  347. case KERNEL: {
  348. dispatch_kernel(op, input_infos, input_descs, &outputs);
  349. break;
  350. }
  351. }
  352. return outputs;
  353. }
  354. HostTensorND ChannelImpl::get_value(Handle handle) {
  355. MGB_LOCK_GUARD(m_spin);
  356. mgb_assert(check_available(), "Channel already closed");
  357. mgb_assert(
  358. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  359. handle);
  360. auto info = reinterpret_cast<TensorInfo*>(handle);
  361. // donnot use info->value_fetched, it's unsafe
  362. mgb_assert(!info->invalid, "tensor is unusable due to previous error");
  363. return wait_tensor(info, TensorProp::HostValue)->get_value();
  364. }
  365. TensorShape ChannelImpl::get_shape(Handle handle) {
  366. MGB_LOCK_GUARD(m_spin);
  367. mgb_assert(check_available(), "Channel already closed");
  368. mgb_assert(
  369. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  370. handle);
  371. auto info = reinterpret_cast<TensorInfo*>(handle);
  372. if (info->desc.layout.ndim != 0) {
  373. return info->desc.layout;
  374. }
  375. TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
  376. mgb_assert(ret.ndim != 0);
  377. return ret;
  378. }
  379. DType ChannelImpl::get_dtype(Handle handle) {
  380. MGB_LOCK_GUARD(m_spin);
  381. mgb_assert(check_available(), "Channel already closed");
  382. mgb_assert(
  383. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  384. handle);
  385. auto info = reinterpret_cast<TensorInfo*>(handle);
  386. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
  387. auto ret = info->desc.layout.dtype;
  388. mgb_assert(ret.valid());
  389. return ret;
  390. }
  391. CompNode ChannelImpl::get_device(Handle handle) {
  392. MGB_LOCK_GUARD(m_spin);
  393. mgb_assert(check_available(), "Channel already closed");
  394. mgb_assert(
  395. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  396. handle);
  397. auto info = reinterpret_cast<TensorInfo*>(handle);
  398. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
  399. auto ret = info->desc.comp_node;
  400. mgb_assert(ret.valid());
  401. return ret;
  402. }
  403. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  404. MGB_LOCK_GUARD(m_spin);
  405. mgb_assert(check_available(), "Channel already closed");
  406. mgb_assert(
  407. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  408. handle);
  409. auto info = reinterpret_cast<TensorInfo*>(handle);
  410. return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
  411. }
  412. void ChannelImpl::sync() {
  413. MGB_LOCK_GUARD(m_spin);
  414. mgb_assert(check_available(), "Channel already closed");
  415. sync_impl();
  416. }
  417. void ChannelImpl::sync_impl() {
  418. m_buffer.flush();
  419. m_worker.wait_all_task_finish();
  420. MGB_LOCK_GUARD(m_mutex);
  421. check_worker_exc_unsafe();
  422. }
  423. void ChannelImpl::close() {
  424. MGB_LOCK_GUARD(m_spin);
  425. if (!check_available()) {
  426. return;
  427. }
  428. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  429. for (auto* handle : valid_handles) {
  430. del_impl(handle);
  431. }
  432. mgb_assert(m_valid_handle.empty());
  433. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  434. sync_impl();
  435. m_closed = true;
  436. }
  437. size_t ChannelImpl::get_option(std::string name) {
  438. MGB_LOCK_GUARD(m_spin);
  439. mgb_assert(check_available(), "Channel already closed");
  440. auto& state = get_channel_state();
  441. return state.options.get_option(name);
  442. }
  443. void ChannelImpl::set_option(std::string name, size_t value) {
  444. MGB_LOCK_GUARD(m_spin);
  445. mgb_assert(check_available(), "Channel already closed");
  446. auto& state = get_channel_state();
  447. state.options.set_option(name, value);
  448. m_buffer.enqueue(SetOption{name, value});
  449. }
  450. void ChannelImpl::clear_candidates() {
  451. MGB_LOCK_GUARD(m_spin);
  452. mgb_assert(check_available(), "Channel already closed");
  453. m_dtr.candidates.clear();
  454. }
  455. TensorInfo* ChannelImpl::alloc() {
  456. auto& state = get_channel_state();
  457. auto info = [this] {
  458. MGB_LOCK_GUARD(m_mutex);
  459. return m_pool.alloc();
  460. }();
  461. info->id = Profiler::next_id();
  462. if (Profiler::is_profiling()) {
  463. size_t tensor_id = state.stack_manager.current()->next_id("tensor");
  464. info->name =
  465. state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
  466. }
  467. return info;
  468. }
  469. void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
  470. m_valid_handle.insert(reinterpret_cast<Handle>(info));
  471. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  472. info->status = TensorInfo::Allocated;
  473. info->desc = std::move(desc);
  474. info->mem_desc.layout = info->desc.layout;
  475. info->mem_desc.cn = info->desc.comp_node;
  476. info->mem_desc.offset = 0;
  477. }
  478. void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
  479. if (!ptr->producer) {
  480. if (user) {
  481. mgb_log_warn(
  482. "the input that produced tensor %p has been deleted, this drop "
  483. "operation will be ignored",
  484. ptr);
  485. }
  486. return;
  487. }
  488. if (ptr->evict_type != EvictType::NONE) {
  489. return;
  490. }
  491. ptr->evict_type = EvictType::DROP;
  492. ptr->status = TensorInfo::Dropped;
  493. release_tensor(ptr);
  494. }
  495. void ChannelImpl::free(TensorInfo* ptr) {
  496. auto& state = get_worker_state();
  497. if (state.options.enable_dtr_auto_drop) {
  498. // Evicting a tensor, rather than freeing it, can avoid pinning
  499. // potentially exploding amounts of memory and allow us to save
  500. // more memory.
  501. ptr->allow_delete = true;
  502. if (!ptr->ref_cnt) {
  503. recursive_free(ptr);
  504. } else {
  505. do_drop(ptr);
  506. }
  507. } else {
  508. real_free(ptr);
  509. }
  510. }
  511. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  512. MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
  513. SmallVector<TensorInfo*> inps;
  514. if (ptr->producer) {
  515. for (auto i : ptr->producer->inputs) {
  516. if (i && --i->ref_cnt == 0) {
  517. inps.push_back(i);
  518. }
  519. }
  520. }
  521. real_free(ptr);
  522. for (auto i : inps) {
  523. if (i->allow_delete) {
  524. recursive_free(i);
  525. }
  526. }
  527. MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
  528. }
  529. void ChannelImpl::real_free(TensorInfo* ptr) {
  530. auto& state = get_worker_state();
  531. if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  532. m_dtr.erase_candidate(ptr);
  533. }
  534. detach_users(ptr);
  535. ptr->detach_producer();
  536. bool has_value = ptr->ptr != nullptr;
  537. if (has_value) {
  538. MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
  539. }
  540. MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
  541. ptr->status = TensorInfo::Deleted;
  542. MGB_LOCK_GUARD(m_mutex);
  543. m_pool.free(ptr);
  544. }
  545. ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this) {}
  546. ChannelImpl::~ChannelImpl() {
  547. close();
  548. }
  549. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  550. auto& state = get_worker_state();
  551. MGB_LOCK_GUARD(m_mutex);
  552. m_dtr.update_used_time(dest);
  553. MGB_RECORD_EVENT(
  554. TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
  555. ptr->dev_tensor().raw_ptr());
  556. // update tensor desc for static infer
  557. dest->desc.layout = ptr->layout();
  558. dest->desc.comp_node = ptr->comp_node();
  559. dest->memory = ptr->blob()->size();
  560. dest->ptr = std::move(ptr);
  561. dest->evict_type = EvictType::NONE;
  562. dest->status = TensorInfo::Produced;
  563. if (dest->pinned == 0 &&
  564. dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  565. m_dtr.insert_candidate(dest);
  566. }
  567. notify_tensor_unsafe(dest);
  568. }
  569. void ChannelImpl::release_tensor(TensorInfo* dest) {
  570. MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
  571. MGB_LOCK_GUARD(m_mutex);
  572. dest->ptr.reset();
  573. auto& state = get_worker_state();
  574. if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  575. m_dtr.erase_candidate(dest);
  576. }
  577. }
  578. void ChannelImpl::regenerate(TensorInfo* dest) {
  579. if (dest->evict_type == EvictType::DROP) {
  580. auto&& path = dest->producer;
  581. m_apply_stack.push(
  582. {ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest,
  583. "dtr"});
  584. if (!m_applying)
  585. flush_apply_stack();
  586. }
  587. }
  588. void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
  589. using namespace ranges;
  590. using namespace ranges::views;
  591. auto& state = get_worker_state();
  592. bool profiling_device =
  593. Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
  594. uint64_t apply_id = cmd.id;
  595. struct TensorWithDesc {
  596. TensorPtr tensor;
  597. MemoryDesc desc;
  598. };
  599. SmallVector<TensorWithDesc> inputs;
  600. inputs.reserve(cmd.inputs.size());
  601. // refcnt == 1, owners: [TensorInfo::ptr]
  602. for (auto i : cmd.inputs) {
  603. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  604. // refcnt ++, owners: [i->ptr, tensor_inputs]
  605. // tensor_inputs.push_back(i->ptr);
  606. inputs.push_back({i->ptr, i->mem_desc});
  607. }
  608. if (state.options.enable_dtr_auto_drop &&
  609. state.options.dtr_eviction_threshold > 0) {
  610. auto_evict(0);
  611. }
  612. auto apply_on_physical_tensor =
  613. [&](auto&& self, const OpDef& def,
  614. SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
  615. auto apply_functor = [&](std::shared_ptr<OpDef> op,
  616. SmallVector<TensorWithDesc> inputs,
  617. size_t nr_outputs) -> SmallVector<TensorWithDesc> {
  618. auto opname = op->trait()->make_name(*op);
  619. imperative_log_profile_begin(opname.c_str());
  620. auto outputs = self(self, *op, inputs);
  621. imperative_log_profile_end(opname.c_str());
  622. return outputs;
  623. };
  624. auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
  625. return {value, MemoryDesc{
  626. value->layout(), 0, value->comp_node(),
  627. StorageIdentifier::make()}};
  628. };
  629. if (def.trait()->make_forward_graph) {
  630. // apply recursivily
  631. SmallVector<LogicalTensorDesc> input_descs;
  632. for (auto&& input : inputs) {
  633. input_descs.push_back(
  634. {{{}, input.tensor->dtype()}, input.tensor->comp_node()});
  635. }
  636. auto forward_graph = OpDef::make_forward_graph(def, input_descs);
  637. auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
  638. return outputs;
  639. }
  640. SmallVector<TensorPtr> input_tensors;
  641. SmallVector<MemoryDesc> input_descs;
  642. for (auto&& input : inputs) {
  643. input_tensors.push_back(input.tensor);
  644. input_descs.push_back(input.desc);
  645. }
  646. auto [output_descs, output_tensors, workspaces] =
  647. init_output_and_workspace(def, input_tensors, input_descs);
  648. if (!output_descs.empty()) {
  649. OpDef::execute(def, input_tensors, output_tensors, workspaces);
  650. } else {
  651. output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
  652. for (auto&& output_tensor : output_tensors) {
  653. output_descs.push_back(MemoryDesc{
  654. output_tensor->layout(), 0, output_tensor->comp_node(),
  655. StorageIdentifier::make()});
  656. }
  657. }
  658. SmallVector<TensorWithDesc> outputs;
  659. for (auto&& [output_tensor, output_desc] :
  660. ranges::zip_view(output_tensors, output_descs)) {
  661. outputs.push_back({output_tensor, output_desc});
  662. }
  663. return outputs;
  664. };
  665. MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
  666. // Begin profiling operator
  667. SmallVector<std::pair<CompNode, uint64_t>> kernels;
  668. if (profiling_device) {
  669. // Collecting devices
  670. SmallVector<CompNode> devices;
  671. for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
  672. if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
  673. devices.push_back(i->desc.comp_node);
  674. kernels.push_back({i->desc.comp_node, Profiler::next_id()});
  675. }
  676. }
  677. }
  678. for (auto* input : cmd.inputs) {
  679. auto input_id = input->id;
  680. MGB_RECORD_EVENT(OpInputEvent, input_id);
  681. MGB_RECORD_EVENT(TensorUsageEvent, input_id);
  682. MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
  683. }
  684. // Fused by command buffer. @see: CommandBuffer::fuse_del
  685. // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
  686. // tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
  687. // yet but Del would be also fused.
  688. for (auto* del : cmd.dels) {
  689. // refcnt --, owners: [tensor_inputs]
  690. // if it's decreased to 1, would be detected at @see:
  691. // proxy_graph_detail::apply_on_physical_tensor
  692. uint64_t del_id = del->id;
  693. MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
  694. free(del);
  695. MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
  696. }
  697. // Before wait
  698. // TODO: split operator wait and execute so that OpWait could be corrected recorded.
  699. // Before execute
  700. for (auto&& [device, kernel_id] : kernels) {
  701. MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
  702. MGB_RECORD_EVENT_IF(
  703. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  704. Timer::record_device(device));
  705. }
  706. // Apply op
  707. // Here std::move is REQUIRED for removing duplicated references.
  708. auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs);
  709. // After execute
  710. for (auto&& [device, kernel_id] : kernels) {
  711. MGB_RECORD_EVENT_IF(
  712. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  713. Timer::record_device(device));
  714. MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
  715. }
  716. // End profiling operator
  717. mgb_assert(outputs.size() == cmd.outputs.size());
  718. for (size_t i = 0; i < outputs.size(); ++i) {
  719. auto output = cmd.outputs[i];
  720. if (output == nullptr) {
  721. MGB_RECORD_EVENT(OpOutputEvent, 0);
  722. MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
  723. } else if (output->ptr != nullptr) {
  724. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  725. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  726. } else {
  727. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  728. produce_tensor(output, outputs[i].tensor);
  729. output->mem_desc = outputs[i].desc;
  730. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  731. sample_on_device(output->desc.comp_node, false);
  732. }
  733. }
  734. if (state.options.enable_dtr_auto_drop) {
  735. double estimate_compute_time = 0;
  736. for (auto i : cmd.inputs) {
  737. estimate_compute_time += i->memory;
  738. }
  739. for (auto i : outputs) {
  740. estimate_compute_time += i.tensor->blob()->size();
  741. }
  742. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  743. for (auto i : cmd.outputs) {
  744. if (i != nullptr) {
  745. i->compute_time = estimate_compute_time;
  746. }
  747. }
  748. m_dtr.unpin(cmd.inputs, state);
  749. }
  750. MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
  751. // End profiling operator
  752. }
  753. void ChannelImpl::flush_apply_stack() {
  754. m_applying = true;
  755. auto& state = get_worker_state();
  756. while (!m_apply_stack.empty()) {
  757. auto& [cmd, idx, recomp, reason] =
  758. m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
  759. if (idx == 0) {
  760. if (state.options.enable_dtr_auto_drop) {
  761. m_dtr.pin(cmd.inputs);
  762. }
  763. if (recomp) {
  764. MGB_RECORD_EVENT(
  765. TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
  766. }
  767. }
  768. bool regen = false;
  769. for (size_t i = idx; i < cmd.inputs.size(); i++) {
  770. auto&& p = cmd.inputs[i];
  771. if (state.options.enable_dtr_auto_drop) {
  772. m_dtr.update_used_time(p);
  773. }
  774. if (!p->ptr && p->evict_type != EvictType::NONE) {
  775. idx = i + 1;
  776. regenerate(p); // add ApplyOp to the stack
  777. regen = true;
  778. break;
  779. }
  780. }
  781. if (regen)
  782. continue;
  783. // the required input tensors are already in memory
  784. auto [cmd_backup, recomp_backup, reason_backup] =
  785. std::make_tuple(cmd, recomp, reason);
  786. m_apply_stack.pop();
  787. do_apply_op(cmd_backup, reason_backup);
  788. if (recomp_backup) {
  789. MGB_RECORD_EVENT(
  790. TensorCommandFinishEvent, recomp_backup->id,
  791. TensorCommandKind::ReGen);
  792. for (auto o : cmd_backup.outputs) {
  793. if (o) {
  794. m_dtr.update_dsu_after_recompute(o);
  795. }
  796. }
  797. }
  798. }
  799. m_applying = false;
  800. }
  801. bool ChannelImpl::auto_evict(size_t force_num) {
  802. auto& state = get_worker_state();
  803. if (!m_dtr.comp_node.valid()) {
  804. return false;
  805. }
  806. size_t current_memory = m_dtr.comp_node.get_used_memory();
  807. size_t flag = false;
  808. while ((state.options.dtr_eviction_threshold > 0 &&
  809. current_memory > state.options.dtr_eviction_threshold) ||
  810. force_num > 0) {
  811. MGB_RECORD_EVENT(AutoEvictEvent);
  812. sample_on_device(m_dtr.comp_node, false);
  813. auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
  814. if (!best) {
  815. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  816. break;
  817. }
  818. if (best->ptr.unique() && best->ptr->blob().unique()) {
  819. current_memory -= best->memory;
  820. if (force_num > 0) {
  821. force_num--;
  822. }
  823. flag = true;
  824. }
  825. do_drop(best);
  826. if (best->evict_type == EvictType::DROP) {
  827. m_dtr.update_dsu_after_evict(best);
  828. }
  829. sample_on_device(m_dtr.comp_node, false);
  830. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  831. }
  832. return flag;
  833. }
  834. void ChannelImpl::detach_users(TensorInfo* dest) {
  835. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  836. for (auto* user : users) {
  837. SmallVector<TensorInfo*> outputs = user->outputs;
  838. SmallVector<TensorInfo*> inputs = user->inputs;
  839. for (auto* output : outputs) {
  840. // When a `ComputePath` is detach from it's input,
  841. // there is no need to reserve it,
  842. // so we detach all output of this path
  843. // to decrease it's `ref_cnt` to zero.
  844. if (output == nullptr) {
  845. continue;
  846. }
  847. regenerate(output);
  848. output->detach_producer();
  849. for (auto* input : inputs) {
  850. input->ref_cnt--;
  851. }
  852. }
  853. // now user is dead
  854. }
  855. mgb_assert(dest->users.empty(), "ComputePath leaking");
  856. }
  857. bool ChannelImpl::check_available() {
  858. return !m_closed;
  859. }
  860. TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
  861. m_buffer.flush();
  862. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  863. mgb_assert(!m_waitee, "duplicate waitee");
  864. m_waitee = info;
  865. m_waitee_id = Profiler::next_id();
  866. MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
  867. bool require_host = prop == TensorProp::HostValue;
  868. auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
  869. bool wait_host = false;
  870. if (require_host && !host_available()) {
  871. // avoid dead lock
  872. lock.unlock();
  873. m_buffer.enqueue(GetValue{info});
  874. m_buffer.flush();
  875. lock.lock();
  876. wait_host = true;
  877. }
  878. m_cv.wait(lock, [&]() {
  879. check_worker_exc_unsafe();
  880. return require_host ? host_available() : static_cast<bool>(info->ptr);
  881. });
  882. MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
  883. m_waitee = nullptr;
  884. if (wait_host) {
  885. auto err = info->ptr->comp_node().check_async_error();
  886. mgb_assert(!err, "%s", err->what());
  887. }
  888. return info->ptr;
  889. }
  890. void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
  891. if (info == m_waitee) {
  892. MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
  893. m_cv.notify_all();
  894. }
  895. }
  896. std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
  897. std::unordered_set<TensorInfo*> valid_tensors;
  898. for (auto* handle : m_valid_handle) {
  899. auto* info = reinterpret_cast<TensorInfo*>(handle);
  900. valid_tensors.insert(info);
  901. }
  902. return valid_tensors;
  903. }
  904. void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
  905. auto reserve_size = [&](size_t size) {
  906. if (!m_dtr.comp_node.valid()) {
  907. return false;
  908. }
  909. while (size > m_dtr.comp_node.get_max_block_size_available()) {
  910. bool evict_suc = auto_evict(1);
  911. if (!evict_suc)
  912. return false;
  913. }
  914. return true;
  915. };
  916. auto pre_level = set_log_level(LogLevel::NO_LOG);
  917. reserve_size(x->size());
  918. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  919. MGB_CATCH(MemAllocError&, {
  920. bool suc = false;
  921. while (!suc) {
  922. if (!auto_evict(1)) {
  923. break;
  924. }
  925. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  926. MGB_CATCH(MemAllocError&, { continue; });
  927. suc = true;
  928. }
  929. if (!suc) {
  930. set_log_level(pre_level);
  931. mgb_log_warn(
  932. "reallocating all cuda memory to alleviate fragmentation, the "
  933. "performance may be affected");
  934. set_log_level(LogLevel::NO_LOG);
  935. imperative_log_profile_begin("defrag");
  936. BlobManager::inst()->defrag(x->comp_node());
  937. imperative_log_profile_end("defrag");
  938. BlobManager::inst()->alloc_direct(x, x->size());
  939. }
  940. });
  941. set_log_level(pre_level);
  942. }
  943. std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
  944. ChannelImpl::init_output_and_workspace(
  945. const OpDef& def, SmallVector<TensorPtr> inputs,
  946. SmallVector<MemoryDesc> inputs_mem_desc) {
  947. auto [outputs_desc, workspaces_desc] =
  948. OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
  949. if (!outputs_desc.size()) {
  950. // failed to infer memplan
  951. return {{}, {}, {}};
  952. }
  953. // refine storage id to make it unique
  954. for (auto&& desc : outputs_desc) {
  955. if (desc.id->is_sys_alloc()) {
  956. // TODO: there may be some outputs sharing the same storage id
  957. desc.id->id = ++m_storage_id;
  958. }
  959. }
  960. auto& state = get_worker_state();
  961. auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
  962. SmallVector<TensorPtr> tensors;
  963. for (size_t i = 0; i < desc.size(); i++) {
  964. if (desc[i].id->is_sys_alloc()) {
  965. tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
  966. if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
  967. alloc_tensor_with_evict(tensors.back()->blob().get());
  968. }
  969. } else if (desc[i].id->is_from_other()) {
  970. for (size_t j = 0; j < inputs_mem_desc.size(); j++) {
  971. if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
  972. tensors.push_back(
  973. inputs[j]->sub(desc[i].offset, desc[i].layout));
  974. break;
  975. }
  976. }
  977. } else if (desc[i].id->is_device_ptr()) {
  978. tensors.push_back(desc[i].id->ptr);
  979. } else {
  980. mgb_assert(0, "not implemented");
  981. }
  982. }
  983. return tensors;
  984. };
  985. return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
  986. }
  987. void ChannelImpl::process_one_task(Command& icmd) {
  988. using namespace ranges;
  989. using namespace ranges::views;
  990. auto& state = get_worker_state();
  991. auto& options = state.options;
  992. // TODO: remove std::visit for support osx 10.12
  993. auto cmd_visitor = [&](const auto& cmd) {
  994. using T = std::decay_t<decltype(cmd)>;
  995. if constexpr (std::is_same_v<T, Put>) {
  996. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
  997. MGB_RECORD_EVENT_IF(
  998. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  999. Timer::record_device(cmd.value.comp_node()));
  1000. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
  1001. : Tensor::make(cmd.value);
  1002. MGB_RECORD_EVENT_IF(
  1003. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1004. Timer::record_device(cmd.value.comp_node()));
  1005. produce_tensor(cmd.dest, std::move(value));
  1006. MGB_RECORD_EVENT(
  1007. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
  1008. sample_on_device(cmd.dest->desc.comp_node, false);
  1009. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  1010. for (auto& i : cmd.inputs) {
  1011. if (i->invalid) {
  1012. MGB_LOCK_GUARD(m_mutex);
  1013. for (auto& i : cmd.outputs) {
  1014. i->invalid = true;
  1015. }
  1016. return;
  1017. }
  1018. }
  1019. m_apply_stack.push({cmd, 0, nullptr, "cmd"});
  1020. flush_apply_stack();
  1021. for (size_t i = 0; i < cmd.outputs.size(); ++i) {
  1022. auto output = cmd.outputs[i];
  1023. if (output == nullptr) {
  1024. continue;
  1025. }
  1026. if (state.options.enable_dtr_auto_drop) {
  1027. output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
  1028. }
  1029. }
  1030. if (state.options.enable_drop && state.options.record_computing_path) {
  1031. auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
  1032. auto& input = std::get<0>(tuple2);
  1033. auto& output = std::get<1>(tuple2);
  1034. if (!input->ptr || !output->ptr) {
  1035. return false;
  1036. }
  1037. return input->ptr->blob()->storage() ==
  1038. output->ptr->blob()->storage();
  1039. };
  1040. // FIXME: do not use opname as identifier
  1041. auto get_name = [](const OpDef& opdef) {
  1042. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  1043. return attr->type.c_str();
  1044. }
  1045. return opdef.dyn_typeinfo()->name;
  1046. };
  1047. auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
  1048. return info->desc.comp_node != comp_node;
  1049. };
  1050. bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
  1051. bool inplace =
  1052. any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
  1053. if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  1054. TensorInfo::ComputePath::make(
  1055. cmd.id, cmd.op, cmd.inputs, cmd.outputs);
  1056. size_t detach_cnt = 0;
  1057. if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
  1058. cmd.outputs.size() == 5) {
  1059. cmd.outputs[0]->detach_producer(); // detach running_mean
  1060. cmd.outputs[1]->detach_producer(); // detach running_var
  1061. for (auto input : cmd.inputs) {
  1062. input->ref_cnt -= 2;
  1063. }
  1064. }
  1065. for (auto output : cmd.outputs) {
  1066. if (output->producer &&
  1067. !output->size_exceeds_thd(
  1068. state.options.dtr_evictee_minimum_size)) {
  1069. output->detach_producer();
  1070. detach_cnt++;
  1071. }
  1072. }
  1073. for (auto input : cmd.inputs) {
  1074. input->ref_cnt -= detach_cnt;
  1075. }
  1076. }
  1077. }
  1078. } else if constexpr (std::is_same_v<T, Del>) {
  1079. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
  1080. CompNode device = cmd.dest->desc.comp_node;
  1081. uint64_t tensor_id = cmd.dest->id;
  1082. free(cmd.dest);
  1083. MGB_RECORD_EVENT(
  1084. TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
  1085. sample_on_device(device, false);
  1086. } else if constexpr (std::is_same_v<T, GetValue>) {
  1087. if (cmd.dest->invalid)
  1088. return;
  1089. imperative_log_profile_begin("GetValue");
  1090. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  1091. regenerate(cmd.dest);
  1092. }
  1093. cmd.dest->ptr->fetch_value();
  1094. MGB_LOCK_GUARD(m_mutex);
  1095. notify_tensor_unsafe(cmd.dest);
  1096. imperative_log_profile_end("GetValue");
  1097. } else if constexpr (std::is_same_v<T, Drop>) {
  1098. if (cmd.dest->invalid)
  1099. return;
  1100. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
  1101. do_drop(cmd.dest, true);
  1102. MGB_RECORD_EVENT(
  1103. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
  1104. } else if constexpr (std::is_same_v<T, SetOption>) {
  1105. options.set_option(cmd.key, cmd.value);
  1106. } else if constexpr (std::is_same_v<T, StartProfile>) {
  1107. MGB_RECORD_EVENT(StartProfileEvent);
  1108. CompNode::sync_all();
  1109. for (auto* info : cmd.capture_tensors) {
  1110. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  1111. if (info->status == TensorInfo::Produced) {
  1112. // TODO: handle drop
  1113. MGB_RECORD_EVENT(
  1114. TensorProduceEvent, info->id, info->desc.layout,
  1115. info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
  1116. }
  1117. }
  1118. CompNode::foreach ([&](CompNode device) {
  1119. sample_on_device(device, true);
  1120. MGB_RECORD_EVENT_IF(
  1121. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1122. Timer::record_device(device));
  1123. });
  1124. MGB_RECORD_EVENT(StartProfileFinishEvent);
  1125. } else if constexpr (std::is_same_v<T, StopProfile>) {
  1126. MGB_RECORD_EVENT(StopProfileEvent);
  1127. for (auto* info : cmd.escape_tensors) {
  1128. bool has_value = info->status == TensorInfo::Produced;
  1129. if (has_value) {
  1130. MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
  1131. }
  1132. MGB_RECORD_EVENT(TensorEraseEvent, info->id);
  1133. }
  1134. CompNode::foreach (
  1135. [&](CompNode device) { sample_on_device(device, true); });
  1136. MGB_RECORD_EVENT(StopProfileFinishEvent);
  1137. } else if constexpr (std::is_same_v<T, PushScope>) {
  1138. MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
  1139. } else if constexpr (std::is_same_v<T, PopScope>) {
  1140. MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
  1141. } else {
  1142. static_assert(!std::is_same_v<T, T>);
  1143. }
  1144. };
  1145. std::visit(
  1146. [&](const auto& cmd) {
  1147. using T = std::decay_t<decltype(cmd)>;
  1148. if (!options.catch_worker_execption) {
  1149. cmd_visitor(cmd);
  1150. return;
  1151. }
  1152. try {
  1153. cmd_visitor(cmd);
  1154. } catch (...) {
  1155. MGB_LOCK_GUARD(m_mutex);
  1156. if constexpr (std::is_same_v<T, ApplyOp>) {
  1157. for (auto oup : cmd.outputs) {
  1158. oup->invalid = true;
  1159. }
  1160. } else if constexpr (std::is_same_v<T, Put>) {
  1161. cmd.dest->invalid = true;
  1162. }
  1163. m_worker_exc = std::current_exception();
  1164. MGB_RECORD_EVENT(WorkerExceptionEvent);
  1165. if (m_waitee) {
  1166. notify_tensor_unsafe(m_waitee);
  1167. }
  1168. }
  1169. },
  1170. icmd.data);
  1171. }
  1172. void ChannelImpl::check_worker_exc_unsafe() {
  1173. if (m_worker_exc) {
  1174. // for reuse interpreter_for_py after some exception tests
  1175. m_waitee = nullptr;
  1176. std::exception_ptr exc;
  1177. std::swap(exc, m_worker_exc);
  1178. try {
  1179. std::rethrow_exception(exc);
  1180. } catch (...) {
  1181. throw AsyncError();
  1182. }
  1183. }
  1184. }
  1185. void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
  1186. auto& state = m_owner->get_channel_state();
  1187. if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
  1188. return;
  1189. }
  1190. m_commands.push_back(
  1191. {Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
  1192. auto flush_pos = flush_pos_for(m_commands.back());
  1193. flush(flush_pos);
  1194. }
  1195. void ChannelImpl::CommandBuffer::flush() {
  1196. flush(m_commands.end());
  1197. }
  1198. void ChannelImpl::CommandBuffer::flush(Handle pos) {
  1199. for (auto iter = m_commands.begin(); iter != pos; ++iter) {
  1200. if (Profiler::is_profiling()) {
  1201. mgb_log_debug("%s Flushed", to_string(*iter).c_str());
  1202. }
  1203. m_owner->m_worker.add_task(std::move(*iter));
  1204. }
  1205. m_commands.erase(m_commands.begin(), pos);
  1206. }
  1207. auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
  1208. auto& state = m_owner->get_channel_state();
  1209. return std::visit(
  1210. [this, &state](const auto& cmd) {
  1211. using T = std::decay_t<decltype(cmd)>;
  1212. if constexpr (std::is_same_v<T, ApplyOp>) {
  1213. auto* op_type = cmd.op->dyn_typeinfo();
  1214. if (op_type == RemoteRecv::typeinfo() ||
  1215. op_type == RemoteSend::typeinfo() ||
  1216. op_type == CollectiveComm::typeinfo() ||
  1217. op_type == opr::InputCallback::typeinfo() ||
  1218. op_type == opr::OutputCallback::typeinfo()) {
  1219. return m_commands.end();
  1220. }
  1221. } else if constexpr (std::is_same_v<T, GetValue>) {
  1222. return m_commands.end();
  1223. }
  1224. size_t buffer_length = state.options.buffer_length;
  1225. if (m_commands.size() > buffer_length) {
  1226. return m_commands.begin() + (m_commands.size() - buffer_length);
  1227. }
  1228. return m_commands.begin();
  1229. },
  1230. cmd.data);
  1231. }
  1232. /**
  1233. * 1. Find ApplyOp(dest) in buffered commands
  1234. * 2. Check if there are other usages between ApplyOp and Del, return false if not
  1235. * 3. Fuse Del into ApplyOp, return true
  1236. */
  1237. bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
  1238. auto* dest = cmd.dest;
  1239. // TODO: eliminate Puts
  1240. auto begin = m_commands.begin(), end = m_commands.end();
  1241. auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd) {
  1242. if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
  1243. return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
  1244. }
  1245. return false;
  1246. });
  1247. if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
  1248. return false;
  1249. }
  1250. std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
  1251. return true;
  1252. }
  1253. auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
  1254. -> Handle {
  1255. auto found = range[1];
  1256. for (auto iter = range[0]; iter != range[1]; ++iter) {
  1257. std::visit(
  1258. [&](const auto& cmd) {
  1259. using T = std::decay_t<decltype(cmd)>;
  1260. if constexpr (std::is_same_v<T, ApplyOp>) {
  1261. if (std::count(cmd.inputs.begin(), cmd.inputs.end(), dest) >
  1262. 0) {
  1263. found = iter;
  1264. }
  1265. } else if constexpr (std::is_same_v<T, GetValue>) {
  1266. if (cmd.dest == dest) {
  1267. found = iter;
  1268. }
  1269. } else if constexpr (std::is_same_v<T, Drop>) {
  1270. // TODO: ignore swap-like commands, just remove them from buffer
  1271. if (cmd.dest == dest) {
  1272. found = iter;
  1273. }
  1274. }
  1275. },
  1276. iter->data);
  1277. };
  1278. return found;
  1279. }
  1280. auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) -> Handle {
  1281. return std::find_if(range[0], range[1], [dest](auto& cmd) {
  1282. return std::visit(
  1283. [dest](const auto& cmd) {
  1284. using T = std::decay_t<decltype(cmd)>;
  1285. if constexpr (std::is_same_v<T, ApplyOp>) {
  1286. return std::count(
  1287. cmd.outputs.begin(), cmd.outputs.end(), dest) >
  1288. 0;
  1289. } else if constexpr (std::is_same_v<T, Put>) {
  1290. return cmd.dest == dest;
  1291. }
  1292. return false;
  1293. },
  1294. cmd.data);
  1295. });
  1296. }
  1297. void ChannelImpl::start_profile() {
  1298. MGB_LOCK_GUARD(m_spin);
  1299. mgb_assert(check_available(), "Channel already closed");
  1300. auto capture_tensors = collect_valid_tensors();
  1301. if (capture_tensors.size() > 0) {
  1302. m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
  1303. }
  1304. }
  1305. void ChannelImpl::stop_profile() {
  1306. MGB_LOCK_GUARD(m_spin);
  1307. mgb_assert(check_available(), "Channel already closed");
  1308. m_buffer.flush();
  1309. auto escape_tensors = collect_valid_tensors();
  1310. if (escape_tensors.size() > 0) {
  1311. m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
  1312. }
  1313. }
  1314. void ChannelImpl::push_scope(std::string name) {
  1315. MGB_LOCK_GUARD(m_spin);
  1316. mgb_assert(check_available(), "Channel already closed");
  1317. auto& state = get_channel_state();
  1318. state.stack_manager.enter(name);
  1319. MGB_RECORD_EVENT(ScopeEvent, name);
  1320. m_buffer.enqueue(PushScope{name});
  1321. }
  1322. void ChannelImpl::pop_scope(std::string name) {
  1323. MGB_LOCK_GUARD(m_spin);
  1324. mgb_assert(check_available(), "Channel already closed");
  1325. auto& state = get_channel_state();
  1326. state.stack_manager.exit(name);
  1327. MGB_RECORD_EVENT(ScopeFinishEvent, name);
  1328. m_buffer.enqueue(PopScope{name});
  1329. }
  1330. void ChannelImpl::assert_in_channel() {
  1331. mgb_assert(
  1332. get_worker_tid() != std::this_thread::get_id(),
  1333. "this method cannot be called in worker thread");
  1334. }
  1335. void ChannelImpl::assert_in_worker() {
  1336. mgb_assert(
  1337. get_worker_tid() == std::this_thread::get_id(),
  1338. "this method can only be called in worker thread");
  1339. }
  1340. void ChannelImpl::sample_on_device(CompNode device, bool force) {
  1341. if (!force) {
  1342. thread_local int last_sample_id = 0;
  1343. int sample_rate =
  1344. Profiler::is_profiling() ? Profiler::get_option("sample_rate", 0) : 0;
  1345. if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
  1346. return;
  1347. }
  1348. }
  1349. MGB_RECORD_EVENT(SampleDeviceEvent, device);
  1350. auto [total, free] = device.get_mem_status_bytes();
  1351. MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
  1352. }
  1353. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  1354. for (auto i : vec) {
  1355. i->pin();
  1356. erase_candidate(i);
  1357. }
  1358. }
  1359. void ChannelImpl::DynamicSublinear::unpin(
  1360. const SmallVector<TensorInfo*>& vec, WorkerState& state) {
  1361. for (auto i : vec) {
  1362. i->unpin();
  1363. if (i->pinned == 0 &&
  1364. i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
  1365. i->cand_index == UINT_MAX) {
  1366. insert_candidate(i);
  1367. }
  1368. }
  1369. }
  1370. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  1371. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  1372. dsu_fa->t -= ptr->compute_time;
  1373. ptr->dsu_ptr->parent.reset();
  1374. ptr->dsu_ptr->t = ptr->compute_time;
  1375. }
  1376. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  1377. for (auto i : ptr->producer->inputs) {
  1378. if (i->evict_type == EvictType::DROP) {
  1379. merge(i->dsu_ptr, ptr->dsu_ptr);
  1380. }
  1381. }
  1382. for (auto i : ptr->producer->outputs) {
  1383. if (i && i->evict_type == EvictType::DROP) {
  1384. merge(ptr->dsu_ptr, i->dsu_ptr);
  1385. }
  1386. }
  1387. }
  1388. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1389. double cost = 0;
  1390. for (auto i : ptr->producer->inputs) {
  1391. if (i->evict_type == EvictType::DROP) {
  1392. double t = find_father(i->dsu_ptr)->t;
  1393. if (t < i->compute_time) {
  1394. t = i->compute_time;
  1395. }
  1396. cost += t;
  1397. }
  1398. }
  1399. for (auto i : ptr->producer->outputs) {
  1400. if (i && i->evict_type == EvictType::DROP) {
  1401. double t = find_father(i->dsu_ptr)->t;
  1402. if (t < i->compute_time) {
  1403. t = i->compute_time;
  1404. }
  1405. cost += t;
  1406. }
  1407. }
  1408. return cost;
  1409. }
  1410. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
  1411. bool enable_dtr_sqrt_sampling = false) {
  1412. if (candidates.empty())
  1413. return nullptr;
  1414. double min_msps = -1;
  1415. TensorInfo* best = nullptr;
  1416. size_t sz = 1;
  1417. if (enable_dtr_sqrt_sampling) {
  1418. while (sz * sz <= candidates.size())
  1419. sz++;
  1420. sz--;
  1421. } else {
  1422. sz = candidates.size();
  1423. }
  1424. size_t ti = rand() % sz;
  1425. for (size_t vi = 0; vi < sz; vi++) {
  1426. if (!enable_dtr_sqrt_sampling) {
  1427. ti = vi;
  1428. }
  1429. auto i = candidates[ti];
  1430. if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
  1431. double neighbor_cost = estimate_neighbor_cost(i);
  1432. size_t begin_ptr =
  1433. reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1434. auto side_info = i->ptr->comp_node().get_free_left_and_right(
  1435. begin_ptr, begin_ptr + i->ptr->blob()->size());
  1436. double free_mem = side_info.first + side_info.second;
  1437. double msps = i->eval_func(
  1438. neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1439. if (min_msps < 0 || msps < min_msps) {
  1440. min_msps = msps;
  1441. best = i;
  1442. }
  1443. }
  1444. if (enable_dtr_sqrt_sampling) {
  1445. ti += rand() % sz;
  1446. if (ti > candidates.size())
  1447. break;
  1448. }
  1449. }
  1450. return best;
  1451. }
  1452. void ChannelImpl::DynamicSublinear::merge(
  1453. std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
  1454. auto&& f_x = find_father(x);
  1455. auto&& f_y = find_father(y);
  1456. if (f_x.get() == f_y.get()) {
  1457. return;
  1458. }
  1459. f_y->t += f_x->t;
  1460. f_x->parent = f_y;
  1461. }
  1462. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
  1463. std::shared_ptr<DsuNode>& x) {
  1464. if (x->is_root()) {
  1465. return x;
  1466. } else {
  1467. auto&& fa = find_father(x->parent);
  1468. return x->parent = fa;
  1469. }
  1470. }
  1471. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1472. // tensor to be inserted must be brand new
  1473. mgb_assert(
  1474. ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
  1475. ptr->cand_index);
  1476. ptr->cand_index = candidates.size();
  1477. candidates.push_back(ptr);
  1478. if (!comp_node.valid()) {
  1479. comp_node = ptr->ptr->comp_node();
  1480. }
  1481. }
  1482. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1483. // close dtr will just clear candidates, so nothing to erase
  1484. if (candidates.empty()) {
  1485. ptr->cand_index = UINT_MAX;
  1486. return;
  1487. }
  1488. // some tensors may be erased already, just skip them
  1489. if (ptr->cand_index != UINT_MAX) {
  1490. std::swap(candidates[ptr->cand_index], candidates.back());
  1491. candidates[ptr->cand_index]->cand_index = ptr->cand_index;
  1492. candidates.pop_back();
  1493. ptr->cand_index = UINT_MAX;
  1494. }
  1495. }
  1496. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1497. ptr->last_used_time = estimate_timestamp;
  1498. }