You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mini_graph.h 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. /**
  2. * \file imperative/src/impl/proxy_graph/mini_graph.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/graph/operator_node.h"
  12. #include "megbrain/imperative/op_def.h"
  13. #include "megbrain/imperative/ops/autogen.h"
  14. #include "megbrain/imperative/physical_tensor.h"
  15. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  16. #include "../blob_manager_impl.h"
  17. #include "./common.h"
  18. #include "./proxy_graph_base.h"
  19. #include <optional>
  20. #include "range/v3/all.hpp"
  21. namespace mgb::imperative::proxy_graph {
  22. using cg::OperatorNodeBase;
  23. template <typename C, typename E>
  24. std::pair<bool, size_t> find_index(const C& container, const E& item) {
  25. auto&& it = std::find(container.begin(), container.end(), item);
  26. return {it != container.end(), it - container.begin()};
  27. }
  28. template <typename T, typename = void>
  29. class TensorAdaptor;
  30. template <typename T, typename U>
  31. using enable_if_same_upto_cv_t =
  32. std::enable_if_t<std::is_same_v<std::remove_cv_t<T>, std::remove_cv_t<U>>>;
  33. template <typename T>
  34. class TensorAdaptor<T, enable_if_same_upto_cv_t<T, LogicalTensorDesc>> {
  35. T& wrapped;
  36. template <typename U>
  37. using maybe_add_const_t = std::conditional_t<std::is_const_v<T>, const U, U>;
  38. public:
  39. using type = T;
  40. TensorAdaptor(T& desc) : wrapped(desc) {}
  41. TensorAdaptor(T* desc) : wrapped(*desc) {}
  42. DType dtype() { return wrapped.layout.dtype; }
  43. CompNode comp_node() { return wrapped.comp_node; }
  44. maybe_add_const_t<TensorShape>& shape() { return wrapped.layout; }
  45. bool has_value() { return wrapped.value.shape_valid(); }
  46. auto& value() { return wrapped.value; }
  47. auto* operator->() { return &wrapped; }
  48. };
  49. template <typename T>
  50. class TensorAdaptor<T, enable_if_same_upto_cv_t<T, Tensor>> {
  51. Tensor& wrapped;
  52. public:
  53. using type = Tensor;
  54. TensorAdaptor(Tensor& tensor) : wrapped(tensor) {}
  55. TensorAdaptor(Tensor* tensor) : wrapped(*tensor) {}
  56. DType dtype() { return wrapped.dtype(); }
  57. CompNode comp_node() { return wrapped.comp_node(); }
  58. const TensorShape& shape() { return wrapped.shape(); }
  59. type* operator->() { return &wrapped; }
  60. };
  61. // deduction guides
  62. template <typename T>
  63. TensorAdaptor(T&) -> TensorAdaptor<T, void>;
  64. template <typename T>
  65. TensorAdaptor(T*) -> TensorAdaptor<T, void>;
  66. SmallVector<Tensor*> to_raw_ptr_array(
  67. const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) {
  68. SmallVector<Tensor*> ret;
  69. for (auto&& i : inputs) {
  70. mgb_assert(i);
  71. ret.push_back(i.get());
  72. if (ensure_storage) {
  73. // apply lazy allocation
  74. i->blob()->storage();
  75. }
  76. }
  77. return ret;
  78. }
  79. static size_t get_workspace_limit(CompNode cn, size_t old_limit) {
  80. size_t free = cn.get_free_mem();
  81. size_t lmt = cn.get_max_block_size_available();
  82. return std::max(lmt, free);
  83. }
  84. // single opr graph, for static inference and execution
  85. // contains static inference descs
  86. class ProxyGraph::MiniGraph {
  87. protected:
  88. struct InferDepItem {
  89. bool is_input : 1;
  90. size_t idx : 63;
  91. cg::static_infer::DepType type;
  92. };
  93. enum class InferStatus { UNKOWN, READY, FAILED };
  94. // inference desc and pre-allocated storage for a single var
  95. template <typename T>
  96. struct InferData {
  97. SmallVector<InferDepItem> deps;
  98. thin_function<bool(T&, const cg::static_infer::InpVal&)> infer_func;
  99. // pre-allocated infer states
  100. InferStatus status = InferStatus::UNKOWN;
  101. cg::static_infer::InpVal inp_val;
  102. T dest;
  103. void initialize(
  104. OperatorNodeBase* opr, const cg::static_infer::DepVal& dep_val,
  105. const thin_function<bool(T&, const cg::static_infer::InpVal&)>& func) {
  106. mgb_assert(!infer_func);
  107. infer_func = func;
  108. inp_val.val.resize(dep_val.size());
  109. deps.reserve(dep_val.size());
  110. for (auto&& dep : dep_val) {
  111. auto [found, i] = find_index(opr->input(), dep.dest);
  112. if (found) {
  113. deps.push_back({true, i, dep.type});
  114. } else {
  115. auto [found, i] = find_index(opr->output(), dep.dest);
  116. mgb_assert(found);
  117. deps.push_back({false, i, dep.type});
  118. }
  119. }
  120. }
  121. void reset() {
  122. status = InferStatus::UNKOWN;
  123. if constexpr (std::is_same_v<T, TensorShape>) {
  124. dest.ndim = 0;
  125. } else {
  126. static_assert(std::is_same_v<T, DeviceTensorND>);
  127. dest.storage({});
  128. }
  129. }
  130. };
  131. struct OutputData {
  132. InferData<TensorShape> shape_infer;
  133. InferData<DeviceTensorND> value_infer;
  134. };
  135. struct InferSessionBase {
  136. virtual const TensorShape& infer_shape(VarNode*) { mgb_assert(0); }
  137. virtual const TensorShape* infer_shape_fallible(VarNode*) { mgb_assert(0); }
  138. virtual const DeviceTensorND& infer_value(VarNode*) { mgb_assert(0); }
  139. virtual const DeviceTensorND* infer_value_fallible(VarNode*) { mgb_assert(0); }
  140. };
  141. size_t buf_size;
  142. SmallVector<size_t> hash_buf;
  143. OperatorNodeBase* m_opr = nullptr;
  144. SmallVector<std::unique_ptr<OperatorNodeBase>> opr_ref_keeper;
  145. size_t run_id = 0;
  146. SmallVector<OutputData> output_data;
  147. SmallVector<size_t> input_remap;
  148. SmallVector<size_t> output_remap;
  149. // pre-allocated buffer for converted inputs
  150. SmallVector<std::optional<DeviceTensorND>> input_value_storage;
  151. InferSessionBase* m_sess = nullptr;
  152. template <typename T>
  153. struct InputAdaptor {
  154. T& wrapped;
  155. SmallVector<std::optional<DeviceTensorND>>& value_storage;
  156. InputAdaptor(MiniGraph& owner, T& inputs)
  157. : wrapped(inputs), value_storage(owner.input_value_storage) {}
  158. ~InputAdaptor() {
  159. for (auto& i : value_storage) {
  160. i.reset();
  161. }
  162. }
  163. const TensorShape* shape(size_t i) {
  164. TensorAdaptor tensor(wrapped[i]);
  165. auto& shape = tensor.shape();
  166. return shape.ndim ? &shape : nullptr;
  167. }
  168. const DeviceTensorND* value(size_t i, bool sync) {
  169. TensorAdaptor tensor(wrapped[i]);
  170. using tensor_t = std::remove_cv_t<typename decltype(tensor)::type>;
  171. if constexpr (std::is_same_v<tensor_t, Tensor>) {
  172. auto& storage = value_storage[i];
  173. if (!storage) {
  174. if (sync) {
  175. return &storage.emplace(
  176. tensor->get_value().proxy_to_default_cpu());
  177. } else {
  178. if (auto* hv = tensor->try_get_value()) {
  179. return &storage.emplace(hv->proxy_to_default_cpu());
  180. }
  181. return nullptr;
  182. }
  183. }
  184. return &storage.value();
  185. } else {
  186. auto& value = tensor.value();
  187. return value.shape_valid() ? &value : nullptr;
  188. }
  189. }
  190. };
  191. public:
  192. template <typename I, typename G>
  193. MiniGraph(
  194. G& graph, const OpDef& opdef, const I& inputs, const size_t* hash_buf_,
  195. const size_t buf_size_)
  196. : buf_size(buf_size_), input_value_storage(inputs.size()) {
  197. mgb_assert(!m_opr);
  198. auto _ = graph.scoped_attach(this);
  199. cg::VarNodeArray vinputs(inputs.size());
  200. for (auto&& [i, t] : ranges::views::enumerate(inputs)) {
  201. auto tensor = TensorAdaptor(t);
  202. opr_ref_keeper.emplace_back(
  203. new InputPlaceholder(graph, tensor.dtype(), tensor.comp_node()));
  204. vinputs[i] = opr_ref_keeper.back()->output(0);
  205. }
  206. auto ovars = OpDef::apply_on_var_node(opdef, vinputs);
  207. mgb_assert(m_opr);
  208. output_data.resize(m_opr->output().size());
  209. for (auto* v : ovars) {
  210. mgb_assert(v->owner_opr() == m_opr);
  211. }
  212. m_opr->init_output_static_infer_desc();
  213. // fix permuted input: the order of m_opr->input() and vinputs may be
  214. // different, input_remap keeps the index map of m_opr->input() and vinputs
  215. input_remap.reserve(m_opr->input().size());
  216. for (auto* v : m_opr->input()) {
  217. auto [found, i] = find_index(vinputs, v);
  218. mgb_assert(found);
  219. input_remap.push_back(i);
  220. }
  221. auto fix_dep_idx = [&](SmallVector<InferDepItem>& deps) {
  222. for (auto& dep : deps) {
  223. if (dep.is_input) {
  224. dep.idx = input_remap[dep.idx];
  225. }
  226. }
  227. };
  228. for (auto& data : output_data) {
  229. fix_dep_idx(data.shape_infer.deps);
  230. fix_dep_idx(data.value_infer.deps);
  231. }
  232. // fix permuted output
  233. output_remap.reserve(ovars.size());
  234. for (auto* v : ovars) {
  235. auto [found, i] = find_index(m_opr->output(), v);
  236. mgb_assert(found);
  237. output_remap.push_back(i);
  238. }
  239. hash_buf.resize(buf_size);
  240. for (size_t i = 0; i < buf_size; ++i) {
  241. hash_buf[i] = hash_buf_[i];
  242. }
  243. }
  244. bool is_same_buf(const size_t hash_buf_[], const size_t buf_size_) {
  245. if (buf_size != buf_size_) {
  246. return false;
  247. }
  248. for (size_t i = 0; i < buf_size; i++) {
  249. if (hash_buf[i] != hash_buf_[i]) {
  250. return false;
  251. }
  252. }
  253. return true;
  254. }
  255. // methods for containing graph
  256. OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) {
  257. mgb_assert(!m_opr);
  258. m_opr = opr_uniqp.get();
  259. mgb_assert(opr_ref_keeper.back()->owner_graph() == m_opr->owner_graph());
  260. mgb_assert(!m_opr->inserted_in_graph());
  261. opr_ref_keeper.push_back(std::move(opr_uniqp));
  262. m_opr->set_inserted_in_graph();
  263. m_opr->init_output_comp_node();
  264. m_opr->init_output_dtype();
  265. return m_opr;
  266. }
  267. void init_input_tensor(const SmallVector<Tensor*>& inputs) {
  268. auto&& opr_inputs = m_opr->input();
  269. mgb_assert(opr_inputs.size() == inputs.size());
  270. size_t idx = 0;
  271. for (auto&& input : opr_inputs) {
  272. mgb_assert(input->owner_opr()->same_type<InputPlaceholder>());
  273. input->m_dev_tensor.storage({});
  274. auto&& dev_tensor = inputs[input_remap[idx]]->dev_tensor();
  275. auto&& layout = dev_tensor.layout();
  276. input->shape(dev_tensor.shape());
  277. auto&& chk = input->m_mem_plan.reset_from_owner_var().chunk();
  278. input->m_dev_tensor.reset(dev_tensor.storage(), layout);
  279. input->m_mem_plan.layout(layout);
  280. chk.mem_alloc_status.set_from_owner_var();
  281. mgb_assert(input->comp_node() == dev_tensor.comp_node());
  282. mgb_assert(input->shape().eq_shape(layout));
  283. mgb_assert(input->dtype() == layout.dtype);
  284. idx++;
  285. }
  286. }
  287. void init_output_tensor(const SmallVector<Tensor*>& outputs) {
  288. mgb_assert(m_opr->usable_output().size() == outputs.size());
  289. ::mgb::opr::intl::WorkspaceLimitHook::set_impl(
  290. m_opr->owner_graph(), get_workspace_limit);
  291. size_t j = 0;
  292. for (auto&& var : m_opr->output()) {
  293. auto&& chk = var->m_mem_plan.reset_from_owner_var().chunk();
  294. if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  295. TensorLayout layout{var->shape(), var->dtype(), var->format()};
  296. var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(
  297. var->comp_node(), layout);
  298. } else {
  299. mgb_assert(j < outputs.size());
  300. auto&& tensor = outputs[j];
  301. auto&& layout = tensor->layout();
  302. mgb_assert(var->comp_node() == tensor->comp_node());
  303. mgb_assert(var->shape().eq_shape(layout));
  304. mgb_assert(var->dtype() == layout.dtype);
  305. var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
  306. ++j;
  307. }
  308. chk.mem_alloc_status.set_from_owner_var();
  309. }
  310. mgb_assert(j == outputs.size());
  311. // Memory forwarding was bypassed in megbrain with graph option
  312. // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
  313. // to initialize some opr(e.g. Subtensor)'s internal state
  314. // TODO: implement memory forwarding
  315. m_opr->mem_plan_fwd_in2out_readonly();
  316. {
  317. // some opr (e.g. Reduce) rely on on_mem_status_changed to set
  318. // input/output tensor corretly, since we bypass var_node_mem_mgr
  319. // on_mem_status_changed should be called here
  320. auto&& cb = m_opr->get_opr_event_callback().on_mem_status_changed;
  321. if (cb.valid()) {
  322. cb.val()();
  323. }
  324. }
  325. }
  326. void execute(
  327. const SmallVector<Tensor*>& inputs, const SmallVector<Tensor*>& outputs,
  328. cg::GraphExecutable::ExecEnv& env) {
  329. init_input_tensor(inputs);
  330. init_output_tensor(outputs);
  331. m_opr->execute(env);
  332. for (auto&& i : m_opr->input()) {
  333. i->m_dev_tensor.storage({});
  334. }
  335. for (auto&& i : m_opr->output()) {
  336. i->m_dev_tensor.storage({});
  337. }
  338. }
  339. void register_shape_infer(
  340. VarNode* varnode, const cg::static_infer::ShapeInferDesc& desc) {
  341. auto [found, i] = find_index(m_opr->output(), varnode);
  342. mgb_assert(found);
  343. output_data[i].shape_infer.initialize(m_opr, desc.deps, desc.infer_func);
  344. }
  345. void register_value_infer(
  346. VarNode* varnode, const cg::static_infer::ValueInferDesc& desc) {
  347. auto [found, i] = find_index(m_opr->output(), varnode);
  348. mgb_assert(found);
  349. output_data[i].value_infer.initialize(m_opr, desc.deps, desc.infer_func);
  350. }
  351. const TensorShape& infer_shape(VarNode* var) {
  352. mgb_assert(m_sess);
  353. return m_sess->infer_shape(var);
  354. }
  355. const DeviceTensorND& infer_value(VarNode* var) {
  356. mgb_assert(m_sess);
  357. return m_sess->infer_value(var);
  358. }
  359. OperatorNodeBase* opr() { return m_opr; }
  360. // inference routine template for type of input
  361. template <typename I>
  362. class InferSession : protected InferSessionBase {
  363. public:
  364. MiniGraph& owner;
  365. SmallVector<OutputData>& output_data;
  366. InputAdaptor<I> inputs;
  367. template <typename T>
  368. const T* infer(InferData<T>& target, bool sync) {
  369. bool ret;
  370. if (target.status != InferStatus::UNKOWN) {
  371. ret = target.status == InferStatus::READY;
  372. } else {
  373. ret = target.infer_func && do_infer(target, sync);
  374. target.status = ret ? InferStatus::READY : InferStatus::FAILED;
  375. }
  376. return ret ? &target.dest : nullptr;
  377. }
  378. template <typename T>
  379. bool do_infer(InferData<T>& target, bool sync) {
  380. for (size_t i = 0; i < target.deps.size(); ++i) {
  381. target.inp_val.run_id = owner.run_id;
  382. auto& dep = target.deps[i];
  383. if (dep.is_input) {
  384. if (dep.type == cg::static_infer::DepType::SHAPE) {
  385. if (auto* val = inputs.shape(dep.idx)) {
  386. target.inp_val.val[i].m_shape = val;
  387. } else
  388. return false;
  389. } else {
  390. if (auto* val = inputs.value(dep.idx, sync)) {
  391. target.inp_val.val[i].m_value = val;
  392. } else
  393. return false;
  394. }
  395. } else {
  396. if (dep.type == cg::static_infer::DepType::SHAPE) {
  397. if (auto* val = infer(output_data[dep.idx].shape_infer, sync)) {
  398. target.inp_val.val[i].m_shape = val;
  399. } else
  400. return false;
  401. } else {
  402. if (auto* val = infer(output_data[dep.idx].value_infer, sync)) {
  403. target.inp_val.val[i].m_value = val;
  404. } else
  405. return false;
  406. }
  407. }
  408. }
  409. return target.infer_func(target.dest, target.inp_val);
  410. }
  411. // methods for owner mini graph
  412. // corresponding methods of containing ComputingGraph will be redirected here
  413. const TensorShape& infer_shape(VarNode* var) override {
  414. mgb_assert(owner.m_opr);
  415. auto [found, i] = find_index(owner.m_opr->input(), var);
  416. mgb_assert(found);
  417. i = owner.input_remap[i];
  418. auto* shape = inputs.shape(i);
  419. mgb_assert(shape);
  420. return *shape;
  421. }
  422. const DeviceTensorND& infer_value(VarNode* var) override {
  423. mgb_assert(owner.m_opr);
  424. auto [found, i] = find_index(owner.m_opr->input(), var);
  425. mgb_assert(found);
  426. i = owner.input_remap[i];
  427. auto* value = inputs.value(i, true);
  428. mgb_assert(value);
  429. return *value;
  430. }
  431. public:
  432. InferSession(MiniGraph& mgraph, I& inputs_)
  433. : owner(mgraph),
  434. output_data(mgraph.output_data),
  435. inputs(mgraph, inputs_) {
  436. mgraph.run_id++;
  437. mgb_assert(!owner.m_sess);
  438. owner.m_sess = this;
  439. }
  440. ~InferSession() {
  441. owner.m_sess = nullptr;
  442. for (auto& i : output_data) {
  443. i.shape_infer.reset();
  444. i.value_infer.reset();
  445. }
  446. }
  447. const TensorShape* infer_shape(size_t i, bool sync) {
  448. i = owner.output_remap[i];
  449. auto* p = infer(output_data[i].shape_infer, sync);
  450. if (sync)
  451. mgb_assert(p, "failed to infer shape");
  452. return p;
  453. }
  454. const DeviceTensorND* infer_value(size_t i, bool sync) {
  455. i = owner.output_remap[i];
  456. auto* p = infer(output_data[i].value_infer, sync);
  457. if (sync)
  458. mgb_assert(p, "failed to infer value");
  459. return p;
  460. }
  461. };
  462. template <typename T>
  463. InferSession<T> infer_session(T& inputs) {
  464. return InferSession(*this, inputs);
  465. }
  466. size_t output_size() { return output_remap.size(); }
  467. VarNode* output_var(size_t i) {
  468. i = output_remap[i];
  469. return m_opr->output(i);
  470. }
  471. };
  472. class CompNodeTracker {
  473. static constexpr size_t bucket_size = 100;
  474. static constexpr size_t bucket_count = 10;
  475. CompNode comp_node;
  476. std::array<std::unique_ptr<CompNode::Event>, bucket_count> events;
  477. size_t free_slots = bucket_size;
  478. size_t head = 0; // events[head] is not recorded
  479. size_t tail = 0; // events[tail] is not finished
  480. void rotate() {
  481. while (tail < head && events[tail % bucket_count]->finished()) {
  482. ++tail;
  483. }
  484. auto& ev = events[head % bucket_count];
  485. if (head == tail + bucket_count) {
  486. // do not wait if head == tail
  487. ev->host_wait();
  488. ++tail;
  489. }
  490. ev->record();
  491. ++head;
  492. free_slots = bucket_size;
  493. }
  494. public:
  495. CompNodeTracker(CompNode cn) : comp_node(cn) {
  496. for (auto& e : events) {
  497. e = cn.create_event();
  498. }
  499. }
  500. size_t add_opr() {
  501. if (!free_slots)
  502. rotate();
  503. --free_slots;
  504. return head;
  505. }
  506. size_t progress() { return tail; }
  507. };
  508. class ExecMiniGraph : public ProxyGraph::MiniGraph {
  509. union BusyListItem {
  510. size_t finish_time;
  511. OperatorNodeBase* opr;
  512. };
  513. SmallVector<CompNodeTracker*> comp_node_trackers;
  514. std::deque<BusyListItem> busy_oprs;
  515. SmallVector<OperatorNodeBase*> idle_oprs;
  516. OperatorNodeBase* acquire_opr() {
  517. mgb_assert(!m_opr);
  518. if (!idle_oprs.empty()) {
  519. m_opr = idle_oprs.back();
  520. idle_oprs.pop_back();
  521. return m_opr;
  522. }
  523. mgb_assert(busy_oprs.size() > comp_node_trackers.size());
  524. bool can_pop = true;
  525. for (auto [item, tracker] : ranges::views::zip(busy_oprs, comp_node_trackers)) {
  526. if (item.finish_time >= tracker->progress()) {
  527. can_pop = false;
  528. break;
  529. }
  530. }
  531. if (can_pop) {
  532. for (auto _ : comp_node_trackers) {
  533. MGB_MARK_USED_VAR(_);
  534. busy_oprs.pop_front();
  535. }
  536. m_opr = busy_oprs.front().opr;
  537. busy_oprs.pop_front();
  538. return m_opr;
  539. }
  540. }
  541. template <bool in_use>
  542. void release_opr() {
  543. if constexpr (in_use) {
  544. for (auto tracker : comp_node_trackers) {
  545. tracker->add_opr();
  546. }
  547. }
  548. }
  549. };
  550. class ProxyGraphTypeI : public ProxyGraphBase {
  551. class StaticInferManager : public StaticInferManagerBase {
  552. ProxyGraph::MiniGraph* target = nullptr;
  553. friend class ProxyGraphTypeI;
  554. public:
  555. void register_shape_infer(
  556. VarNode* var, const cg::static_infer::ShapeInferDesc& desc) override {
  557. mgb_assert(target);
  558. target->register_shape_infer(var, desc);
  559. };
  560. void register_value_infer(
  561. VarNode* var, const cg::static_infer::ValueInferDesc& desc) override {
  562. mgb_assert(target);
  563. target->register_value_infer(var, desc);
  564. };
  565. cg::static_infer::InferType get_infer_type(VarNode*) override {
  566. return {cg::static_infer::InferType::MISSING_INP,
  567. cg::static_infer::InferType::MISSING_INP};
  568. }
  569. // some poorly written inference func would call infer_{shape,value}
  570. const TensorShape& infer_shape(VarNode* var) override {
  571. mgb_assert(target);
  572. return target->infer_shape(var);
  573. }
  574. const DeviceTensorND& infer_value(VarNode* var) override {
  575. mgb_assert(target);
  576. return target->infer_value(var);
  577. }
  578. };
  579. ProxyGraph::MiniGraph* target = nullptr;
  580. StaticInferManager m_static_infer_manager;
  581. std::unordered_multimap<size_t, ProxyGraph::MiniGraph> m_mini_graph_cache;
  582. std::mutex m_mini_graph_cache_mtx;
  583. size_t opr_count = 0;
  584. ExecEnvBase m_env;
  585. CompNode::UnorderedSet m_used_comp_node;
  586. static thread_local std::unique_ptr<ProxyGraphTypeI> sm_instance;
  587. friend class ProxyGraph::MiniGraph;
  588. size_t nr_oprs_in_graph() const override { return opr_count; }
  589. size_t next_node_id() override { return opr_count; }
  590. void add_used_comp_node(CompNode cn) { m_used_comp_node.insert(cn); }
  591. std::shared_ptr<void> on_comp_node_finalize() override {
  592. assert(!target);
  593. MGB_LOCK_GUARD(m_mini_graph_cache_mtx);
  594. m_mini_graph_cache.clear();
  595. return {};
  596. }
  597. cg::static_infer::StaticInferManager& static_infer_manager() override {
  598. return m_static_infer_manager;
  599. }
  600. void attach(ProxyGraph::MiniGraph* target_) {
  601. target = target_;
  602. m_static_infer_manager.target = target_;
  603. }
  604. struct AttachGuard {
  605. ProxyGraphTypeI* owner = nullptr;
  606. ProxyGraph::MiniGraph* target = nullptr;
  607. AttachGuard(
  608. ProxyGraphTypeI* owner_ = nullptr,
  609. ProxyGraph::MiniGraph* target_ = nullptr)
  610. : owner(owner_), target(target_) {}
  611. AttachGuard(AttachGuard&) = delete;
  612. AttachGuard& operator=(AttachGuard&) = delete;
  613. AttachGuard(AttachGuard&& rhs) : owner(rhs.owner), target(rhs.target) {
  614. rhs.owner = nullptr;
  615. }
  616. AttachGuard& operator=(AttachGuard&& rhs) = delete;
  617. ~AttachGuard() {
  618. if (owner)
  619. owner->attach(target);
  620. }
  621. };
  622. [[nodiscard]] AttachGuard scoped_attach(ProxyGraph::MiniGraph* target_) {
  623. attach(target_);
  624. return attach_guard();
  625. }
  626. [[nodiscard]] AttachGuard attach_guard(ProxyGraph::MiniGraph* target_ = nullptr) {
  627. return {this, target_};
  628. }
  629. public:
  630. ~ProxyGraphTypeI() {
  631. if (is_finalized()) {
  632. return;
  633. }
  634. for (auto&& i : m_used_comp_node) {
  635. if (i.device_type() == CompNode::DeviceType::CUDA)
  636. continue;
  637. i.sync();
  638. }
  639. }
  640. OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr_uniqp) override {
  641. mgb_assert(target);
  642. return target->insert_opr(std::move(opr_uniqp));
  643. }
  644. static ProxyGraphTypeI& inst() {
  645. if (!sm_instance || sm_instance->is_finalized()) {
  646. sm_instance.reset(new ProxyGraphTypeI);
  647. }
  648. return *sm_instance;
  649. }
  650. template <typename T>
  651. ProxyGraph::MiniGraph& get_cached_minigraph(const OpDef& def, const T& inputs) {
  652. mgb_assert(!is_finalized());
  653. size_t buf_size = 2 * inputs.size() + 1;
  654. size_t buf[buf_size];
  655. size_t pos = 0;
  656. buf[pos++] = def.hash();
  657. for (auto&& inp : inputs) {
  658. auto tensor = TensorAdaptor(inp);
  659. buf[pos++] = mgb::hash(tensor.dtype().handle());
  660. buf[pos++] = mgb::hash(tensor.comp_node());
  661. }
  662. mgb_assert(pos == buf_size);
  663. auto key = XXHash{}.update(buf, buf_size * sizeof(size_t)).digest();
  664. auto its = m_mini_graph_cache.equal_range(key);
  665. auto it = its.first;
  666. for (; it != its.second; ++it) {
  667. if (it->second.is_same_buf(buf, buf_size)) {
  668. return it->second;
  669. }
  670. mgb_log_warn("hash collision occurs in minigraph cache with key: %lu", key);
  671. }
  672. auto&& result = m_mini_graph_cache.emplace(
  673. std::piecewise_construct, std::make_tuple(key),
  674. std::forward_as_tuple(
  675. *this, def, inputs, static_cast<size_t*>(buf), buf_size));
  676. mgb_assert(result->first);
  677. return result->second;
  678. }
  679. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  680. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
  681. auto& minigraph = get_cached_minigraph(def, inputs);
  682. auto _ = scoped_attach(&minigraph);
  683. auto sess = minigraph.infer_session(inputs);
  684. std::tuple<SmallVector<LogicalTensorDesc>, bool> ret;
  685. auto& [descs, noerr] = ret;
  686. descs.reserve(minigraph.output_size());
  687. for (size_t i = 0; i < minigraph.output_size(); ++i) {
  688. descs.emplace_back();
  689. auto& desc = descs.back();
  690. desc.layout.dtype = minigraph.output_var(i)->dtype();
  691. desc.layout.format = minigraph.output_var(i)->format();
  692. desc.comp_node = minigraph.output_var(i)->comp_node();
  693. if (auto* shape = sess.infer_shape(i, false)) {
  694. desc.layout.init_contiguous_stride(*shape);
  695. noerr = true;
  696. } else {
  697. noerr = false;
  698. }
  699. }
  700. return ret;
  701. }
  702. SmallVector<TensorPtr> apply_on_physical_tensor(
  703. const OpDef& def, SmallVector<TensorPtr> inputs) {
  704. auto raw_inputs = to_raw_ptr_array(inputs);
  705. auto& minigraph = get_cached_minigraph(def, raw_inputs);
  706. auto _ = scoped_attach(&minigraph);
  707. auto sess = minigraph.infer_session(raw_inputs);
  708. ::mgb::opr::intl::WorkspaceLimitHook::set_impl(
  709. minigraph.opr()->owner_graph(), get_workspace_limit);
  710. // some output var in minigraph.opr()->output() may not appears in
  711. // minigraph.opr()->usable_output() bug execution may use the attrs for those
  712. // output var, so we infer attrs for all outputs, but only return
  713. // LogicalTensorDesc for minigraph.opr()->usable_output()
  714. SmallVector<LogicalTensorDesc> output_descs;
  715. for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) {
  716. auto* shape = sess.infer(sess.output_data[i].shape_infer, true);
  717. mgb_assert(shape);
  718. minigraph.opr()->output()[i]->shape(*shape);
  719. }
  720. descs.reserve(minigraph.output_size());
  721. for (size_t i = 0; i < minigraph.output_size(); ++i) {
  722. auto* ovar = minigraph.output_var(i);
  723. mgb_assert(ovar->dtype().valid() && ovar->comp_node().valid());
  724. mgb_assert(
  725. ovar->shape().ndim ||
  726. ovar->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC));
  727. output_descs.push_back({{ovar->shape(), ovar->dtype()}, ovar->comp_node()});
  728. }
  729. SmallVector<TensorPtr> outputs(output_descs.size(), {});
  730. for (size_t i = 0; i < outputs.size(); i++) {
  731. outputs[i] =
  732. Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
  733. }
  734. auto raw_outputs = to_raw_ptr_array(outputs);
  735. CompNode::UnorderedSet used_cns;
  736. for (auto&& out : raw_outputs) {
  737. auto cn = out->comp_node();
  738. add_used_comp_node(cn);
  739. if (used_cns.insert(cn).second) {
  740. for (auto&& in : inputs) {
  741. if (in->comp_node() != cn) {
  742. auto&& e = in->get_or_create_event();
  743. e->device_wait_by(cn);
  744. }
  745. }
  746. }
  747. }
  748. // some opr (e.g. Subtensor) may invoke infer_value during execution,
  749. // so we need create inference session here
  750. minigraph.execute(raw_inputs, raw_outputs, m_env);
  751. for (auto&& cn : used_cns) {
  752. for (auto&& in : inputs) {
  753. if (in->comp_node() != cn) {
  754. in->add_release_callback(cn);
  755. }
  756. }
  757. }
  758. return outputs;
  759. }
  760. };
  761. } // namespace mgb::imperative::proxy_graph