You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interpreter_impl.cpp 56 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573
  1. /**
  2. * \file imperative/src/impl/interpreter/interpreter_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./interpreter_impl.h"
  12. #include "range/v3/all.hpp"
  13. #include "megbrain/common.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/backward_graph.h"
  17. #include "megbrain/imperative/ops/opr_attr.h"
  18. #include "megbrain/imperative/ops/utility.h"
  19. #include "megbrain/imperative/utils/to_string.h"
  20. #include "../blob_manager_impl.h"
  21. #include "../event_pool.h"
  22. #include "../op_trait.h"
  23. using namespace mgb;
  24. using namespace imperative;
  25. using namespace interpreter;
  26. using namespace interpreter::intl;
  27. namespace {
  28. auto tinfo_to_tid(SmallVector<TensorInfo*> tinfo) {
  29. SmallVector<uint64_t> tid;
  30. for (auto* ptinfo : tinfo) {
  31. tid.push_back(ptinfo->id);
  32. }
  33. return tid;
  34. };
  35. } // namespace
  36. namespace mgb {
  37. using namespace profiler;
  38. }
  39. #if defined(_WIN32) || defined(_WIN64)
  40. #define SYMBOL_EXPORT __declspec(dllexport)
  41. #else
  42. #define SYMBOL_EXPORT __attribute__((visibility("default")))
  43. #endif
  44. namespace mgb {
  45. /**
  46. * USAGE
  47. *
  48. * header:
  49. * namespace mgb { void imperative_log_profile(const char* message); }
  50. *
  51. * code:
  52. * mgb::imperative_log_profile("MY MESSAGE");
  53. *
  54. **/
  55. SYMBOL_EXPORT
  56. void imperative_log_profile_begin(const char* message) {
  57. MGB_RECORD_EVENT(CustomEvent, std::string{message});
  58. }
  59. SYMBOL_EXPORT
  60. void imperative_log_profile_end(const char* message) {
  61. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message});
  62. }
  63. SYMBOL_EXPORT
  64. void imperative_log_profile(const char* message) {
  65. imperative_log_profile_begin(message);
  66. imperative_log_profile_end(message);
  67. }
  68. SYMBOL_EXPORT
  69. void imperative_log_profile_begin(const char* message, const char* device) {
  70. auto comp_node = CompNode::load(device);
  71. MGB_RECORD_EVENT(CustomEvent, std::string{message}, {}, comp_node);
  72. MGB_RECORD_EVENT(
  73. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  74. }
  75. SYMBOL_EXPORT
  76. void imperative_log_profile_end(const char* message, const char* device) {
  77. auto comp_node = CompNode::load(device);
  78. MGB_RECORD_EVENT(
  79. RecordDeviceEvent, EventPool::with_timer().alloc_shared(comp_node));
  80. MGB_RECORD_EVENT(CustomFinishEvent, std::string{message}, {}, comp_node);
  81. }
  82. } // namespace mgb
  83. std::thread::id ChannelImpl::get_worker_tid() {
  84. return m_worker_state.tid;
  85. }
  86. ChannelImpl::ChannelState& ChannelImpl::get_channel_state() {
  87. assert_in_channel();
  88. return m_channel_state;
  89. }
  90. ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
  91. assert_in_worker();
  92. return m_worker_state;
  93. }
  94. void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
  95. sys::set_thread_name("worker");
  96. m_owner->m_worker_state.tid = std::this_thread::get_id();
  97. auto custom_allocator = [&](CompNode device, size_t size) {
  98. auto blob = Blob::make(device, size);
  99. m_owner->alloc_tensor_with_evict(blob.get());
  100. return blob->storage();
  101. };
  102. OpDef::set_allocator(custom_allocator);
  103. }
  104. // Do not use m_xxx_state directly
  105. #define m_channel_state
  106. #define m_worker_state
  107. std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
  108. return std::make_unique<ChannelImpl>();
  109. }
  110. Interpreter& Interpreter::inst() {
  111. static InterpreterImpl inst_;
  112. return inst_;
  113. }
  114. Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
  115. MGB_LOCK_GUARD(m_spin);
  116. mgb_assert(check_available(), "Channel already closed");
  117. std::optional<StackManager::Guard> guard;
  118. if (Profiler::is_profiling()) {
  119. auto& state = get_channel_state();
  120. guard.emplace("Put", &state.stack_manager);
  121. }
  122. auto info = put_impl(value, no_cache);
  123. return reinterpret_cast<Handle>(info);
  124. }
  125. TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
  126. if (value.empty()) {
  127. auto layout = value.layout();
  128. layout.init_contiguous_stride();
  129. const_cast<HostTensorND&>(value).reset(value.storage(), layout);
  130. }
  131. auto info = alloc();
  132. constexpr int size_threshold = TensorShape::MAX_NDIM;
  133. init(info, {value.layout(), value.comp_node()});
  134. if (value.layout().total_nr_elems() <= size_threshold) {
  135. info->h_value = value;
  136. info->desc.value = value.proxy_to_default_cpu();
  137. }
  138. if (Profiler::is_profiling()) {
  139. m_worker.add_task(
  140. {Profiler::next_id(), Put{info, value, no_cache},
  141. get_channel_state().stack_manager.dump()});
  142. } else {
  143. m_worker.add_task({
  144. Profiler::next_id(),
  145. Put{info, value, no_cache},
  146. });
  147. }
  148. if (m_async_level == 0) {
  149. sync_impl();
  150. info->desc.comp_node.sync();
  151. auto err = info->desc.comp_node.check_async_error();
  152. mgb_assert(!err, "%s", err->what());
  153. }
  154. return info;
  155. }
  156. Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
  157. MGB_LOCK_GUARD(m_spin);
  158. mgb_assert(check_available(), "Channel already closed");
  159. return reinterpret_cast<Handle>(put_impl(data, hvalue));
  160. }
  161. TensorInfo* ChannelImpl::put_impl(
  162. const DeviceTensorND& data, const HostTensorND& hvalue) {
  163. std::optional<StackManager::Guard> guard;
  164. if (Profiler::is_profiling()) {
  165. auto& state = get_channel_state();
  166. guard.emplace("Put", &state.stack_manager);
  167. }
  168. auto info = alloc();
  169. MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
  170. constexpr int size_threshold = TensorShape::MAX_NDIM;
  171. init(info, {data.layout(), data.comp_node()});
  172. if ((!hvalue.empty()) && info->desc.layout.total_nr_elems() <= size_threshold) {
  173. info->desc.value = hvalue.proxy_to_default_cpu();
  174. }
  175. info->ptr = Tensor::make(data, hvalue);
  176. MGB_RECORD_EVENT(
  177. TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
  178. data.raw_ptr());
  179. info->status = TensorInfo::Produced;
  180. MGB_RECORD_EVENT(TensorCommandFinishEvent, info->id, TensorCommandKind::Put);
  181. return info;
  182. }
  183. void ChannelImpl::del(Handle handle) {
  184. MGB_LOCK_GUARD(m_spin);
  185. if (!check_available()) {
  186. return;
  187. }
  188. del_impl(handle);
  189. }
  190. void ChannelImpl::del_impl(Handle handle) {
  191. mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
  192. auto* info = reinterpret_cast<TensorInfo*>(handle);
  193. m_valid_handle.erase(handle);
  194. if (Profiler::is_profiling()) {
  195. m_worker.add_task(
  196. {Profiler::next_id(), Del{info},
  197. get_channel_state().stack_manager.dump()});
  198. } else {
  199. m_worker.add_task({
  200. Profiler::next_id(),
  201. Del{info},
  202. });
  203. }
  204. }
  205. void ChannelImpl::drop(Handle handle) {
  206. MGB_LOCK_GUARD(m_spin);
  207. mgb_assert(check_available(), "Channel already closed");
  208. auto& state = get_channel_state();
  209. if (state.options.enable_drop) {
  210. mgb_assert(
  211. m_valid_handle.find(handle) != m_valid_handle.end(),
  212. "invalid handle: %p", handle);
  213. auto* info = reinterpret_cast<TensorInfo*>(handle);
  214. if (Profiler::is_profiling()) {
  215. m_worker.add_task(
  216. {Profiler::next_id(), Drop{info},
  217. get_channel_state().stack_manager.dump()});
  218. } else {
  219. m_worker.add_task({
  220. Profiler::next_id(),
  221. Drop{info},
  222. });
  223. }
  224. }
  225. }
  226. void ChannelImpl::dispatch_default_cpu(
  227. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  228. const SmallVector<LogicalTensorDesc>& input_descs,
  229. SmallVector<Handle>* outputs) {
  230. auto& state = get_channel_state();
  231. std::optional<StackManager::Guard> guard;
  232. if (Profiler::is_profiling()) {
  233. guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
  234. }
  235. auto [output_descs, validated] =
  236. OpDef::infer_output_attrs_fallible(*op, input_descs);
  237. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  238. SmallVector<DeviceTensorND> input_tensornds;
  239. CompNode output_cn;
  240. {
  241. MGB_LOCK_GUARD(m_mutex);
  242. for (auto&& info : input_infos) {
  243. auto input_cn = info->desc.comp_node;
  244. if (!output_cn.valid()) {
  245. output_cn = input_cn;
  246. } else {
  247. mgb_assert(output_cn == input_cn, "cannot decide output comp node");
  248. }
  249. if (info->ptr && info->ptr->try_get_value()) {
  250. input_tensornds.emplace_back(
  251. info->ptr->get_value().proxy_to_default_cpu());
  252. } else {
  253. // We assign h_value before drop ptr
  254. mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
  255. input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
  256. }
  257. }
  258. }
  259. SmallVector<DeviceTensorND> output_tensornds;
  260. for (auto&& desc : output_descs) {
  261. // TODO: may conflict with condtake, which need alloc inside
  262. mgb_assert(!desc.layout.is_empty());
  263. // use HostTensorND alloc_host for cuda pinned memory
  264. output_tensornds.emplace_back(
  265. HostTensorND(output_cn, desc.layout).proxy_to_default_cpu());
  266. }
  267. uint64_t op_id = Profiler::next_id();
  268. if (op->trait()->apply_on_device_tensornd) {
  269. OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
  270. } else {
  271. // proxy to apply_on_physical_tensor
  272. SmallVector<TensorPtr> input_tensors;
  273. for (auto&& input_tensornd : input_tensornds) {
  274. input_tensors.push_back(Tensor::make(
  275. input_tensornd, HostTensorND::make_proxy(input_tensornd)));
  276. }
  277. auto output_tensors = OpDef::apply_on_physical_tensor(
  278. *op, input_tensors, output_descs, validated);
  279. for (size_t i = 0; i < output_tensors.size(); ++i) {
  280. output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor());
  281. }
  282. }
  283. SmallVector<TensorInfo*> output_infos;
  284. for (auto&& tensornd : output_tensornds) {
  285. HostTensorND host_tensornd =
  286. HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
  287. // use `put` for consistency
  288. auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
  289. mgb_assert(info->desc.layout.ndim != 0);
  290. output_infos.push_back(info);
  291. outputs->push_back(reinterpret_cast<Handle>(info));
  292. }
  293. auto op_info_getter = [op] {
  294. std::unordered_map<std::string, std::string> op_info;
  295. auto props = OpDef::props(*op);
  296. for (auto&& [key, value] : props) {
  297. op_info[key] = value;
  298. }
  299. return op_info;
  300. };
  301. MGB_RECORD_EVENT(
  302. OpDispatchEvent, op_id, guard.value().name(), op_info_getter,
  303. tinfo_to_tid(input_infos), tinfo_to_tid(output_infos),
  304. state.stack_manager.dump());
  305. }
  306. void ChannelImpl::dispatch_kernel(
  307. std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
  308. const SmallVector<LogicalTensorDesc>& input_descs,
  309. SmallVector<Handle>* outputs) {
  310. auto& state = get_channel_state();
  311. auto& options = state.options;
  312. std::optional<StackManager::Guard> guard;
  313. if (Profiler::is_profiling()) {
  314. guard.emplace(op->trait()->make_name(*op), &state.stack_manager);
  315. }
  316. auto [output_descs, validated] =
  317. OpDef::infer_output_attrs_fallible(*op, input_descs);
  318. MGB_RECORD_EVENT(ShapeInferEvent, validated);
  319. SmallVector<TensorInfo*> output_infos;
  320. output_infos.reserve(output_descs.size());
  321. outputs->reserve(output_descs.size());
  322. for (int i = 0; i < output_descs.size(); ++i) {
  323. auto&& desc = output_descs[i];
  324. auto info = alloc();
  325. init(info, std::move(desc));
  326. // make sure desc's value is consistent with h_value
  327. if (!info->desc.value.empty()) {
  328. info->h_value = HostTensorND::make_proxy(info->desc.value)
  329. .proxy_to_comp_node(info->desc.comp_node);
  330. }
  331. output_infos.push_back(info);
  332. outputs->push_back(reinterpret_cast<Handle>(info));
  333. }
  334. ApplyOp cmd{
  335. Profiler::next_id(), std::move(op), std::move(input_infos),
  336. std::move(output_infos), validated};
  337. if (Profiler::is_profiling()) {
  338. auto op_info_getter = [op = cmd.op] {
  339. std::unordered_map<std::string, std::string> op_info;
  340. auto props = OpDef::props(*op);
  341. for (auto&& [key, value] : props) {
  342. op_info[key] = value;
  343. }
  344. return op_info;
  345. };
  346. MGB_RECORD_EVENT(
  347. OpDispatchEvent, cmd.id, guard.value().name(), op_info_getter,
  348. tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs),
  349. state.stack_manager.dump());
  350. m_worker.add_task(
  351. {Profiler::next_id(), std::move(cmd),
  352. get_channel_state().stack_manager.dump()});
  353. } else {
  354. m_worker.add_task({
  355. Profiler::next_id(),
  356. std::move(cmd),
  357. });
  358. }
  359. if (!validated && options.async_level == 1) {
  360. sync_impl();
  361. } else if (options.async_level == 0) {
  362. sync_impl();
  363. // check device error
  364. for (auto&& oup : *outputs) {
  365. auto info = reinterpret_cast<TensorInfo*>(oup);
  366. info->ptr->comp_node().sync();
  367. auto err = info->ptr->comp_node().check_async_error();
  368. mgb_assert(!err, "%s", err->what());
  369. }
  370. }
  371. }
  372. SmallVector<Handle> ChannelImpl::apply_op(
  373. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  374. MGB_LOCK_GUARD(m_spin);
  375. mgb_assert(check_available(), "Channel already closed");
  376. auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
  377. if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
  378. size_t ndim = input->desc.layout.ndim;
  379. auto& gvs = op->cast_final_safe<GetVarShape>();
  380. if (gvs.axis == MEGDNN_MAX_NDIM) {
  381. HostTensorND shape_tensor{input->desc.comp_node, {ndim}, dtype::Int32()};
  382. DeviceTensorND shape_tensor_device = shape_tensor.proxy_to_default_cpu();
  383. cg::copy_shape_to_tensor_value(shape_tensor_device, input->desc.layout);
  384. return {reinterpret_cast<Handle>(put_impl(shape_tensor, false))};
  385. }
  386. }
  387. return apply_op_impl(std::move(op), inputs);
  388. }
  389. SmallVector<Handle> ChannelImpl::apply_op_impl(
  390. std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
  391. auto& state = get_channel_state();
  392. for (auto i : inputs) {
  393. mgb_assert(
  394. m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p",
  395. i);
  396. }
  397. SmallVector<TensorInfo*> input_infos;
  398. SmallVector<LogicalTensorDesc> input_descs;
  399. {
  400. MGB_LOCK_GUARD(m_info_spin);
  401. for (auto i : inputs) {
  402. auto info = reinterpret_cast<TensorInfo*>(i);
  403. mgb_assert(
  404. !info->invalid,
  405. "an input tensor is unusable due to previous error");
  406. input_infos.push_back(info);
  407. input_descs.push_back(info->desc);
  408. }
  409. }
  410. SmallVector<Handle> outputs;
  411. DispatchMode dispatch_mode = state.options.enable_host_compute
  412. ? OpDef::decide_dispatch_mode(*op, input_descs)
  413. : DispatchMode::KERNEL;
  414. switch (dispatch_mode) {
  415. case DEFAULT_CPU: {
  416. dispatch_default_cpu(op, input_infos, input_descs, &outputs);
  417. break;
  418. }
  419. case KERNEL: {
  420. dispatch_kernel(op, input_infos, input_descs, &outputs);
  421. break;
  422. }
  423. }
  424. return outputs;
  425. }
  426. HostTensorND ChannelImpl::get_value(Handle handle) {
  427. MGB_LOCK_GUARD(m_spin);
  428. mgb_assert(check_available(), "Channel already closed");
  429. mgb_assert(
  430. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  431. handle);
  432. auto info = reinterpret_cast<TensorInfo*>(handle);
  433. // donnot use info->value_fetched, it's unsafe
  434. mgb_assert(!info->invalid, "tensor is unusable due to previous error");
  435. return wait_tensor(info, TensorProp::HostValue)->get_value();
  436. }
  437. TensorShape ChannelImpl::get_shape(Handle handle) {
  438. MGB_LOCK_GUARD(m_spin);
  439. mgb_assert(check_available(), "Channel already closed");
  440. mgb_assert(
  441. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  442. handle);
  443. auto info = reinterpret_cast<TensorInfo*>(handle);
  444. if (info->desc.layout.ndim != 0) {
  445. return info->desc.layout;
  446. }
  447. TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
  448. mgb_assert(ret.ndim != 0);
  449. return ret;
  450. }
  451. DType ChannelImpl::get_dtype(Handle handle) {
  452. MGB_LOCK_GUARD(m_spin);
  453. mgb_assert(check_available(), "Channel already closed");
  454. mgb_assert(
  455. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  456. handle);
  457. auto info = reinterpret_cast<TensorInfo*>(handle);
  458. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::DType);
  459. auto ret = info->desc.layout.dtype;
  460. mgb_assert(ret.valid());
  461. return ret;
  462. }
  463. CompNode ChannelImpl::get_device(Handle handle) {
  464. MGB_LOCK_GUARD(m_spin);
  465. mgb_assert(check_available(), "Channel already closed");
  466. mgb_assert(
  467. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  468. handle);
  469. auto info = reinterpret_cast<TensorInfo*>(handle);
  470. MGB_RECORD_EVENT(TensorGetPropEvent, info->id, TensorProp::Device);
  471. auto ret = info->desc.comp_node;
  472. mgb_assert(ret.valid());
  473. return ret;
  474. }
  475. DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
  476. MGB_LOCK_GUARD(m_spin);
  477. mgb_assert(check_available(), "Channel already closed");
  478. mgb_assert(
  479. m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
  480. handle);
  481. auto info = reinterpret_cast<TensorInfo*>(handle);
  482. return wait_tensor(info, TensorProp::DevValue)->dev_tensor();
  483. }
  484. void ChannelImpl::sync() {
  485. MGB_LOCK_GUARD(m_spin);
  486. mgb_assert(check_available(), "Channel already closed");
  487. sync_impl();
  488. }
  489. void ChannelImpl::sync_impl() {
  490. m_worker.wait_all_task_finish();
  491. MGB_LOCK_GUARD(m_mutex);
  492. check_worker_exc_unsafe();
  493. }
  494. void ChannelImpl::close() {
  495. MGB_LOCK_GUARD(m_spin);
  496. if (!check_available()) {
  497. return;
  498. }
  499. std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
  500. for (auto* handle : valid_handles) {
  501. del_impl(handle);
  502. }
  503. mgb_assert(m_valid_handle.empty());
  504. mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
  505. sync_impl();
  506. m_closed = true;
  507. }
  508. size_t ChannelImpl::get_option(std::string name) {
  509. MGB_LOCK_GUARD(m_spin);
  510. mgb_assert(check_available(), "Channel already closed");
  511. auto& state = get_channel_state();
  512. return state.options.get_option(name);
  513. }
  514. void ChannelImpl::set_option(std::string name, size_t value) {
  515. MGB_LOCK_GUARD(m_spin);
  516. mgb_assert(check_available(), "Channel already closed");
  517. auto& state = get_channel_state();
  518. state.options.set_option(name, value);
  519. // FIXME
  520. if (name == "enable_dtr_auto_drop" && value) {
  521. auto custom_allocator = [&](CompNode device, size_t size) {
  522. auto blob = Blob::make(device, size);
  523. alloc_tensor_with_evict(blob.get());
  524. return blob->storage();
  525. };
  526. BlobManager::inst()->set_allocator(custom_allocator);
  527. }
  528. if (Profiler::is_profiling()) {
  529. m_worker.add_task(
  530. {Profiler::next_id(), SetOption{name, value},
  531. get_channel_state().stack_manager.dump()});
  532. } else {
  533. m_worker.add_task({
  534. Profiler::next_id(),
  535. SetOption{name, value},
  536. });
  537. }
  538. }
  539. void ChannelImpl::clear_candidates() {
  540. MGB_LOCK_GUARD(m_spin);
  541. mgb_assert(check_available(), "Channel already closed");
  542. m_dtr.candidates.clear();
  543. }
  544. TensorInfo* ChannelImpl::alloc() {
  545. auto& state = get_channel_state();
  546. auto info = [this] {
  547. MGB_LOCK_GUARD(m_pool_spin);
  548. return m_pool.alloc();
  549. }();
  550. info->id = Profiler::next_id();
  551. if (Profiler::is_profiling()) {
  552. size_t tensor_id = state.stack_manager.current()->next_id("tensor");
  553. info->name =
  554. state.stack_manager.dump().to_string() + ssprintf(":%zu", tensor_id);
  555. }
  556. return info;
  557. }
  558. void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
  559. m_valid_handle.insert(reinterpret_cast<Handle>(info));
  560. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  561. info->status = TensorInfo::Allocated;
  562. info->desc = std::move(desc);
  563. }
  564. void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
  565. if (!ptr->producer) {
  566. if (user) {
  567. mgb_log_warn(
  568. "the input that produced tensor %p has been deleted, this drop "
  569. "operation will be ignored",
  570. ptr);
  571. }
  572. return;
  573. }
  574. if (ptr->evict_type != EvictType::NONE) {
  575. return;
  576. }
  577. ptr->evict_type = EvictType::DROP;
  578. ptr->status = TensorInfo::Dropped;
  579. release_tensor(ptr);
  580. }
  581. void ChannelImpl::free(TensorInfo* ptr) {
  582. auto& state = get_worker_state();
  583. if (state.options.enable_dtr_auto_drop) {
  584. // Evicting a tensor, rather than freeing it, can avoid pinning
  585. // potentially exploding amounts of memory and allow us to save
  586. // more memory.
  587. ptr->allow_delete = true;
  588. if (!ptr->ref_cnt) {
  589. recursive_free(ptr);
  590. } else {
  591. do_drop(ptr);
  592. }
  593. } else {
  594. real_free(ptr);
  595. }
  596. }
  597. void ChannelImpl::recursive_free(TensorInfo* ptr) {
  598. MGB_RECORD_EVENT(TensorCommandEvent, ptr->id, TensorCommandKind::RecFree);
  599. SmallVector<TensorInfo*> inps;
  600. if (ptr->producer) {
  601. for (auto i : ptr->producer->inputs) {
  602. if (i && --i->ref_cnt == 0) {
  603. inps.push_back(i);
  604. }
  605. }
  606. }
  607. real_free(ptr);
  608. for (auto i : inps) {
  609. if (i->allow_delete) {
  610. recursive_free(i);
  611. }
  612. }
  613. MGB_RECORD_EVENT(TensorCommandFinishEvent, ptr->id, TensorCommandKind::RecFree);
  614. }
  615. void ChannelImpl::real_free(TensorInfo* ptr) {
  616. auto& state = get_worker_state();
  617. if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  618. m_dtr.erase_candidate(ptr);
  619. }
  620. detach_users(ptr);
  621. ptr->detach_producer();
  622. bool has_value = ptr->ptr != nullptr;
  623. if (has_value) {
  624. MGB_RECORD_EVENT(TensorReleaseEvent, ptr->id);
  625. }
  626. MGB_RECORD_EVENT(TensorEraseEvent, ptr->id, ptr->ptr_use_count);
  627. ptr->status = TensorInfo::Deleted;
  628. MGB_LOCK_GUARD(m_pool_spin);
  629. m_pool.free(ptr);
  630. }
  631. ChannelImpl::ChannelImpl() : m_worker(this) {}
  632. ChannelImpl::~ChannelImpl() {
  633. close();
  634. }
  635. void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
  636. auto& state = get_worker_state();
  637. MGB_LOCK_GUARD(m_mutex);
  638. m_dtr.update_used_time(dest);
  639. MGB_RECORD_EVENT(
  640. TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
  641. ptr->raw_ptr_not_for_readwrite());
  642. // update tensor desc for static infer
  643. if (dest->desc.layout.ndim) {
  644. mgb_assert(
  645. dest->desc.layout.eq_shape(ptr->layout()),
  646. "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
  647. ptr->layout().to_string().c_str());
  648. }
  649. // in order to avoid performance impact,
  650. // memory forwarding is disabled when DTR is enabled
  651. if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
  652. ptr->to_contiguous_inplace();
  653. }
  654. dest->desc.comp_node = ptr->comp_node();
  655. dest->memory = ptr->blob()->size();
  656. dest->ptr = std::move(ptr);
  657. dest->evict_type = EvictType::NONE;
  658. dest->status = TensorInfo::Produced;
  659. if (dest->pinned == 0 &&
  660. dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  661. m_dtr.insert_candidate(dest);
  662. }
  663. notify_tensor_unsafe(dest);
  664. }
  665. void ChannelImpl::release_tensor(TensorInfo* dest) {
  666. MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
  667. MGB_LOCK_GUARD(m_mutex);
  668. dest->ptr.reset();
  669. auto& state = get_worker_state();
  670. if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
  671. m_dtr.erase_candidate(dest);
  672. }
  673. }
  674. void ChannelImpl::regenerate(TensorInfo* dest) {
  675. if (dest->evict_type == EvictType::DROP) {
  676. auto&& path = dest->producer;
  677. m_apply_stack.push(
  678. {ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
  679. "dtr"});
  680. if (!m_applying)
  681. flush_apply_stack();
  682. }
  683. }
  684. void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
  685. using namespace ranges;
  686. using namespace ranges::views;
  687. auto& state = get_worker_state();
  688. bool profiling_device =
  689. Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
  690. uint64_t apply_id = cmd.id;
  691. SmallVector<TensorPtr> inputs;
  692. inputs.reserve(cmd.inputs.size());
  693. // refcnt == 1, owners: [TensorInfo::ptr]
  694. for (auto i : cmd.inputs) {
  695. mgb_assert(i->ptr, "Invalid input tensor ptr!");
  696. // refcnt ++, owners: [i->ptr, tensor_inputs]
  697. // tensor_inputs.push_back(i->ptr);
  698. inputs.push_back(i->ptr);
  699. }
  700. if (state.options.enable_dtr_auto_drop &&
  701. state.options.dtr_eviction_threshold > 0) {
  702. auto_evict(0);
  703. }
  704. auto apply_on_physical_tensor =
  705. [&](auto&& self, const OpDef& def, SmallVector<TensorPtr>&& inputs,
  706. SmallVector<LogicalTensorDesc>& output_descs,
  707. const bool& validated) -> SmallVector<TensorPtr> {
  708. if (def.trait()->make_forward_graph) {
  709. auto apply_functor = [&](std::shared_ptr<OpDef> op,
  710. SmallVector<TensorPtr> inputs,
  711. size_t nr_outputs) -> SmallVector<TensorPtr> {
  712. auto opname = op->trait()->make_name(*op);
  713. imperative_log_profile_begin(opname.c_str());
  714. auto outputs = self(self, *op, std::move(inputs), output_descs, false);
  715. imperative_log_profile_end(opname.c_str());
  716. return outputs;
  717. };
  718. auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
  719. // apply recursivily
  720. SmallVector<LogicalTensorDesc> input_descs;
  721. for (auto&& input : inputs) {
  722. input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
  723. }
  724. auto forward_graph = OpDef::make_forward_graph(def, input_descs);
  725. auto outputs = forward_graph.apply<TensorPtr>(
  726. inputs, apply_functor, const_functor);
  727. return outputs;
  728. }
  729. // Check Input Layout
  730. // Get the input layout constraints, and if the constraint is not satisfied
  731. // inplace update the layout and blob to make the tensor contiguous
  732. auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
  733. for (size_t idx = 0; idx < inputs.size(); ++idx) {
  734. auto&& layout_checker = constraints[idx];
  735. if (layout_checker) {
  736. inputs[idx]->to_contiguous_inplace(layout_checker);
  737. }
  738. }
  739. auto outputs = OpDef::apply_on_physical_tensor(
  740. def, std::move(inputs), output_descs, validated);
  741. for (auto& o : outputs) {
  742. o->set_ready_event(
  743. record_event(o->comp_node(), def.same_type<imperative::Barrier>()));
  744. }
  745. return outputs;
  746. };
  747. MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
  748. SmallVector<std::pair<CompNode, uint64_t>> kernels;
  749. if (profiling_device) {
  750. // Collecting devices
  751. SmallVector<CompNode> devices;
  752. for (auto&& i : concat(cmd.inputs, cmd.outputs)) {
  753. if (i != nullptr && count(devices, i->desc.comp_node) == 0) {
  754. devices.push_back(i->desc.comp_node);
  755. kernels.push_back({i->desc.comp_node, Profiler::next_id()});
  756. }
  757. }
  758. }
  759. for (auto* input : cmd.inputs) {
  760. auto input_id = input->id;
  761. MGB_RECORD_EVENT(OpInputEvent, input_id);
  762. MGB_RECORD_EVENT(TensorUsageEvent, input_id);
  763. MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
  764. }
  765. // Before wait
  766. // TODO: split operator wait and execute so that OpWait could be corrected recorded.
  767. // Before execute
  768. for (auto&& [device, kernel_id] : kernels) {
  769. MGB_RECORD_EVENT(KernelLaunchEvent, apply_id, kernel_id, device);
  770. MGB_RECORD_EVENT_IF(
  771. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  772. Timer::record_device(device));
  773. }
  774. // Apply op
  775. SmallVector<LogicalTensorDesc> output_descs;
  776. bool validated = cmd.validated;
  777. if (!state.options.enable_dtr_auto_drop) {
  778. for (auto i : cmd.outputs) {
  779. output_descs.push_back(i->desc);
  780. }
  781. } else {
  782. validated = false;
  783. }
  784. // Here std::move is REQUIRED for removing duplicated references.
  785. auto outputs = apply_on_physical_tensor(
  786. apply_on_physical_tensor, *cmd.op, std::move(inputs), output_descs,
  787. validated);
  788. // After execute
  789. for (auto&& [device, kernel_id] : kernels) {
  790. MGB_RECORD_EVENT_IF(
  791. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  792. Timer::record_device(device));
  793. MGB_RECORD_EVENT(KernelLaunchFinishEvent, apply_id, kernel_id, device);
  794. }
  795. // End profiling operator
  796. mgb_assert(outputs.size() == cmd.outputs.size());
  797. for (size_t i = 0; i < outputs.size(); ++i) {
  798. auto output = cmd.outputs[i];
  799. if (mgb_unlikely(output == nullptr)) {
  800. MGB_RECORD_EVENT(OpOutputEvent, 0);
  801. MGB_RECORD_EVENT(OpOutputFinishEvent, 0);
  802. } else if (mgb_unlikely(output->ptr != nullptr)) {
  803. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  804. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  805. } else {
  806. MGB_RECORD_EVENT(OpOutputEvent, output->id);
  807. produce_tensor(output, outputs[i]);
  808. MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
  809. sample_on_device(output->desc.comp_node, false);
  810. }
  811. }
  812. if (state.options.enable_dtr_auto_drop) {
  813. double estimate_compute_time = 0;
  814. for (auto i : cmd.inputs) {
  815. estimate_compute_time += i->memory;
  816. }
  817. for (auto i : outputs) {
  818. estimate_compute_time += i->blob()->size();
  819. }
  820. m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
  821. for (auto i : cmd.outputs) {
  822. if (i != nullptr) {
  823. i->compute_time = estimate_compute_time;
  824. }
  825. }
  826. m_dtr.unpin(cmd.inputs, state);
  827. }
  828. MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
  829. // End profiling operator
  830. }
  831. void ChannelImpl::flush_apply_stack() {
  832. m_applying = true;
  833. auto& state = get_worker_state();
  834. while (!m_apply_stack.empty()) {
  835. auto& [cmd, idx, recomp, reason] =
  836. m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory
  837. if (idx == 0) {
  838. if (state.options.enable_dtr_auto_drop) {
  839. m_dtr.pin(cmd.inputs);
  840. }
  841. if (recomp) {
  842. MGB_RECORD_EVENT(
  843. TensorCommandEvent, recomp->id, TensorCommandKind::ReGen);
  844. }
  845. }
  846. bool regen = false;
  847. for (size_t i = idx; i < cmd.inputs.size(); i++) {
  848. auto&& p = cmd.inputs[i];
  849. if (state.options.enable_dtr_auto_drop) {
  850. m_dtr.update_used_time(p);
  851. }
  852. if (!p->ptr && p->evict_type != EvictType::NONE) {
  853. idx = i + 1;
  854. regenerate(p); // add ApplyOp to the stack
  855. regen = true;
  856. break;
  857. }
  858. }
  859. if (regen)
  860. continue;
  861. // the required input tensors are already in memory
  862. auto [cmd_backup, recomp_backup, reason_backup] =
  863. std::make_tuple(cmd, recomp, reason);
  864. m_apply_stack.pop();
  865. do_apply_op(cmd_backup, reason_backup);
  866. if (recomp_backup) {
  867. MGB_RECORD_EVENT(
  868. TensorCommandFinishEvent, recomp_backup->id,
  869. TensorCommandKind::ReGen);
  870. for (auto o : cmd_backup.outputs) {
  871. if (o) {
  872. m_dtr.update_dsu_after_recompute(o);
  873. }
  874. }
  875. }
  876. }
  877. m_applying = false;
  878. }
  879. bool ChannelImpl::auto_evict(size_t force_num) {
  880. auto& state = get_worker_state();
  881. if (!m_dtr.comp_node.valid()) {
  882. return false;
  883. }
  884. size_t current_memory = m_dtr.comp_node.get_used_memory();
  885. size_t flag = false;
  886. while ((state.options.dtr_eviction_threshold > 0 &&
  887. current_memory > state.options.dtr_eviction_threshold) ||
  888. force_num > 0) {
  889. MGB_RECORD_EVENT(AutoEvictEvent);
  890. sample_on_device(m_dtr.comp_node, false);
  891. auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
  892. if (!best) {
  893. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  894. break;
  895. }
  896. if (best->ptr.unique() && best->ptr->blob().unique()) {
  897. current_memory -= best->memory;
  898. if (force_num > 0) {
  899. force_num--;
  900. }
  901. flag = true;
  902. }
  903. do_drop(best);
  904. if (best->evict_type == EvictType::DROP) {
  905. m_dtr.update_dsu_after_evict(best);
  906. }
  907. sample_on_device(m_dtr.comp_node, false);
  908. MGB_RECORD_EVENT(AutoEvictFinishEvent);
  909. }
  910. return flag;
  911. }
  912. void ChannelImpl::detach_users(TensorInfo* dest) {
  913. SmallVector<TensorInfo::ComputePath*> users = dest->users;
  914. for (auto* user : users) {
  915. SmallVector<TensorInfo*> outputs = user->outputs;
  916. SmallVector<TensorInfo*> inputs = user->inputs;
  917. for (auto* output : outputs) {
  918. // When a `ComputePath` is detach from it's input,
  919. // there is no need to reserve it,
  920. // so we detach all output of this path
  921. // to decrease it's `ref_cnt` to zero.
  922. if (output == nullptr) {
  923. continue;
  924. }
  925. regenerate(output);
  926. output->detach_producer();
  927. for (auto* input : inputs) {
  928. input->ref_cnt--;
  929. }
  930. }
  931. // now user is dead
  932. }
  933. mgb_assert(dest->users.empty(), "ComputePath leaking");
  934. }
  935. bool ChannelImpl::check_available() {
  936. return !m_closed;
  937. }
  938. TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
  939. std::unique_lock<decltype(m_mutex)> lock(m_mutex);
  940. mgb_assert(!m_waitee, "duplicate waitee");
  941. m_waitee = info;
  942. m_waitee_id = Profiler::next_id();
  943. MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
  944. bool require_host = prop == TensorProp::HostValue;
  945. auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
  946. bool wait_host = false;
  947. if (require_host && !host_available()) {
  948. // avoid dead lock
  949. lock.unlock();
  950. if (Profiler::is_profiling()) {
  951. m_worker.add_task(
  952. {Profiler::next_id(), GetValue{info},
  953. get_channel_state().stack_manager.dump()});
  954. } else {
  955. m_worker.add_task({
  956. Profiler::next_id(),
  957. GetValue{info},
  958. });
  959. }
  960. lock.lock();
  961. wait_host = true;
  962. }
  963. m_cv.wait(lock, [&]() {
  964. check_worker_exc_unsafe();
  965. return require_host ? host_available() : static_cast<bool>(info->ptr);
  966. });
  967. MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
  968. m_waitee = nullptr;
  969. if (wait_host) {
  970. auto err = info->ptr->comp_node().check_async_error();
  971. mgb_assert(!err, "%s", err->what());
  972. }
  973. return info->ptr;
  974. }
  975. void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) {
  976. if (info == m_waitee) {
  977. MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id);
  978. m_cv.notify_all();
  979. }
  980. }
  981. std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
  982. std::unordered_set<TensorInfo*> valid_tensors;
  983. for (auto* handle : m_valid_handle) {
  984. auto* info = reinterpret_cast<TensorInfo*>(handle);
  985. valid_tensors.insert(info);
  986. }
  987. return valid_tensors;
  988. }
  989. void ChannelImpl::alloc_tensor_with_evict(OwnedBlob* x) {
  990. bool in_worker = (get_worker_tid() == std::this_thread::get_id());
  991. auto reserve_size = [&](size_t size) {
  992. if (!m_dtr.comp_node.valid()) {
  993. return false;
  994. }
  995. while (size > m_dtr.comp_node.get_max_block_size_available()) {
  996. bool evict_suc = auto_evict(1);
  997. if (!evict_suc)
  998. return false;
  999. }
  1000. return true;
  1001. };
  1002. auto pre_level = set_log_level(LogLevel::NO_LOG);
  1003. if (in_worker) {
  1004. reserve_size(x->size());
  1005. }
  1006. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  1007. MGB_CATCH(MemAllocError&, {
  1008. bool suc = false;
  1009. if (in_worker) {
  1010. while (!suc) {
  1011. if (!auto_evict(1)) {
  1012. break;
  1013. }
  1014. MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
  1015. MGB_CATCH(MemAllocError&, { continue; });
  1016. suc = true;
  1017. }
  1018. }
  1019. if (!suc) {
  1020. set_log_level(pre_level);
  1021. mgb_log_warn(
  1022. "reallocating all cuda memory to alleviate fragmentation, the "
  1023. "performance may be affected");
  1024. set_log_level(LogLevel::NO_LOG);
  1025. imperative_log_profile_begin("defrag");
  1026. BlobManager::inst()->defrag(x->comp_node());
  1027. imperative_log_profile_end("defrag");
  1028. BlobManager::inst()->alloc_direct(x, x->size());
  1029. }
  1030. });
  1031. set_log_level(pre_level);
  1032. }
  1033. void ChannelImpl::process_one_task(Command& icmd) {
  1034. using namespace ranges;
  1035. using namespace ranges::views;
  1036. auto& state = get_worker_state();
  1037. auto& options = state.options;
  1038. // TODO: remove std::visit for support osx 10.12
  1039. auto cmd_visitor = [&](const auto& cmd) {
  1040. using T = std::decay_t<decltype(cmd)>;
  1041. if constexpr (std::is_same_v<T, Put>) {
  1042. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Put);
  1043. MGB_RECORD_EVENT_IF(
  1044. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1045. Timer::record_device(cmd.value.comp_node()));
  1046. auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value)
  1047. : Tensor::make(cmd.value);
  1048. MGB_RECORD_EVENT_IF(
  1049. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1050. Timer::record_device(cmd.value.comp_node()));
  1051. produce_tensor(cmd.dest, std::move(value));
  1052. MGB_RECORD_EVENT(
  1053. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Put);
  1054. sample_on_device(cmd.dest->desc.comp_node, false);
  1055. } else if constexpr (std::is_same_v<T, ApplyOp>) {
  1056. for (auto& i : cmd.inputs) {
  1057. if (mgb_unlikely(i->invalid)) {
  1058. MGB_LOCK_GUARD(m_mutex);
  1059. for (auto& i : cmd.outputs) {
  1060. i->invalid = true;
  1061. }
  1062. return;
  1063. }
  1064. }
  1065. if (state.options.enable_dtr_auto_drop) {
  1066. m_apply_stack.push({cmd, 0, nullptr, "cmd"});
  1067. flush_apply_stack();
  1068. for (size_t i = 0; i < cmd.outputs.size(); ++i) {
  1069. auto output = cmd.outputs[i];
  1070. if (output == nullptr) {
  1071. continue;
  1072. }
  1073. output->dsu_ptr = std::make_shared<DsuNode>(output->compute_time);
  1074. }
  1075. } else {
  1076. do_apply_op(cmd, "cmd");
  1077. }
  1078. if (state.options.enable_drop && state.options.record_computing_path) {
  1079. auto is_inplace = [](std::tuple<TensorInfo*, TensorInfo*> tuple2) {
  1080. auto& input = std::get<0>(tuple2);
  1081. auto& output = std::get<1>(tuple2);
  1082. if (!input->ptr || !output->ptr) {
  1083. return false;
  1084. }
  1085. return input->ptr->blob()->storage() ==
  1086. output->ptr->blob()->storage();
  1087. };
  1088. // FIXME: do not use opname as identifier
  1089. auto get_name = [](const OpDef& opdef) {
  1090. if (auto attr = opdef.try_cast_final<OprAttr>()) {
  1091. return attr->type.c_str();
  1092. }
  1093. return opdef.dyn_typeinfo()->name;
  1094. };
  1095. auto is_cross_cn = [comp_node = m_dtr.comp_node](TensorInfo* info) {
  1096. return info->desc.comp_node != comp_node;
  1097. };
  1098. bool cross_cn = any_of(concat(cmd.inputs, cmd.outputs), is_cross_cn);
  1099. bool inplace =
  1100. any_of(cartesian_product(cmd.inputs, cmd.outputs), is_inplace);
  1101. if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
  1102. TensorInfo::ComputePath::make(
  1103. cmd.id, cmd.op, cmd.inputs, cmd.outputs);
  1104. size_t detach_cnt = 0;
  1105. if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
  1106. cmd.outputs.size() == 6) {
  1107. cmd.outputs[0]->detach_producer(); // detach running_mean
  1108. cmd.outputs[1]->detach_producer(); // detach running_var
  1109. for (auto input : cmd.inputs) {
  1110. input->ref_cnt -= 2;
  1111. }
  1112. }
  1113. for (auto output : cmd.outputs) {
  1114. if (output->producer &&
  1115. !output->size_exceeds_thd(
  1116. state.options.dtr_evictee_minimum_size)) {
  1117. output->detach_producer();
  1118. detach_cnt++;
  1119. }
  1120. }
  1121. for (auto input : cmd.inputs) {
  1122. input->ref_cnt -= detach_cnt;
  1123. }
  1124. }
  1125. }
  1126. } else if constexpr (std::is_same_v<T, Del>) {
  1127. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Del);
  1128. CompNode device = cmd.dest->desc.comp_node;
  1129. uint64_t tensor_id = cmd.dest->id;
  1130. free(cmd.dest);
  1131. MGB_RECORD_EVENT(
  1132. TensorCommandFinishEvent, tensor_id, TensorCommandKind::Del);
  1133. sample_on_device(device, false);
  1134. } else if constexpr (std::is_same_v<T, GetValue>) {
  1135. if (cmd.dest->invalid)
  1136. return;
  1137. imperative_log_profile_begin("GetValue");
  1138. if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
  1139. regenerate(cmd.dest);
  1140. }
  1141. cmd.dest->ptr->fetch_value();
  1142. MGB_LOCK_GUARD(m_mutex);
  1143. notify_tensor_unsafe(cmd.dest);
  1144. imperative_log_profile_end("GetValue");
  1145. } else if constexpr (std::is_same_v<T, Drop>) {
  1146. if (cmd.dest->invalid)
  1147. return;
  1148. MGB_RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandKind::Drop);
  1149. do_drop(cmd.dest, true);
  1150. MGB_RECORD_EVENT(
  1151. TensorCommandFinishEvent, cmd.dest->id, TensorCommandKind::Drop);
  1152. } else if constexpr (std::is_same_v<T, SetOption>) {
  1153. options.set_option(cmd.key, cmd.value);
  1154. } else if constexpr (std::is_same_v<T, StartProfile>) {
  1155. MGB_RECORD_EVENT(StartProfileEvent);
  1156. CompNode::sync_all();
  1157. for (auto* info : cmd.capture_tensors) {
  1158. MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
  1159. if (info->status == TensorInfo::Produced) {
  1160. // TODO: handle drop
  1161. MGB_RECORD_EVENT(
  1162. TensorProduceEvent, info->id, info->desc.layout,
  1163. info->desc.comp_node, info->ptr->dev_tensor().raw_ptr());
  1164. }
  1165. }
  1166. CompNode::foreach ([&](CompNode device) {
  1167. sample_on_device(device, true);
  1168. MGB_RECORD_EVENT_IF(
  1169. (Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
  1170. Timer::record_device(device));
  1171. });
  1172. MGB_RECORD_EVENT(StartProfileFinishEvent);
  1173. } else if constexpr (std::is_same_v<T, StopProfile>) {
  1174. MGB_RECORD_EVENT(StopProfileEvent);
  1175. for (auto* info : cmd.escape_tensors) {
  1176. bool has_value = info->status == TensorInfo::Produced;
  1177. if (has_value) {
  1178. MGB_RECORD_EVENT(TensorReleaseEvent, info->id);
  1179. }
  1180. MGB_RECORD_EVENT(TensorEraseEvent, info->id);
  1181. }
  1182. CompNode::foreach (
  1183. [&](CompNode device) { sample_on_device(device, true); });
  1184. MGB_RECORD_EVENT(StopProfileFinishEvent);
  1185. } else if constexpr (std::is_same_v<T, PushScope>) {
  1186. MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name);
  1187. } else if constexpr (std::is_same_v<T, PopScope>) {
  1188. MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
  1189. } else {
  1190. static_assert(!std::is_same_v<T, T>);
  1191. }
  1192. };
  1193. std::visit(
  1194. [&](const auto& cmd) {
  1195. using T = std::decay_t<decltype(cmd)>;
  1196. if (!options.catch_worker_execption) {
  1197. cmd_visitor(cmd);
  1198. return;
  1199. }
  1200. try {
  1201. cmd_visitor(cmd);
  1202. } catch (...) {
  1203. MGB_LOCK_GUARD(m_mutex);
  1204. if constexpr (std::is_same_v<T, ApplyOp>) {
  1205. for (auto oup : cmd.outputs) {
  1206. oup->invalid = true;
  1207. }
  1208. } else if constexpr (std::is_same_v<T, Put>) {
  1209. cmd.dest->invalid = true;
  1210. }
  1211. m_worker_exc = std::current_exception();
  1212. MGB_RECORD_EVENT(WorkerExceptionEvent);
  1213. if (m_waitee) {
  1214. notify_tensor_unsafe(m_waitee);
  1215. }
  1216. }
  1217. },
  1218. icmd.data);
  1219. }
  1220. void ChannelImpl::check_worker_exc_unsafe() {
  1221. if (m_worker_exc) {
  1222. // for reuse interpreter_for_py after some exception tests
  1223. m_waitee = nullptr;
  1224. std::exception_ptr exc;
  1225. std::swap(exc, m_worker_exc);
  1226. try {
  1227. std::rethrow_exception(exc);
  1228. } catch (...) {
  1229. throw AsyncError();
  1230. }
  1231. }
  1232. }
  1233. void ChannelImpl::start_profile() {
  1234. MGB_LOCK_GUARD(m_spin);
  1235. mgb_assert(check_available(), "Channel already closed");
  1236. auto capture_tensors = collect_valid_tensors();
  1237. if (capture_tensors.size() > 0) {
  1238. if (Profiler::is_profiling()) {
  1239. m_worker.add_task(
  1240. {Profiler::next_id(), StartProfile{std::move(capture_tensors)},
  1241. get_channel_state().stack_manager.dump()});
  1242. } else {
  1243. m_worker.add_task({
  1244. Profiler::next_id(),
  1245. StartProfile{std::move(capture_tensors)},
  1246. });
  1247. }
  1248. }
  1249. }
  1250. void ChannelImpl::stop_profile() {
  1251. MGB_LOCK_GUARD(m_spin);
  1252. mgb_assert(check_available(), "Channel already closed");
  1253. auto escape_tensors = collect_valid_tensors();
  1254. if (escape_tensors.size() > 0) {
  1255. if (Profiler::is_profiling()) {
  1256. m_worker.add_task(
  1257. {Profiler::next_id(), StopProfile{std::move(escape_tensors)},
  1258. get_channel_state().stack_manager.dump()});
  1259. } else {
  1260. m_worker.add_task({
  1261. Profiler::next_id(),
  1262. StopProfile{std::move(escape_tensors)},
  1263. });
  1264. }
  1265. }
  1266. }
  1267. void ChannelImpl::push_scope(std::string name) {
  1268. MGB_LOCK_GUARD(m_spin);
  1269. mgb_assert(check_available(), "Channel already closed");
  1270. auto& state = get_channel_state();
  1271. state.stack_manager.enter(name);
  1272. MGB_RECORD_EVENT(ScopeEvent, name);
  1273. if (Profiler::is_profiling()) {
  1274. m_worker.add_task(
  1275. {Profiler::next_id(), PushScope{name},
  1276. get_channel_state().stack_manager.dump()});
  1277. } else {
  1278. m_worker.add_task({
  1279. Profiler::next_id(),
  1280. PushScope{name},
  1281. });
  1282. }
  1283. }
  1284. void ChannelImpl::pop_scope(std::string name) {
  1285. MGB_LOCK_GUARD(m_spin);
  1286. mgb_assert(check_available(), "Channel already closed");
  1287. auto& state = get_channel_state();
  1288. state.stack_manager.exit(name);
  1289. MGB_RECORD_EVENT(ScopeFinishEvent, name);
  1290. if (Profiler::is_profiling()) {
  1291. m_worker.add_task(
  1292. {Profiler::next_id(), PopScope{name},
  1293. get_channel_state().stack_manager.dump()});
  1294. } else {
  1295. m_worker.add_task({
  1296. Profiler::next_id(),
  1297. PopScope{name},
  1298. });
  1299. }
  1300. }
  1301. void ChannelImpl::assert_in_channel() {
  1302. mgb_assert(
  1303. get_worker_tid() != std::this_thread::get_id(),
  1304. "this method cannot be called in worker thread");
  1305. }
  1306. void ChannelImpl::assert_in_worker() {
  1307. mgb_assert(
  1308. get_worker_tid() == std::this_thread::get_id(),
  1309. "this method can only be called in worker thread");
  1310. }
  1311. void ChannelImpl::sample_on_device(CompNode device, bool force) {
  1312. if (!Profiler::is_profiling()) {
  1313. return;
  1314. }
  1315. if (!force) {
  1316. thread_local int last_sample_id = 0;
  1317. int sample_rate = Profiler::get_option("sample_rate", 0);
  1318. if (!sample_rate || ((++last_sample_id) % sample_rate != 0)) {
  1319. return;
  1320. }
  1321. }
  1322. MGB_RECORD_EVENT(SampleDeviceEvent, device);
  1323. auto [total, free] = device.get_mem_status_bytes();
  1324. MGB_RECORD_EVENT(SampleDeviceFinishEvent, device, total, free);
  1325. }
  1326. void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
  1327. for (auto i : vec) {
  1328. i->pin();
  1329. erase_candidate(i);
  1330. }
  1331. }
  1332. void ChannelImpl::DynamicSublinear::unpin(
  1333. const SmallVector<TensorInfo*>& vec, WorkerState& state) {
  1334. for (auto i : vec) {
  1335. i->unpin();
  1336. if (i->pinned == 0 &&
  1337. i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
  1338. i->cand_index == UINT_MAX) {
  1339. insert_candidate(i);
  1340. }
  1341. }
  1342. }
  1343. void ChannelImpl::DynamicSublinear::update_dsu_after_recompute(TensorInfo* ptr) {
  1344. auto&& dsu_fa = find_father(ptr->dsu_ptr);
  1345. dsu_fa->t -= ptr->compute_time;
  1346. ptr->dsu_ptr->parent.reset();
  1347. ptr->dsu_ptr->t = ptr->compute_time;
  1348. }
  1349. void ChannelImpl::DynamicSublinear::update_dsu_after_evict(TensorInfo* ptr) {
  1350. for (auto i : ptr->producer->inputs) {
  1351. if (i->evict_type == EvictType::DROP) {
  1352. merge(i->dsu_ptr, ptr->dsu_ptr);
  1353. }
  1354. }
  1355. for (auto i : ptr->producer->outputs) {
  1356. if (i && i->evict_type == EvictType::DROP) {
  1357. merge(ptr->dsu_ptr, i->dsu_ptr);
  1358. }
  1359. }
  1360. }
  1361. double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
  1362. double cost = 0;
  1363. for (auto i : ptr->producer->inputs) {
  1364. if (i->evict_type == EvictType::DROP) {
  1365. double t = find_father(i->dsu_ptr)->t;
  1366. if (t < i->compute_time) {
  1367. t = i->compute_time;
  1368. }
  1369. cost += t;
  1370. }
  1371. }
  1372. for (auto i : ptr->producer->outputs) {
  1373. if (i && i->evict_type == EvictType::DROP) {
  1374. double t = find_father(i->dsu_ptr)->t;
  1375. if (t < i->compute_time) {
  1376. t = i->compute_time;
  1377. }
  1378. cost += t;
  1379. }
  1380. }
  1381. return cost;
  1382. }
  1383. TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
  1384. bool enable_dtr_sqrt_sampling = false) {
  1385. if (candidates.empty())
  1386. return nullptr;
  1387. double min_msps = -1;
  1388. TensorInfo* best = nullptr;
  1389. size_t sz = 1;
  1390. if (enable_dtr_sqrt_sampling) {
  1391. while (sz * sz <= candidates.size())
  1392. sz++;
  1393. sz--;
  1394. } else {
  1395. sz = candidates.size();
  1396. }
  1397. size_t ti = rand() % sz;
  1398. for (size_t vi = 0; vi < sz; vi++) {
  1399. if (!enable_dtr_sqrt_sampling) {
  1400. ti = vi;
  1401. }
  1402. auto i = candidates[ti];
  1403. if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
  1404. double neighbor_cost = estimate_neighbor_cost(i);
  1405. size_t begin_ptr =
  1406. reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
  1407. auto side_info = i->ptr->comp_node().get_free_left_and_right(
  1408. begin_ptr, begin_ptr + i->ptr->blob()->size());
  1409. double free_mem = side_info.first + side_info.second;
  1410. double msps = i->eval_func(
  1411. neighbor_cost, free_mem, estimate_timestamp, 1.0, 1.0, 1.0, 1.0001);
  1412. if (min_msps < 0 || msps < min_msps) {
  1413. min_msps = msps;
  1414. best = i;
  1415. }
  1416. }
  1417. if (enable_dtr_sqrt_sampling) {
  1418. ti += rand() % sz;
  1419. if (ti > candidates.size())
  1420. break;
  1421. }
  1422. }
  1423. return best;
  1424. }
  1425. void ChannelImpl::DynamicSublinear::merge(
  1426. std::shared_ptr<DsuNode>& x, std::shared_ptr<DsuNode>& y) {
  1427. auto&& f_x = find_father(x);
  1428. auto&& f_y = find_father(y);
  1429. if (f_x.get() == f_y.get()) {
  1430. return;
  1431. }
  1432. f_y->t += f_x->t;
  1433. f_x->parent = f_y;
  1434. }
  1435. std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
  1436. std::shared_ptr<DsuNode>& x) {
  1437. if (x->is_root()) {
  1438. return x;
  1439. } else {
  1440. auto&& fa = find_father(x->parent);
  1441. return x->parent = fa;
  1442. }
  1443. }
  1444. void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
  1445. // tensor to be inserted must be brand new
  1446. mgb_assert(
  1447. ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
  1448. ptr->cand_index);
  1449. ptr->cand_index = candidates.size();
  1450. candidates.push_back(ptr);
  1451. if (!comp_node.valid()) {
  1452. comp_node = ptr->ptr->comp_node();
  1453. }
  1454. }
  1455. void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
  1456. // close dtr will just clear candidates, so nothing to erase
  1457. if (candidates.empty()) {
  1458. ptr->cand_index = UINT_MAX;
  1459. return;
  1460. }
  1461. // some tensors may be erased already, just skip them
  1462. if (ptr->cand_index != UINT_MAX) {
  1463. std::swap(candidates[ptr->cand_index], candidates.back());
  1464. candidates[ptr->cand_index]->cand_index = ptr->cand_index;
  1465. candidates.pop_back();
  1466. ptr->cand_index = UINT_MAX;
  1467. }
  1468. }
  1469. void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
  1470. ptr->last_used_time = estimate_timestamp;
  1471. }