You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 56 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "range/v3/all.hpp"
  13. #include "megbrain/common.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/ops/utility.h"
  19. #include "megbrain/imperative/utils/stats.h"
  20. #include "megbrain/imperative/utils/to_string.h"
  21. #include "../blob_manager_impl.h"
  22. #include "../event_pool.h"
  23. #include "../op_trait.h"
  24. using namespace mgb;
  25. using namespace imperative;
  26. using namespace interpreter;
  27. using namespace interpreter::intl;
  28. namespace {
  29. auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
  30. SmallVector<uint64_t> tid;
  31. for (auto* ptinfo : tinfo) {
  32. tid.push_back(ptinfo->id);
  33. }
  34. return tid;
  35. };
  36. } // namespace
  37. namespace mgb {
  38. using namespace profiler;
  39. }
  40. #if defined(_WIN32) || defined(_WIN64)
  41. #define SYMBOL_EXPORT __declspec(dllexport)
  42. #else
  43. #define SYMBOL_EXPORT __attribute__((visibility("default")))
  44. #endif
  45. namespace mgb {
  46. /**
  47. * USAGE
  48. *
  49. * header:
  50. * namespace mgb { void imperative_log_profile(const char* message); }
  51. *
  52. * code:
  53. * mgb::imperative_log_profile("MY MESSAGE");
  54. *
  55. **/
  56. SYMBOL_EXPORT
  57. void imperative_log_profile_begin(const char* message) {
  58. MGB_RECORD_EVENT(CustomEvent, std::string{message});
  59. }
  60. SYMBOL_EXPORT
  61. void imperative_log_profile_end(const char* message) {
  62. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
  63. }
  64. SYMBOL_EXPORT
  65. void imperative_log_profile(const char* message) {
  66. imperative_log_profile_begin(message);
  67. imperative_log_profile_end(message);
  68. }
  69. SYMBOL_EXPORT
  70. void imperative_log_profile_begin(const char* message, const char* device) {
  71. auto comp_node = CompNode::load(device);
  72. MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
  73. MGB_RECORD_EVENT(
  74. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  75. }
  76. SYMBOL_EXPORT
  77. void imperative_log_profile_end(const char* message, const char* device) {
  78. auto comp_node = CompNode::load(device);
  79. MGB_RECORD_EVENT(
  80. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  81. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
  82. }
  83. } // namespace mgb
  84. std::thread::id ChannelImpl::get_worker_tid() {
  85. return m_worker_state.tid;
  86. }
  87. ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
  88. assert_in_channel();
  89. return m_channel_state;
  90. }
  91. ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
  92. assert_in_worker();
  93. return m_worker_state;
  94. }
  95. void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
  96. sys::set_thread_name("worker");
  97. m_owner->m_worker_state.tid = std::this_thread::get_id();
  98. auto custom_allocator = [&](CompNode device, size_t size) {
  99. auto blob = Blob::make(device, size);
  100. m_owner->alloc_tensor_with_evict(blob.get());
  101. return blob->storage();
  102. };
  103. OpDef::set_allocator(custom_allocator);
  104. }
  105. // Do not use m_xxx_state directly
  106. #define m_channel_state
  107. #define m_worker_state
  108. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  109. return std::make_unique<ChannelImpl>();
  110. }
  111. Interpreter& Interpreter::inst() {
  112. static InterpreterImpl inst_;
  113. return inst_;
  114. }
  115. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  116. MGB_LOCK_GUARD(m_spin);
  117. mgb_assert(check_available(), "Channel already closed");
  118. std::optional<StackManager::Guard> guard;
  119. if (Profiler::is_profiling()) {
  120. auto& state = get_channel_state();
  121. guard.emplace("Put", &state.stack_manager);
  122. }
  123. auto info = put_impl(value, no_cache);
  124. return reinterpret_cast<Handle>(info);
  125. }
  126. TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
  127. if (value.empty()) {
  128. auto layout = value.layout();
  129. layout.init_contiguous_stride();
  130. const_cast<HostTensorND&>(value).reset(value.storage(), layout);
  131. }
  132. auto info = alloc();
  133. constexpr int size_threshold = TensorShape::MAX_NDIM;
  134. init(info, {value.layout(), value.comp_node()});
  135. if (value.layout().total_nr_elems() <= size_threshold) {
  136. info->h_value = value;
  137. info->desc.value = value.proxy_to_default_cpu();
  138. }
  139. if (Profiler::is_profiling()) {
  140. m_worker.add_task(
  141. {Profiler::next_id(), Put{info, value, no_cache},
  142. get_channel_state().stack_manager.dump()});
  143. } else {
  144. m_worker.add_task({
  145. Profiler::next_id(),
  146. Put{info, value, no_cache},
  147. });
  148. }
  149. if (m_async_level == 0) {
  150. sync_impl();
  151. info->desc.comp_node.sync();
  152. auto err = info->desc.comp_node.check_async_error();
  153. mgb_assert(!err, "%s", err->what());
  154. }
  155. return info;
  156. }
  157. Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
  158. MGB_LOCK_GUARD(m_spin);
  159. mgb_assert(check_available(), "Channel already closed");
  160. return reinterpret_cast<Handle>(put_impl(data, hvalue));
  161. }
  162. TensorInfo* ChannelImpl::put_impl(
  163. const DeviceTensorND& data, const HostTensorND& hvalue) {
  164. std::optional<StackManager::Guard> guard;
  165. if (Profiler::is_profiling()) {
  166. auto& state = get_channel_state();
  167. guard.emplace("Put", &state.stack_manager);
  168. }
  169. auto info = alloc();
  170. MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
  171. constexpr int size_threshold = TensorShape::MAX_NDIM;
  172. init(info, {data.layout(), data.comp_node()});
  173. if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
  174. info->desc.value = hvalue.proxy_to_default_cpu();
  175. }
  176. info->ptr = Tensor::make(data, hvalue);
  177. MGB_RECORD_EVENT(
  178. TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
  179. data.raw_ptr());
  180. info->status = TensorInfo::Produced;
  181. MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
  182. return info;
  183. }
  184. void ChannelImpl::del(Handle handle) {
  185. MGB_LOCK_GUARD(m_spin);
  186. if (!check_available()) {
  187. return;
  188. }
  189. del_impl(handle);
  190. }
  191. void ChannelImpl::del_impl(Handle handle) {
  192. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  193. auto* info = reinterpret_cast<TensorInfo*>(handle);
  194. m_valid_handle.erase(handle);
  195. if (Profiler::is_profiling()) {
  196. m_worker.add_task(
  197. {Profiler::next_id(), Del{info},
  198. get_channel_state().stack_manager.dump()});
  199. } else {
  200. m_worker.add_task({
  201. Profiler::next_id(),
  202. Del{info},
  203. });
  204. }
  205. }
  206. void ChannelImpl::drop(Handle handle) {
  207. MGB_LOCK_GUARD(m_spin);
  208. mgb_assert(check_available(), "Channel already closed");
  209. auto& state = get_channel_state();
  210. if (state.options.enable_drop) {
  211. mgb_assert(
  212. m_valid_handle.find(handle) != m_valid_handle.end(),
  213. "invalid handle: %p", handle);
  214. auto* info = reinterpret_cast<TensorInfo*>(handle);
  215. if (Profiler::is_profiling()) {
  216. m_worker.add_task(
  217. {Profiler::next_id(), Drop{info},
  218. get_channel_state().stack_manager.dump()});
  219. } else {
  220. m_worker.add_task({
  221. Profiler::next_id(),
  222. Drop{info},
  223. });
  224. }
  225. }
  226. }
  227. void ChannelImpl::dispatch_default_cpu(
  228. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  229. const SmallVector<LogicalTensorDesc>& input_descs,
  230. SmallVector<Handle>* outputs) {
  231. auto& state = get_channel_state();
  232. std::optional<StackManager::Guard> guard;
  233. if (Profiler::is_profiling()) {
  234. guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
  235. }
  236. auto [output_descs, validated] =
  237. OpDef::infer_output_attrs_fallible(*op, input_descs);
  238. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  239. SmallVector<DeviceTensorND> input_tensornds;
  240. CompNode output_cn;
  241. {
  242. MGB_LOCK_GUARD(m_mutex);
  243. for (auto&& info : input_infos) {
  244. auto input_cn = info->desc.comp_node;
  245. if (!output_cn.valid()) {
  246. output_cn = input_cn;
  247. } else {
  248. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  249. }
  250. if (info->ptr && info->ptr->try_get_value()) {
  251. input_tensornds.emplace_back(
  252. info->ptr->get_value().proxy_to_default_cpu());
  253. } else {
  254. // We assign h_value before drop ptr
  255. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  256. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  257. }
  258. }
  259. }
  260. SmallVector<DeviceTensorND> output_tensornds;
  261. for (auto&& desc : output_descs) {
  262. // TODO: may conflict with condtake, which need alloc inside
  263. mgb_assert(!desc.layout.is_empty());
  264. // use HostTensorND alloc_host for cuda pinned memory
  265. output_tensornds.emplace_back(
  266. HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  267. }
  268. uint64_t op_id = Profiler::next_id();
  269. if (op->trait()->apply_on_device_tensornd) {
  270. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  271. } else {
  272. // proxy to apply_on_physical_tensor
  273. SmallVector<TensorPtr> input_tensors;
  274. for (auto&& input_tensornd : input_tensornds) {
  275. input_tensors.push_back(Tensor::make(
  276. input_tensornd, HostTensorND::make_proxy(input_tensornd)));
  277. }
  278. auto output_tensors = OpDef::apply_on_physical_tensor(
  279. *op, input_tensors, output_descs, validated);
  280. for (size_t i = 0; i < output_tensors.size(); ++i) {
  281. output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
  282. }
  283. }
  284. SmallVector<TensorInfo*> output_infos;
  285. for (auto&& tensornd : output_tensornds) {
  286. HostTensorND host_tensornd =
  287. HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
  288. // use `put` for consistency
  289. auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
  290. mgb_assert(info->desc.layout.ndim != 0);
  291. output_infos.push_back(info);
  292. outputs->push_back(reinterpret_cast<Handle>(info));
  293. }
  294. auto op_info_getter = [op] {
  295. std::unordered_map<std::string, std::string> op_info;
  296. auto props = OpDef::props(*op);
  297. for (auto&& [key, value] : props) {
  298. op_info[key] = value;
  299. }
  300. return op_info;
  301. };
  302. MGB_RECORD_EVENT(
  303. OpDispatchEvent, op_id, guard.value().name(), op_info_getter,
  304. tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
  305. state.stack_manager.dump());
  306. }
  307. void ChannelImpl::dispatch_kernel(
  308. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  309. const SmallVector<LogicalTensorDesc>& input_descs,
  310. SmallVector<Handle>* outputs) {
  311. auto& state = get_channel_state();
  312. auto& options = state.options;
  313. std::optional<StackManager::Guard> guard;
  314. if (Profiler::is_profiling()) {
  315. guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
  316. }
  317. auto [output_descs, validated] =
  318. OpDef::infer_output_attrs_fallible(*op, input_descs);
  319. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  320. SmallVector<TensorInfo*> output_infos;
  321. output_infos.reserve(output_descs.size());
  322. outputs->reserve(output_descs.size());
  323. for (int i = 0; i < output_descs.size(); ++i) {
  324. auto&& desc = output_descs[i];
  325. auto info = alloc();
  326. init(info, std::move(desc));
  327. // make sure desc's value is consistent with h_value
  328. if (!info->desc.value.empty()) {
  329. info->h_value = HostTensorND::make_proxy(info->desc.value)
  330. .proxy_to_comp_node(info->desc.comp_node);
  331. }
  332. output_infos.push_back(info);
  333. outputs->push_back(reinterpret_cast<Handle>(info));
  334. }
  335. ApplyOp cmd{
  336. Profiler::next_id(), std::move(op), std::move(input_infos),
  337. std::move(output_infos), validated};
  338. if (Profiler::is_profiling()) {
  339. auto op_info_getter = [op = cmd.op] {
  340. std::unordered_map<std::string, std::string> op_info;
  341. auto props = OpDef::props(*op);
  342. for (auto&& [key, value] : props) {
  343. op_info[key] = value;
  344. }
  345. return op_info;
  346. };
  347. MGB_RECORD_EVENT(
  348. OpDispatchEvent, cmd.id, guard.value().name(), op_info_getter,
  349. tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
  350. state.stack_manager.dump());
  351. m_worker.add_task(
  352. {Profiler::next_id(), std::move(cmd),
  353. get_channel_state().stack_manager.dump()});
  354. } else {
  355. m_worker.add_task({
  356. Profiler::next_id(),
  357. std::move(cmd),
  358. });
  359. }
  360. if (!validated && options.async_level == 1) {
  361. sync_impl();
  362. } else if (options.async_level == 0) {
  363. sync_impl();
  364. // check device error
  365. for (auto&& oup : *outputs) {
  366. auto info = reinterpret_cast<TensorInfo*>(oup);
  367. info->ptr->comp_node().sync();
  368. auto err = info->ptr->comp_node().check_async_error();
  369. mgb_assert(!err, "%s", err->what());
  370. }
  371. }
  372. }
  373. SmallVector<Handle> ChannelImpl::apply_op(
  374. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  375. MGB_LOCK_GUARD(m_spin);
  376. mgb_assert(check_available(), "Channel already closed");
  377. auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
  378. if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
  379. size_t ndim = input->desc.layout.ndim;
  380. auto& gvs = op->cast_final_safe<GetVarShape>();
  381. if (gvs.axis == MEGDNN_MAX_NDIM) {
  382. HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
  383. DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
  384. cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
  385. return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
  386. }
  387. }
  388. return apply_op_impl(std::move(op), inputs);
  389. }
  390. SmallVector<Handle> ChannelImpl::apply_op_impl(
  391. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  392. auto& state = get_channel_state();
  393. for (auto i : inputs) {
  394. mgb_assert(
  395. m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
  396. i);
  397. }
  398. SmallVector<TensorInfo*> input_infos;
  399. SmallVector<LogicalTensorDesc> input_descs;
  400. {
  401. MGB_LOCK_GUARD(m_info_spin);
  402. for (auto i : inputs) {
  403. auto info = reinterpret_cast<TensorInfo*>(i);
  404. mgb_assert(
  405. !info->invalid,
  406. "an input tensor is unusable due to previous error");
  407. input_infos.push_back(info);
  408. input_descs.push_back(info->desc);
  409. }
  410. }
  411. SmallVector<Handle> outputs;
  412. DispatchMode dispatch_mode = state.options.enable_host_compute
  413. ? OpDef::decide_dispatch_mode(*op, input_descs)
  414. : DispatchMode::KERNEL;
  415. switch (dispatch_mode) {
  416. case DEFAULT_CPU: {
  417. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  418. break;
  419. }
  420. case KERNEL: {
  421. dispatch_kernel(op, input_infos, input_descs, &outputs);
  422. break;
  423. }
  424. }
  425. return outputs;
  426. }
  427. HostTensorND ChannelImpl::get_value(Handle handle) {
  428. MGB_LOCK_GUARD(m_spin);
  429. mgb_assert(check_available(), "Channel already closed");
  430. mgb_assert(
  431. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  432. handle);
  433. auto info = reinterpret_cast<TensorInfo*>(handle);
  434. // donnot use info->value_fetched, it's unsafe
  435. mgb_assert(!info->invalid, "tensor is unusable due to previous error");
  436. return wait_tensor(info, TensorProp::HostValue)->get_value();
  437. }
  438. TensorShape ChannelImpl::get_shape(Handle handle) {
  439. MGB_LOCK_GUARD(m_spin);
  440. mgb_assert(check_available(), "Channel already closed");
  441. mgb_assert(
  442. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  443. handle);
  444. auto info = reinterpret_cast<TensorInfo*>(handle);
  445. if (info->desc.layout.ndim != 0) {
  446. return info->desc.layout;
  447. }
  448. TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
  449. mgb_assert(ret.ndim != 0);
  450. return ret;
  451. }
  452. DType ChannelImpl::get_dtype(Handle handle) {
  453. MGB_LOCK_GUARD(m_spin);
  454. mgb_assert(check_available(), "Channel already closed");
  455. mgb_assert(
  456. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  457. handle);
  458. auto info = reinterpret_cast<TensorInfo*>(handle);
  459. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
  460. auto ret = info->desc.layout.dtype;
  461. mgb_assert(ret.valid());
  462. return ret;
  463. }
  464. CompNode ChannelImpl::get_device(Handle handle) {
  465. MGB_LOCK_GUARD(m_spin);
  466. mgb_assert(check_available(), "Channel already closed");
  467. mgb_assert(
  468. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  469. handle);
  470. auto info = reinterpret_cast<TensorInfo*>(handle);
  471. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
  472. auto ret = info->desc.comp_node;
  473. mgb_assert(ret.valid());
  474. return ret;
  475. }
  476. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  477. MGB_LOCK_GUARD(m_spin);
  478. mgb_assert(check_available(), "Channel already closed");
  479. mgb_assert(
  480. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  481. handle);
  482. auto info = reinterpret_cast<TensorInfo*>(handle);
  483. return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
  484. }
  485. void ChannelImpl::sync() {
  486. MGB_LOCK_GUARD(m_spin);
  487. mgb_assert(check_available(), "Channel already closed");
  488. sync_impl();
  489. }
  490. void ChannelImpl::sync_impl() {
  491. m_worker.wait_all_task_finish();
  492. MGB_LOCK_GUARD(m_mutex);
  493. check_worker_exc_unsafe();
  494. }
  495. void ChannelImpl::close() {
  496. MGB_LOCK_GUARD(m_spin);
  497. if (!check_available()) {
  498. return;
  499. }
  500. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  501. for (auto* handle : valid_handles) {
  502. del_impl(handle);
  503. }
  504. mgb_assert(m_valid_handle.empty());
  505. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  506. sync_impl();
  507. m_closed = true;
  508. }
  509. size_t ChannelImpl::get_option(std::string name) {
  510. MGB_LOCK_GUARD(m_spin);
  511. mgb_assert(check_available(), "Channel already closed");
  512. auto& state = get_channel_state();
  513. return state.options.get_option(name);
  514. }
  515. void ChannelImpl::set_option(std::string name, size_t value) {
  516. MGB_LOCK_GUARD(m_spin);
  517. mgb_assert(check_available(), "Channel already closed");
  518. auto& state = get_channel_state();
  519. state.options.set_option(name, value);
  520. // FIXME
  521. if (name == "enable_dtr_auto_drop" && value) {
  522. auto custom_allocator = [&](CompNode device, size_t size) {
  523. auto blob = Blob::make(device, size);
  524. alloc_tensor_with_evict(blob.get());
  525. return blob->storage();
  526. };
  527. BlobManager::inst()->set_allocator(custom_allocator);
  528. }
  529. if (Profiler::is_profiling()) {
  530. m_worker.add_task(
  531. {Profiler::next_id(), SetOption{name, value},
  532. get_channel_state().stack_manager.dump()});
  533. } else {
  534. m_worker.add_task({
  535. Profiler::next_id(),
  536. SetOption{name, value},
  537. });
  538. }
  539. }
  540. void ChannelImpl::clear_candidates() {
  541. MGB_LOCK_GUARD(m_spin);
  542. mgb_assert(check_available(), "Channel already closed");
  543. m_dtr.candidates.clear();
  544. }
  545. TensorInfo* ChannelImpl::alloc() {
  546. auto& state = get_channel_state();
  547. auto info = [this] {
  548. MGB_LOCK_GUARD(m_pool_spin);
  549. return m_pool.alloc();
  550. }();
  551. info->id = Profiler::next_id();
  552. if (Profiler::is_profiling()) {
  553. size_t tensor_id = state.stack_manager.current()->next_id("tensor");
  554. info->name =
  555. state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
  556. }
  557. return info;
  558. }
  559. void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
  560. m_valid_handle.insert(reinterpret_cast<Handle>(info));
  561. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  562. info->status = TensorInfo::Allocated;
  563. info->desc = std::move(desc);
  564. }
  565. void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
  566. if (!ptr->producer) {
  567. if (user) {
  568. mgb_log_warn(
  569. "the input that produced tensor %p has been deleted, this drop "
  570. "operation will be ignored",
  571. ptr);
  572. }
  573. return;
  574. }
  575. if (ptr->evict_type != EvictType::NONE) {
  576. return;
  577. }
  578. ptr->evict_type = EvictType::DROP;
  579. ptr->status = TensorInfo::Dropped;
  580. release_tensor(ptr);
  581. }
  582. void ChannelImpl::free(TensorInfo* ptr) {
  583. auto& state = get_worker_state();
  584. if (state.options.enable_dtr_auto_drop) {
  585. // Evicting a tensor, rather than freeing it, can avoid pinning
  586. // potentially exploding amounts of memory and allow us to save
  587. // more memory.
  588. ptr->allow_delete = true;
  589. if (!ptr->ref_cnt) {
  590. recursive_free(ptr);
  591. } else {
  592. do_drop(ptr);
  593. }
  594. } else {
  595. real_free(ptr);
  596. }
  597. }
  598. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  599. MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
  600. SmallVector<TensorInfo*> inps;
  601. if (ptr->producer) {
  602. for (auto i : ptr->producer->inputs) {
  603. if (i && --i->ref_cnt == 0) {
  604. inps.push_back(i);
  605. }
  606. }
  607. }
  608. real_free(ptr);
  609. for (auto i : inps) {
  610. if (i->allow_delete) {
  611. recursive_free(i);
  612. }
  613. }
  614. MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
  615. }
  616. void ChannelImpl::real_free(TensorInfo* ptr) {
  617. auto& state = get_worker_state();
  618. if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  619. m_dtr.erase_candidate(ptr);
  620. }
  621. detach_users(ptr);
  622. ptr->detach_producer();
  623. bool has_value = ptr->ptr != nullptr;
  624. if (has_value) {
  625. MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
  626. }
  627. MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
  628. ptr->status = TensorInfo::Deleted;
  629. MGB_LOCK_GUARD(m_pool_spin);
  630. m_pool.free(ptr);
  631. }
  632. ChannelImpl::ChannelImpl() : m_worker(this) {}
  633. ChannelImpl::~ChannelImpl() {
  634. close();
  635. }
  636. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  637. auto& state = get_worker_state();
  638. MGB_LOCK_GUARD(m_mutex);
  639. m_dtr.update_used_time(dest);
  640. MGB_RECORD_EVENT(
  641. TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
  642. ptr->dev_tensor(false).raw_ptr());
  643. // update tensor desc for static infer
  644. if (dest->desc.layout.ndim) {
  645. mgb_assert(
  646. dest->desc.layout.eq_shape(ptr->layout()),
  647. "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
  648. ptr->layout().to_string().c_str());
  649. }
  650. // in order to avoid performance impact,
  651. // memory forwarding is disabled when DTR is enabled
  652. if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
  653. ptr->to_contiguous_inplace();
  654. }
  655. dest->desc.comp_node = ptr->comp_node();
  656. dest->memory = ptr->blob()->size();
  657. dest->ptr = std::move(ptr);
  658. dest->evict_type = EvictType::NONE;
  659. dest->status = TensorInfo::Produced;
  660. if (dest->pinned == 0 &&
  661. dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  662. m_dtr.insert_candidate(dest);
  663. }
  664. notify_tensor_unsafe(dest);
  665. }
  666. void ChannelImpl::release_tensor(TensorInfo* dest) {
  667. MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
  668. MGB_LOCK_GUARD(m_mutex);
  669. dest->ptr.reset();
  670. auto& state = get_worker_state();
  671. if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  672. m_dtr.erase_candidate(dest);
  673. }
  674. }
  675. void ChannelImpl::regenerate(TensorInfo* dest) {
  676. if (dest->evict_type == EvictType::DROP) {
  677. auto&& path = dest->producer;
  678. m_apply_stack.push(
  679. {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
  680. "dtr"});
  681. if (!m_applying)
  682. flush_apply_stack();
  683. }
  684. }
  685. void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
  686. using namespace ranges;
  687. using namespace ranges::views;
  688. auto& state = get_worker_state();
  689. bool profiling_device =
  690. Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
  691. uint64_t apply_id = cmd.id;
  692. SmallVector<TensorPtr> inputs;
  693. inputs.reserve(cmd.inputs.size());
  694. // refcnt == 1, owners: [TensorInfo::ptr]
  695. for (auto i : cmd.inputs) {
  696. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  697. // refcnt ++, owners: [i->ptr, tensor_inputs]
  698. // tensor_inputs.push_back(i->ptr);
  699. inputs.push_back(i->ptr);
  700. }
  701. if (state.options.enable_dtr_auto_drop &&
  702. state.options.dtr_eviction_threshold > 0) {
  703. auto_evict(0);
  704. }
  705. auto apply_on_physical_tensor =
  706. [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
  707. SmallVector<LogicalTensorDesc>& output_descs,
  708. const bool& validated) -> SmallVector<TensorPtr> {
  709. if (def.trait()->make_forward_graph) {
  710. auto apply_functor = [&](std::shared_ptr<OpDef> op,
  711. SmallVector<TensorPtr> inputs,
  712. size_t nr_outputs) -> SmallVector<TensorPtr> {
  713. auto opname = op->trait()->make_name(*op);
  714. imperative_log_profile_begin(opname.c_str());
  715. auto outputs = self(self, *op, std::move(inputs), output_descs, false);
  716. imperative_log_profile_end(opname.c_str());
  717. return outputs;
  718. };
  719. auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
  720. // apply recursivily
  721. SmallVector<LogicalTensorDesc> input_descs;
  722. for (auto&& input : inputs) {
  723. input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
  724. }
  725. auto forward_graph = OpDef::make_forward_graph(def, input_descs);
  726. auto outputs = forward_graph.apply<TensorPtr>(
  727. inputs, apply_functor, const_functor);
  728. return outputs;
  729. }
  730. // Check Input Layout
  731. // Get the input layout constraints, and if the constraint is not satisfied
  732. // inplace update the layout and blob to make the tensor contiguous
  733. auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
  734. for (size_t idx = 0; idx < inputs.size(); ++idx) {
  735. auto&& layout_checker = constraints[idx];
  736. if (layout_checker) {
  737. inputs[idx]->to_contiguous_inplace(layout_checker);
  738. }
  739. }
  740. return OpDef::apply_on_physical_tensor(
  741. def, std::move(inputs), output_descs, validated);
  742. };
  743. MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
  744. SmallVector<std::pair<CompNode, uint64_t>> kernels;
  745. if (profiling_device) {
  746. // Collecting devices
  747. SmallVector<CompNode> devices;
  748. for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
  749. if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
  750. devices.push_back(i->desc.comp_node);
  751. kernels.push_back({i->desc.comp_node, Profiler::next_id()});
  752. }
  753. }
  754. }
  755. for (auto* input : cmd.inputs) {
  756. auto input_id = input->id;
  757. MGB_RECORD_EVENT(OpInputEvent, input_id);
  758. MGB_RECORD_EVENT(TensorUsageEvent, input_id);
  759. MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
  760. }
  761. // Before wait
  762. // TODO: split operator wait and execute so that OpWait could be corrected recorded.
  763. // Before execute
  764. for (auto&& [device, kernel_id] : kernels) {
  765. MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
  766. MGB_RECORD_EVENT_IF(
  767. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  768. Timer::record_device(device));
  769. }
  770. // Apply op
  771. SmallVector<LogicalTensorDesc> output_descs;
  772. bool validated = cmd.validated;
  773. if (!state.options.enable_dtr_auto_drop) {
  774. for (auto i : cmd.outputs) {
  775. output_descs.push_back(i->desc);
  776. }
  777. } else {
  778. validated = false;
  779. }
  780. // Here std::move is REQUIRED for removing duplicated references.
  781. auto outputs = apply_on_physical_tensor(
  782. apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
  783. validated);
  784. // After execute
  785. for (auto&& [device, kernel_id] : kernels) {
  786. MGB_RECORD_EVENT_IF(
  787. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  788. Timer::record_device(device));
  789. MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
  790. }
  791. // End profiling operator
  792. mgb_assert(outputs.size() == cmd.outputs.size());
  793. for (size_t i = 0; i < outputs.size(); ++i) {
  794. auto output = cmd.outputs[i];
  795. if (mgb_unlikely(output == nullptr)) {
  796. MGB_RECORD_EVENT(OpOutputEvent, 0);
  797. MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
  798. } else if (mgb_unlikely(output->ptr != nullptr)) {
  799. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  800. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  801. } else {
  802. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  803. produce_tensor(output, outputs[i]);
  804. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  805. sample_on_device(output->desc.comp_node, false);
  806. }
  807. }
  808. if (state.options.enable_dtr_auto_drop) {
  809. double estimate_compute_time = 0;
  810. for (auto i : cmd.inputs) {
  811. estimate_compute_time += i->memory;
  812. }
  813. for (auto i : outputs) {
  814. estimate_compute_time += i->blob()->size();
  815. }
  816. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  817. for (auto i : cmd.outputs) {
  818. if (i != nullptr) {
  819. i->compute_time = estimate_compute_time;
  820. }
  821. }
  822. m_dtr.unpin(cmd.inputs, state);
  823. }
  824. MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
  825. // End profiling operator
  826. }
  827. void ChannelImpl::flush_apply_stack() {
  828. m_applying = true;
  829. auto& state = get_worker_state();
  830. while (!m_apply_stack.empty()) {
  831. auto& [cmd, idx, recomp, reason] =
  832. m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
  833. if (idx == 0) {
  834. if (state.options.enable_dtr_auto_drop) {
  835. m_dtr.pin(cmd.inputs);
  836. }
  837. if (recomp) {
  838. MGB_RECORD_EVENT(
  839. TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
  840. }
  841. }
  842. bool regen = false;
  843. for (size_t i = idx; i < cmd.inputs.size(); i++) {
  844. auto&& p = cmd.inputs[i];
  845. if (state.options.enable_dtr_auto_drop) {
  846. m_dtr.update_used_time(p);
  847. }
  848. if (!p->ptr && p->evict_type != EvictType::NONE) {
  849. idx = i + 1;
  850. regenerate(p); // add ApplyOp to the stack
  851. regen = true;
  852. break;
  853. }
  854. }
  855. if (regen)
  856. continue;
  857. // the required input tensors are already in memory
  858. auto [cmd_backup, recomp_backup, reason_backup] =
  859. std::make_tuple(cmd, recomp, reason);
  860. m_apply_stack.pop();
  861. do_apply_op(cmd_backup, reason_backup);
  862. if (recomp_backup) {
  863. MGB_RECORD_EVENT(
  864. TensorCommandFinishEvent, recomp_backup->id,
  865. TensorCommandKind::ReGen);
  866. for (auto o : cmd_backup.outputs) {
  867. if (o) {
  868. m_dtr.update_dsu_after_recompute(o);
  869. }
  870. }
  871. }
  872. }
  873. m_applying = false;
  874. }
  875. bool ChannelImpl::auto_evict(size_t force_num) {
  876. auto& state = get_worker_state();
  877. if (!m_dtr.comp_node.valid()) {
  878. return false;
  879. }
  880. size_t current_memory = m_dtr.comp_node.get_used_memory();
  881. size_t flag = false;
  882. while ((state.options.dtr_eviction_threshold > 0 &&
  883. current_memory > state.options.dtr_eviction_threshold) ||
  884. force_num > 0) {
  885. MGB_RECORD_EVENT(AutoEvictEvent);
  886. sample_on_device(m_dtr.comp_node, false);
  887. auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
  888. if (!best) {
  889. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  890. break;
  891. }
  892. if (best->ptr.unique() && best->ptr->blob().unique()) {
  893. current_memory -= best->memory;
  894. if (force_num > 0) {
  895. force_num--;
  896. }
  897. flag = true;
  898. }
  899. do_drop(best);
  900. if (best->evict_type == EvictType::DROP) {
  901. m_dtr.update_dsu_after_evict(best);
  902. }
  903. sample_on_device(m_dtr.comp_node, false);
  904. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  905. }
  906. return flag;
  907. }
  908. void ChannelImpl::detach_users(TensorInfo* dest) {
  909. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  910. for (auto* user : users) {
  911. SmallVector<TensorInfo*> outputs = user->outputs;
  912. SmallVector<TensorInfo*> inputs = user->inputs;
  913. for (auto* output : outputs) {
  914. // When a `ComputePath` is detach from it's input,
  915. // there is no need to reserve it,
  916. // so we detach all output of this path
  917. // to decrease it's `ref_cnt` to zero.
  918. if (output == nullptr) {
  919. continue;
  920. }
  921. regenerate(output);
  922. output->detach_producer();
  923. for (auto* input : inputs) {
  924. input->ref_cnt--;
  925. }
  926. }
  927. // now user is dead
  928. }
  929. mgb_assert(dest->users.empty(), "ComputePath leaking");
  930. }
  931. bool ChannelImpl::check_available() {
  932. return !m_closed;
  933. }
  934. TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
  935. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  936. mgb_assert(!m_waitee, "duplicate waitee");
  937. m_waitee = info;
  938. m_waitee_id = Profiler::next_id();
  939. MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
  940. bool require_host = prop == TensorProp::HostValue;
  941. auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
  942. bool wait_host = false;
  943. if (require_host && !host_available()) {
  944. // avoid dead lock
  945. lock.unlock();
  946. if (Profiler::is_profiling()) {
  947. m_worker.add_task(
  948. {Profiler::next_id(), GetValue{info},
  949. get_channel_state().stack_manager.dump()});
  950. } else {
  951. m_worker.add_task({
  952. Profiler::next_id(),
  953. GetValue{info},
  954. });
  955. }
  956. lock.lock();
  957. wait_host = true;
  958. }
  959. m_cv.wait(lock, [&]() {
  960. check_worker_exc_unsafe();
  961. return require_host ? host_available() : static_cast<bool>(info->ptr);
  962. });
  963. MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
  964. m_waitee = nullptr;
  965. if (wait_host) {
  966. auto err = info->ptr->comp_node().check_async_error();
  967. mgb_assert(!err, "%s", err->what());
  968. }
  969. return info->ptr;
  970. }
  971. void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
  972. if (info == m_waitee) {
  973. MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
  974. m_cv.notify_all();
  975. }
  976. }
  977. std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
  978. std::unordered_set<TensorInfo*> valid_tensors;
  979. for (auto* handle : m_valid_handle) {
  980. auto* info = reinterpret_cast<TensorInfo*>(handle);
  981. valid_tensors.insert(info);
  982. }
  983. return valid_tensors;
  984. }
  985. void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
  986. bool in_worker = (get_worker_tid() == std::this_thread::get_id());
  987. auto reserve_size = [&](size_t size) {
  988. if (!m_dtr.comp_node.valid()) {
  989. return false;
  990. }
  991. while (size > m_dtr.comp_node.get_max_block_size_available()) {
  992. bool evict_suc = auto_evict(1);
  993. if (!evict_suc)
  994. return false;
  995. }
  996. return true;
  997. };
  998. auto pre_level = set_log_level(LogLevel::NO_LOG);
  999. if (in_worker) {
  1000. reserve_size(x->size());
  1001. }
  1002. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  1003. MGB_CATCH(MemAllocError&, {
  1004. bool suc = false;
  1005. if (in_worker) {
  1006. while (!suc) {
  1007. if (!auto_evict(1)) {
  1008. break;
  1009. }
  1010. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  1011. MGB_CATCH(MemAllocError&, { continue; });
  1012. suc = true;
  1013. }
  1014. }
  1015. if (!suc) {
  1016. set_log_level(pre_level);
  1017. mgb_log_warn(
  1018. "reallocating all cuda memory to alleviate fragmentation, the "
  1019. "performance may be affected");
  1020. set_log_level(LogLevel::NO_LOG);
  1021. imperative_log_profile_begin("defrag");
  1022. BlobManager::inst()->defrag(x->comp_node());
  1023. imperative_log_profile_end("defrag");
  1024. BlobManager::inst()->alloc_direct(x, x->size());
  1025. }
  1026. });
  1027. set_log_level(pre_level);
  1028. }
  1029. void ChannelImpl::process_one_task(Command& icmd) {
  1030. using namespace ranges;
  1031. using namespace ranges::views;
  1032. auto& state = get_worker_state();
  1033. auto& options = state.options;
  1034. // TODO: remove std::visit for support osx 10.12
  1035. auto cmd_visitor = [&](const auto& cmd) {
  1036. using T = std::decay_t<decltype(cmd)>;
  1037. if constexpr (std::is_same_v<T, Put>) {
  1038. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
  1039. MGB_RECORD_EVENT_IF(
  1040. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1041. Timer::record_device(cmd.value.comp_node()));
  1042. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
  1043. : Tensor::make(cmd.value);
  1044. MGB_RECORD_EVENT_IF(
  1045. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1046. Timer::record_device(cmd.value.comp_node()));
  1047. produce_tensor(cmd.dest, std::move(value));
  1048. MGB_RECORD_EVENT(
  1049. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
  1050. sample_on_device(cmd.dest->desc.comp_node, false);
  1051. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  1052. for (auto& i : cmd.inputs) {
  1053. if (mgb_unlikely(i->invalid)) {
  1054. MGB_LOCK_GUARD(m_mutex);
  1055. for (auto& i : cmd.outputs) {
  1056. i->invalid = true;
  1057. }
  1058. return;
  1059. }
  1060. }
  1061. if (state.options.enable_dtr_auto_drop) {
  1062. m_apply_stack.push({cmd, 0, nullptr, "cmd"});
  1063. flush_apply_stack();
  1064. for (size_t i = 0; i < cmd.outputs.size(); ++i) {
  1065. auto output = cmd.outputs[i];
  1066. if (output == nullptr) {
  1067. continue;
  1068. }
  1069. output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
  1070. }
  1071. } else {
  1072. do_apply_op(cmd, "cmd");
  1073. }
  1074. if (state.options.enable_drop && state.options.record_computing_path) {
  1075. auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
  1076. auto& input = std::get<0>(tuple2);
  1077. auto& output = std::get<1>(tuple2);
  1078. if (!input->ptr || !output->ptr) {
  1079. return false;
  1080. }
  1081. return input->ptr->blob()->storage() ==
  1082. output->ptr->blob()->storage();
  1083. };
  1084. // FIXME: do not use opname as identifier
  1085. auto get_name = [](const OpDef& opdef) {
  1086. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  1087. return attr->type.c_str();
  1088. }
  1089. return opdef.dyn_typeinfo()->name;
  1090. };
  1091. auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
  1092. return info->desc.comp_node != comp_node;
  1093. };
  1094. bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
  1095. bool inplace =
  1096. any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
  1097. if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  1098. TensorInfo::ComputePath::make(
  1099. cmd.id, cmd.op, cmd.inputs, cmd.outputs);
  1100. size_t detach_cnt = 0;
  1101. if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
  1102. cmd.outputs.size() == 6) {
  1103. cmd.outputs[0]->detach_producer(); // detach running_mean
  1104. cmd.outputs[1]->detach_producer(); // detach running_var
  1105. for (auto input : cmd.inputs) {
  1106. input->ref_cnt -= 2;
  1107. }
  1108. }
  1109. for (auto output : cmd.outputs) {
  1110. if (output->producer &&
  1111. !output->size_exceeds_thd(
  1112. state.options.dtr_evictee_minimum_size)) {
  1113. output->detach_producer();
  1114. detach_cnt++;
  1115. }
  1116. }
  1117. for (auto input : cmd.inputs) {
  1118. input->ref_cnt -= detach_cnt;
  1119. }
  1120. }
  1121. }
  1122. } else if constexpr (std::is_same_v<T, Del>) {
  1123. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
  1124. CompNode device = cmd.dest->desc.comp_node;
  1125. uint64_t tensor_id = cmd.dest->id;
  1126. free(cmd.dest);
  1127. MGB_RECORD_EVENT(
  1128. TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
  1129. sample_on_device(device, false);
  1130. } else if constexpr (std::is_same_v<T, GetValue>) {
  1131. if (cmd.dest->invalid)
  1132. return;
  1133. imperative_log_profile_begin("GetValue");
  1134. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  1135. regenerate(cmd.dest);
  1136. }
  1137. cmd.dest->ptr->fetch_value();
  1138. MGB_LOCK_GUARD(m_mutex);
  1139. notify_tensor_unsafe(cmd.dest);
  1140. imperative_log_profile_end("GetValue");
  1141. } else if constexpr (std::is_same_v<T, Drop>) {
  1142. if (cmd.dest->invalid)
  1143. return;
  1144. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
  1145. do_drop(cmd.dest, true);
  1146. MGB_RECORD_EVENT(
  1147. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
  1148. } else if constexpr (std::is_same_v<T, SetOption>) {
  1149. options.set_option(cmd.key, cmd.value);
  1150. } else if constexpr (std::is_same_v<T, StartProfile>) {
  1151. MGB_RECORD_EVENT(StartProfileEvent);
  1152. CompNode::sync_all();
  1153. for (auto* info : cmd.capture_tensors) {
  1154. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  1155. if (info->status == TensorInfo::Produced) {
  1156. // TODO: handle drop
  1157. MGB_RECORD_EVENT(
  1158. TensorProduceEvent, info->id, info->desc.layout,
  1159. info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
  1160. }
  1161. }
  1162. CompNode::foreach ([&](CompNode device) {
  1163. sample_on_device(device, true);
  1164. MGB_RECORD_EVENT_IF(
  1165. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1166. Timer::record_device(device));
  1167. });
  1168. MGB_RECORD_EVENT(StartProfileFinishEvent);
  1169. } else if constexpr (std::is_same_v<T, StopProfile>) {
  1170. MGB_RECORD_EVENT(StopProfileEvent);
  1171. for (auto* info : cmd.escape_tensors) {
  1172. bool has_value = info->status == TensorInfo::Produced;
  1173. if (has_value) {
  1174. MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
  1175. }
  1176. MGB_RECORD_EVENT(TensorEraseEvent, info->id);
  1177. }
  1178. CompNode::foreach (
  1179. [&](CompNode device) { sample_on_device(device, true); });
  1180. MGB_RECORD_EVENT(StopProfileFinishEvent);
  1181. } else if constexpr (std::is_same_v<T, PushScope>) {
  1182. MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
  1183. } else if constexpr (std::is_same_v<T, PopScope>) {
  1184. MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
  1185. } else {
  1186. static_assert(!std::is_same_v<T, T>);
  1187. }
  1188. };
  1189. std::visit(
  1190. [&](const auto& cmd) {
  1191. using T = std::decay_t<decltype(cmd)>;
  1192. if (!options.catch_worker_execption) {
  1193. cmd_visitor(cmd);
  1194. return;
  1195. }
  1196. try {
  1197. cmd_visitor(cmd);
  1198. } catch (...) {
  1199. MGB_LOCK_GUARD(m_mutex);
  1200. if constexpr (std::is_same_v<T, ApplyOp>) {
  1201. for (auto oup : cmd.outputs) {
  1202. oup->invalid = true;
  1203. }
  1204. } else if constexpr (std::is_same_v<T, Put>) {
  1205. cmd.dest->invalid = true;
  1206. }
  1207. m_worker_exc = std::current_exception();
  1208. MGB_RECORD_EVENT(WorkerExceptionEvent);
  1209. if (m_waitee) {
  1210. notify_tensor_unsafe(m_waitee);
  1211. }
  1212. }
  1213. },
  1214. icmd.data);
  1215. }
  1216. void ChannelImpl::check_worker_exc_unsafe() {
  1217. if (m_worker_exc) {
  1218. // for reuse interpreter_for_py after some exception tests
  1219. m_waitee = nullptr;
  1220. std::exception_ptr exc;
  1221. std::swap(exc, m_worker_exc);
  1222. try {
  1223. std::rethrow_exception(exc);
  1224. } catch (...) {
  1225. throw AsyncError();
  1226. }
  1227. }
  1228. }
  1229. void ChannelImpl::start_profile() {
  1230. MGB_LOCK_GUARD(m_spin);
  1231. mgb_assert(check_available(), "Channel already closed");
  1232. auto capture_tensors = collect_valid_tensors();
  1233. if (capture_tensors.size() > 0) {
  1234. if (Profiler::is_profiling()) {
  1235. m_worker.add_task(
  1236. {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
  1237. get_channel_state().stack_manager.dump()});
  1238. } else {
  1239. m_worker.add_task({
  1240. Profiler::next_id(),
  1241. StartProfile{std::move(capture_tensors)},
  1242. });
  1243. }
  1244. }
  1245. }
  1246. void ChannelImpl::stop_profile() {
  1247. MGB_LOCK_GUARD(m_spin);
  1248. mgb_assert(check_available(), "Channel already closed");
  1249. auto escape_tensors = collect_valid_tensors();
  1250. if (escape_tensors.size() > 0) {
  1251. if (Profiler::is_profiling()) {
  1252. m_worker.add_task(
  1253. {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
  1254. get_channel_state().stack_manager.dump()});
  1255. } else {
  1256. m_worker.add_task({
  1257. Profiler::next_id(),
  1258. StopProfile{std::move(escape_tensors)},
  1259. });
  1260. }
  1261. }
  1262. }
  1263. void ChannelImpl::push_scope(std::string name) {
  1264. MGB_LOCK_GUARD(m_spin);
  1265. mgb_assert(check_available(), "Channel already closed");
  1266. auto& state = get_channel_state();
  1267. state.stack_manager.enter(name);
  1268. MGB_RECORD_EVENT(ScopeEvent, name);
  1269. if (Profiler::is_profiling()) {
  1270. m_worker.add_task(
  1271. {Profiler::next_id(), PushScope{name},
  1272. get_channel_state().stack_manager.dump()});
  1273. } else {
  1274. m_worker.add_task({
  1275. Profiler::next_id(),
  1276. PushScope{name},
  1277. });
  1278. }
  1279. }
  1280. void ChannelImpl::pop_scope(std::string name) {
  1281. MGB_LOCK_GUARD(m_spin);
  1282. mgb_assert(check_available(), "Channel already closed");
  1283. auto& state = get_channel_state();
  1284. state.stack_manager.exit(name);
  1285. MGB_RECORD_EVENT(ScopeFinishEvent, name);
  1286. if (Profiler::is_profiling()) {
  1287. m_worker.add_task(
  1288. {Profiler::next_id(), PopScope{name},
  1289. get_channel_state().stack_manager.dump()});
  1290. } else {
  1291. m_worker.add_task({
  1292. Profiler::next_id(),
  1293. PopScope{name},
  1294. });
  1295. }
  1296. }
  1297. void ChannelImpl::assert_in_channel() {
  1298. mgb_assert(
  1299. get_worker_tid() != std::this_thread::get_id(),
  1300. "this method cannot be called in worker thread");
  1301. }
  1302. void ChannelImpl::assert_in_worker() {
  1303. mgb_assert(
  1304. get_worker_tid() == std::this_thread::get_id(),
  1305. "this method can only be called in worker thread");
  1306. }
  1307. void ChannelImpl::sample_on_device(CompNode device, bool force) {
  1308. if (!Profiler::is_profiling()) {
  1309. return;
  1310. }
  1311. if (!force) {
  1312. thread_local int last_sample_id = 0;
  1313. int sample_rate = Profiler::get_option("sample_rate", 0);
  1314. if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
  1315. return;
  1316. }
  1317. }
  1318. MGB_RECORD_EVENT(SampleDeviceEvent, device);
  1319. auto [total, free] = device.get_mem_status_bytes();
  1320. MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
  1321. }
  1322. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  1323. for (auto i : vec) {
  1324. i->pin();
  1325. erase_candidate(i);
  1326. }
  1327. }
  1328. void ChannelImpl::DynamicSublinear::unpin(
  1329. const SmallVector<TensorInfo*>& vec, WorkerState& state) {
  1330. for (auto i : vec) {
  1331. i->unpin();
  1332. if (i->pinned == 0 &&
  1333. i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
  1334. i->cand_index == UINT_MAX) {
  1335. insert_candidate(i);
  1336. }
  1337. }
  1338. }
  1339. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  1340. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  1341. dsu_fa->t -= ptr->compute_time;
  1342. ptr->dsu_ptr->parent.reset();
  1343. ptr->dsu_ptr->t = ptr->compute_time;
  1344. }
  1345. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  1346. for (auto i : ptr->producer->inputs) {
  1347. if (i->evict_type == EvictType::DROP) {
  1348. merge(i->dsu_ptr, ptr->dsu_ptr);
  1349. }
  1350. }
  1351. for (auto i : ptr->producer->outputs) {
  1352. if (i && i->evict_type == EvictType::DROP) {
  1353. merge(ptr->dsu_ptr, i->dsu_ptr);
  1354. }
  1355. }
  1356. }
  1357. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1358. double cost = 0;
  1359. for (auto i : ptr->producer->inputs) {
  1360. if (i->evict_type == EvictType::DROP) {
  1361. double t = find_father(i->dsu_ptr)->t;
  1362. if (t < i->compute_time) {
  1363. t = i->compute_time;
  1364. }
  1365. cost += t;
  1366. }
  1367. }
  1368. for (auto i : ptr->producer->outputs) {
  1369. if (i && i->evict_type == EvictType::DROP) {
  1370. double t = find_father(i->dsu_ptr)->t;
  1371. if (t < i->compute_time) {
  1372. t = i->compute_time;
  1373. }
  1374. cost += t;
  1375. }
  1376. }
  1377. return cost;
  1378. }
  1379. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
  1380. bool enable_dtr_sqrt_sampling = false) {
  1381. if (candidates.empty())
  1382. return nullptr;
  1383. double min_msps = -1;
  1384. TensorInfo* best = nullptr;
  1385. size_t sz = 1;
  1386. if (enable_dtr_sqrt_sampling) {
  1387. while (sz * sz <= candidates.size())
  1388. sz++;
  1389. sz--;
  1390. } else {
  1391. sz = candidates.size();
  1392. }
  1393. size_t ti = rand() % sz;
  1394. for (size_t vi = 0; vi < sz; vi++) {
  1395. if (!enable_dtr_sqrt_sampling) {
  1396. ti = vi;
  1397. }
  1398. auto i = candidates[ti];
  1399. if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
  1400. double neighbor_cost = estimate_neighbor_cost(i);
  1401. size_t begin_ptr =
  1402. reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1403. auto side_info = i->ptr->comp_node().get_free_left_and_right(
  1404. begin_ptr, begin_ptr + i->ptr->blob()->size());
  1405. double free_mem = side_info.first + side_info.second;
  1406. double msps = i->eval_func(
  1407. neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1408. if (min_msps < 0 || msps < min_msps) {
  1409. min_msps = msps;
  1410. best = i;
  1411. }
  1412. }
  1413. if (enable_dtr_sqrt_sampling) {
  1414. ti += rand() % sz;
  1415. if (ti > candidates.size())
  1416. break;
  1417. }
  1418. }
  1419. return best;
  1420. }
  1421. void ChannelImpl::DynamicSublinear::merge(
  1422. std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
  1423. auto&& f_x = find_father(x);
  1424. auto&& f_y = find_father(y);
  1425. if (f_x.get() == f_y.get()) {
  1426. return;
  1427. }
  1428. f_y->t += f_x->t;
  1429. f_x->parent = f_y;
  1430. }
  1431. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
  1432. std::shared_ptr<DsuNode>& x) {
  1433. if (x->is_root()) {
  1434. return x;
  1435. } else {
  1436. auto&& fa = find_father(x->parent);
  1437. return x->parent = fa;
  1438. }
  1439. }
  1440. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1441. // tensor to be inserted must be brand new
  1442. mgb_assert(
  1443. ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
  1444. ptr->cand_index);
  1445. ptr->cand_index = candidates.size();
  1446. candidates.push_back(ptr);
  1447. if (!comp_node.valid()) {
  1448. comp_node = ptr->ptr->comp_node();
  1449. }
  1450. }
  1451. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1452. // close dtr will just clear candidates, so nothing to erase
  1453. if (candidates.empty()) {
  1454. ptr->cand_index = UINT_MAX;
  1455. return;
  1456. }
  1457. // some tensors may be erased already, just skip them
  1458. if (ptr->cand_index != UINT_MAX) {
  1459. std::swap(candidates[ptr->cand_index], candidates.back());
  1460. candidates[ptr->cand_index]->cand_index = ptr->cand_index;
  1461. candidates.pop_back();
  1462. ptr->cand_index = UINT_MAX;
  1463. }
  1464. }
  1465. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1466. ptr->last_used_time = estimate_timestamp;
  1467. }