You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utility.cpp 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. /**
  2. * \file imperative/src/impl/ops/utility.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <atomic>
  12. #include <deque>
  13. #include "megbrain/imperative/graph_cache.h"
  14. #include "megbrain/imperative/opr_utility.h"
  15. #include "megbrain/imperative/ops/autogen.h"
  16. #include "megbrain/imperative/ops/opr_attr.h"
  17. #include "megbrain/imperative/ops/utility.h"
  18. #include "megbrain/imperative/resource_manager.h"
  19. #include "megbrain/imperative/subgraph_detail.h"
  20. #include "megbrain/opr/io.h"
  21. #include "megbrain/opr/tensor_gen.h"
  22. #include "megbrain/opr/tensor_manip.h"
  23. #include "megbrain/opr/utility.h"
  24. #if MGB_JIT
  25. #include "megbrain/jit/executor_opr.h"
  26. #endif
  27. #include "../event_pool.h"
  28. #include "../op_trait.h"
  29. namespace mgb::imperative {
  30. MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
  31. OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback();
  32. namespace {
  33. namespace fastpathcopy {
  34. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  35. return inputs;
  36. }
  37. auto make_backward_graph(
  38. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
  39. const SmallVector<bool>& input_requires_grad,
  40. const SmallVector<bool>& output_has_grad) {
  41. Subgraph graph;
  42. graph.inputs = {1, 2, 3};
  43. graph.outputs = {3};
  44. graph.exprs = {};
  45. return EncodedSubgraph::make(graph);
  46. }
  47. OP_TRAIT_REG(FastpathCopy, FastpathCopy)
  48. .apply_on_var_node(apply_on_var_node)
  49. .make_backward_graph(make_backward_graph)
  50. .fallback();
  51. } // namespace fastpathcopy
  52. } // namespace
  53. namespace {
  54. namespace shape_infer {
  55. auto apply_on_physical_tensor(
  56. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  57. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  58. mgb_assert(validated, "shape inference incomplete");
  59. SmallVector<TensorPtr> outputs;
  60. for (auto&& output_desc : output_descs) {
  61. HostTensorND shape_tensor{
  62. output_desc.comp_node, {output_desc.layout.ndim}, dtype::Int32()};
  63. for (size_t i = 0; i < output_desc.layout.ndim; ++i) {
  64. shape_tensor.ptr<int32_t>()[i] = output_desc.layout[i];
  65. }
  66. auto output = Tensor::make(shape_tensor);
  67. outputs.push_back(output);
  68. }
  69. return outputs;
  70. }
  71. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  72. auto& op = def.cast_final_safe<ShapeInfer>();
  73. size_t nr_inputs = inputs.size();
  74. VarNodeArray input_values, outputs;
  75. mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer");
  76. for (size_t i = 0; i < nr_inputs; ++i) {
  77. auto input_value =
  78. opr::Alloc::make(SymbolVar(inputs[i]), op.dtypes[i], {op.devices[i]});
  79. input_values.push_back(input_value.node());
  80. }
  81. auto output_values = OpDef::apply_on_var_node(*op.op, input_values);
  82. for (auto&& output_value : output_values) {
  83. outputs.push_back(opr::GetVarShape::make(output_value).node());
  84. }
  85. return outputs;
  86. }
  87. auto infer_output_attrs_fallible(
  88. const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
  89. auto& op = def.cast_final_safe<ShapeInfer>();
  90. SmallVector<LogicalTensorDesc> input_shape_descs;
  91. size_t nr_inputs = op.devices.size();
  92. mgb_assert(
  93. op.dtypes.size() == nr_inputs,
  94. "number of input devices and dtypes mismatch");
  95. for (size_t i = 0; i < nr_inputs; ++i) {
  96. LogicalTensorDesc input_shape_desc;
  97. input_shape_desc.comp_node = op.devices[i];
  98. input_shape_desc.layout.ndim = 0;
  99. input_shape_desc.layout.dtype = op.dtypes[i];
  100. input_shape_descs.push_back(input_shape_desc);
  101. }
  102. auto [output_shape_descs, _] =
  103. OpDef::infer_output_attrs_fallible(*op.op, input_shape_descs);
  104. SmallVector<LogicalTensorDesc> output_descs;
  105. for (auto&& output_shape_desc : output_shape_descs) {
  106. LogicalTensorDesc output_desc;
  107. output_desc.comp_node = output_shape_desc.comp_node;
  108. output_desc.layout.ndim = 1;
  109. output_desc.layout.dtype = dtype::Int32();
  110. output_descs.push_back(output_desc);
  111. }
  112. return std::make_tuple(output_descs, false);
  113. }
  114. auto props(const OpDef& def) {
  115. auto& op = def.cast_final_safe<ShapeInfer>();
  116. return OpDef::props(*op.op);
  117. }
  118. auto make_name(const OpDef& def) {
  119. auto& op = def.cast_final_safe<ShapeInfer>();
  120. MGB_MARK_USED_VAR(op);
  121. return ssprintf("ShapeInfer[%s]", op.op->make_name().c_str());
  122. }
  123. auto hash(const OpDef& def) {
  124. auto& op = def.cast_final_safe<ShapeInfer>();
  125. return op.op->hash();
  126. }
  127. auto is_same_st(const OpDef& def, const OpDef& another) {
  128. if (!another.same_type<ShapeInfer>()) {
  129. return false;
  130. }
  131. auto& lhs = def.cast_final_safe<ShapeInfer>();
  132. auto& rhs = another.cast_final_safe<ShapeInfer>();
  133. if (!lhs.op->is_same(*rhs.op)) {
  134. return false;
  135. }
  136. return std::tie(lhs.devices, lhs.dtypes) == std::tie(rhs.devices, rhs.dtypes);
  137. }
  138. OP_TRAIT_REG(ShapeInfer, ShapeInfer)
  139. .apply_on_var_node(apply_on_var_node)
  140. .apply_on_physical_tensor(apply_on_physical_tensor)
  141. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  142. .make_name(make_name)
  143. .props(props)
  144. .hash(hash)
  145. .is_same_st(is_same_st)
  146. .fallback();
  147. } // namespace shape_infer
  148. } // namespace
  149. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeInfer);
  150. namespace {
  151. namespace identity {
  152. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  153. auto&& op = def.cast_final_safe<Identity>();
  154. mgb_assert(inputs.size() == 1);
  155. OperatorNodeConfig config{op.make_name()};
  156. return opr::Identity::make(inputs[0], config);
  157. }
  158. auto apply_on_physical_tensor(
  159. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  160. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  161. return SmallVector<TensorPtr>{inputs[0]};
  162. }
  163. OP_TRAIT_REG(Identity, Identity)
  164. .apply_on_var_node(apply_on_var_node)
  165. .apply_on_physical_tensor(apply_on_physical_tensor)
  166. .fallback();
  167. } // namespace identity
  168. } // namespace
  169. namespace {
  170. namespace subgraph {
  171. EncodedSubgraph make_forward_graph(
  172. const OpDef& def, SmallVector<LogicalTensorDesc> inputs) {
  173. return EncodedSubgraph::make(*def.cast_final_safe<SubgraphOp>().graph);
  174. }
  175. EncodedSubgraph make_backward_graph(
  176. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
  177. const SmallVector<bool>& input_requires_grad,
  178. SmallVector<bool> output_has_grad) {
  179. auto& op = def.cast_final_safe<SubgraphOp>();
  180. mgb_assert(output_has_grad.size() == op.output_grad_mask.size());
  181. for (size_t i = 0; i < output_has_grad.size(); ++i) {
  182. if (!op.output_grad_mask[i]) {
  183. output_has_grad[i] = false;
  184. }
  185. }
  186. auto bgraph = subgraph_detail::make_backward_graph(
  187. def, inputs, input_requires_grad, output_has_grad);
  188. return EncodedSubgraph::make_single(
  189. SubgraphOp::make(
  190. op.name + "Grad", std::make_shared<Subgraph>(bgraph.graph)),
  191. bgraph.input_mask, bgraph.output_mask);
  192. }
  193. std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
  194. auto& op = def.cast_final_safe<SubgraphOp>();
  195. return {
  196. {"name", op.name},
  197. {"inputs", mgb::imperative::to_string(op.graph->inputs)},
  198. {"exprs", mgb::imperative::to_string(op.graph->exprs)},
  199. {"outputs", mgb::imperative::to_string(op.graph->outputs)},
  200. };
  201. }
  202. std::string make_name(const OpDef& def) {
  203. auto& op = def.cast_final_safe<SubgraphOp>();
  204. if (op.name.empty()) {
  205. return "SubgraphOp";
  206. } else {
  207. return op.name;
  208. }
  209. }
  210. auto hash(const OpDef& def) {
  211. auto& op = def.cast_final_safe<SubgraphOp>();
  212. if (!op.graph_key) {
  213. return (size_t) reinterpret_cast<uintptr_t>(op.graph.get());
  214. }
  215. return op.graph_key->hash();
  216. }
  217. auto is_same_st(const OpDef& def, const OpDef& another) {
  218. if (!another.same_type<SubgraphOp>()) {
  219. return false;
  220. }
  221. auto& lhs = def.cast_final_safe<SubgraphOp>();
  222. auto& rhs = another.cast_final_safe<SubgraphOp>();
  223. auto has_graph_key = bool(lhs.graph_key);
  224. bool graph_same = false;
  225. if (has_graph_key) {
  226. graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key);
  227. } else {
  228. graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get();
  229. }
  230. return graph_same;
  231. }
  232. OP_TRAIT_REG(SubgraphOp, SubgraphOp)
  233. .make_forward_graph(make_forward_graph)
  234. .make_backward_graph(make_backward_graph)
  235. .props(props)
  236. .make_name(make_name)
  237. .hash(hash)
  238. .is_same_st(is_same_st)
  239. .fallback();
  240. } // namespace subgraph
  241. } // namespace
  242. namespace {
  243. namespace compiled_op {
  244. struct DeviceMemoryAllocatorImpl : cg::DeviceMemoryAllocator {
  245. std::shared_ptr<OpDef> current_op;
  246. void alloc_static(
  247. ComputingGraph* graph, DeviceTensorStorage& dest, size_t size) override {
  248. mgb_assert(0, "alloc_static is not allowed in CompiledOp");
  249. }
  250. void alloc_dynamic(VarNode* var, DeviceTensorStorage& dest, size_t size) override {
  251. auto comp_node = var->comp_node();
  252. auto storage = current_op->allocate(comp_node, size);
  253. dest.reset(comp_node, size, storage);
  254. }
  255. };
  256. enum class HolderKind {
  257. ShapeInfer,
  258. Execute,
  259. };
  260. template <HolderKind Kind>
  261. struct ComputingGraphHolder {
  262. struct Input {
  263. std::shared_ptr<DeviceTensorND> device_value;
  264. std::shared_ptr<HostTensorND> host_value;
  265. std::shared_ptr<HostTensorND> host_shape;
  266. };
  267. std::shared_ptr<ComputingGraph> graph;
  268. std::unique_ptr<cg::AsyncExecutable> executable;
  269. SmallVector<Input> inputs;
  270. SmallVector<std::shared_ptr<DeviceTensorND>> device_outputs;
  271. SmallVector<VarNode*> input_vars;
  272. SmallVector<VarNode*> output_vars;
  273. std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
  274. SmallVector<std::shared_ptr<CompNode::Event>> events;
  275. std::unique_ptr<cg::static_infer::StaticInferUpdater> updater;
  276. void initialize(
  277. const CompiledOp& op, const SmallVector<LogicalTensorDesc>& input_descs) {
  278. allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
  279. graph = ComputingGraph::make();
  280. graph->options().force_dynamic_alloc = true;
  281. graph->options().async_exec_level = 0;
  282. graph->options().graph_opt_level = op.gopt_level;
  283. graph->options().enable_var_mem_defragment = false;
  284. graph->options().comp_seq_sync_device = false;
  285. // set allocator for DTR support
  286. graph->set_device_memory_allocator(allocator);
  287. if constexpr (Kind == HolderKind::ShapeInfer) {
  288. updater = cg::static_infer::StaticInferUpdater::make();
  289. }
  290. for (auto&& desc : input_descs) {
  291. Input input;
  292. VarNode* input_var = nullptr;
  293. if constexpr (Kind == HolderKind::Execute) {
  294. input.device_value = std::make_shared<DeviceTensorND>();
  295. input.device_value->dtype(desc.layout.dtype);
  296. input.device_value->comp_node(desc.comp_node);
  297. input.device_value->resize(desc.layout);
  298. auto callback = [value = input.device_value] { return *value; };
  299. if (!desc.value.empty()) {
  300. input.host_value = std::make_shared<HostTensorND>();
  301. input.host_value->dtype(desc.layout.dtype);
  302. input.host_value->comp_node(desc.comp_node);
  303. }
  304. input_var = opr::MutableTensor::make(
  305. *graph, input.device_value, input.host_value, {})
  306. .node();
  307. // input_var = opr::VolatileSharedDeviceTensor::make(*graph,
  308. // input.device_value).node();
  309. } else if constexpr (Kind == HolderKind::ShapeInfer) {
  310. if (desc.value.empty()) {
  311. input.host_shape = std::make_shared<HostTensorND>();
  312. input.host_shape->dtype(dtype::Int32());
  313. input.host_shape->comp_node(desc.comp_node);
  314. auto input_shape_var =
  315. opr::Host2DeviceCopy::make(*graph, input.host_shape);
  316. input_var =
  317. opr::Alloc::make(input_shape_var, desc.layout.dtype).node();
  318. } else {
  319. input.host_value = std::make_shared<HostTensorND>();
  320. input.host_value->dtype(desc.layout.dtype);
  321. input.host_value->comp_node(desc.comp_node);
  322. input_var =
  323. opr::Host2DeviceCopy::make(*graph, input.host_value).node();
  324. }
  325. } else {
  326. static_assert((Kind != Kind), "unknown holder kind");
  327. }
  328. input_vars.push_back(input_var);
  329. inputs.push_back(input);
  330. }
  331. // forward to inner op
  332. output_vars = OpDef::apply_on_var_node(*op.op, input_vars);
  333. ComputingGraph::OutputSpec output_spec;
  334. CompNode::UnorderedSet comp_nodes;
  335. for (auto&& output_var : output_vars) {
  336. using namespace cg::static_infer;
  337. auto output_ptr = std::make_shared<DeviceTensorND>();
  338. auto callback = [output_ptr](DeviceTensorND output) {
  339. output_ptr->reset(output.storage(), output.layout());
  340. output = {};
  341. };
  342. if constexpr (Kind == HolderKind::ShapeInfer) {
  343. output_spec.push_back({output_var, callback});
  344. auto it = graph->static_infer_manager().get_infer_type(output_var);
  345. if (it.shape == InferType::RT_STATIC) {
  346. updater->add_dest({output_var, DepType::SHAPE});
  347. }
  348. if (it.value == InferType::RT_STATIC) {
  349. updater->add_dest({output_var, DepType::VALUE});
  350. }
  351. } else {
  352. auto output_callback_var =
  353. opr::OutputCallback::make({callback}, output_var);
  354. output_spec.push_back({output_callback_var, {}});
  355. }
  356. device_outputs.push_back(output_ptr);
  357. }
  358. executable = graph->compile(output_spec);
  359. executable->iter_opr_seq([&](cg::OperatorNodeBase* opr) -> bool {
  360. for (auto&& output : opr->output()) {
  361. comp_nodes.insert(output->comp_node());
  362. }
  363. return true;
  364. });
  365. for (auto&& comp_node : comp_nodes) {
  366. events.push_back(EventPool::without_timer().alloc_shared(comp_node));
  367. events.back()->record();
  368. }
  369. }
  370. template <
  371. HolderKind ThisKind = Kind,
  372. typename = std::enable_if_t<ThisKind == HolderKind::Execute>>
  373. SmallVector<TensorPtr> apply_on_physical_tensor(
  374. const OpDef& def, const SmallVector<LogicalTensorDesc> input_descs,
  375. const SmallVector<TensorPtr>& input_tensors) {
  376. // wait for last execution
  377. executable->wait();
  378. size_t nr_inputs = inputs.size();
  379. for (size_t i = 0; i < nr_inputs; ++i) {
  380. auto input_dev_tensor = input_tensors[i]->dev_tensor();
  381. inputs[i].device_value->reset(
  382. input_dev_tensor.storage(), input_dev_tensor.layout());
  383. if (inputs[i].host_value) {
  384. inputs[i].host_value->copy_from(input_descs[i].value);
  385. }
  386. }
  387. allocator->current_op = const_cast<OpDef&>(def).shared_from_this();
  388. executable->execute();
  389. for (auto&& event : events) {
  390. event->record();
  391. }
  392. SmallVector<TensorPtr> outputs_tensors;
  393. for (auto input : inputs) {
  394. *input.device_value = {};
  395. if (input.host_value) {
  396. *input.host_value = {};
  397. }
  398. }
  399. for (auto output_nd : device_outputs) {
  400. outputs_tensors.push_back(Tensor::make(*output_nd));
  401. *output_nd = {};
  402. }
  403. executable->clear_device_memory();
  404. allocator->current_op = nullptr;
  405. return outputs_tensors;
  406. }
  407. template <
  408. HolderKind ThisKind = Kind,
  409. typename = std::enable_if_t<ThisKind == HolderKind::ShapeInfer>>
  410. std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
  411. const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
  412. executable->wait();
  413. size_t nr_inputs = input_vars.size(), nr_outputs = output_vars.size();
  414. SmallVector<LogicalTensorDesc> output_descs(nr_outputs);
  415. for (size_t i = 0; i < nr_inputs; ++i) {
  416. if (inputs[i].host_shape) {
  417. DeviceTensorND input_shape_device_nd;
  418. cg::copy_shape_to_tensor_value(
  419. input_shape_device_nd, input_descs[i].layout);
  420. inputs[i].host_shape->copy_from(input_shape_device_nd);
  421. mgb_assert(input_descs[i].layout.ndim, "ndim == 0");
  422. } else if (inputs[i].host_value) {
  423. inputs[i].host_value->copy_from(input_descs[i].value);
  424. }
  425. }
  426. updater->update();
  427. bool validated = true;
  428. for (size_t i = 0; i < nr_outputs; ++i) {
  429. auto infer_type =
  430. graph->static_infer_manager().get_infer_type(output_vars[i]);
  431. const TensorShape* output_shape = nullptr;
  432. const DeviceTensorND* output_value = nullptr;
  433. auto& desc = output_descs[i];
  434. if (infer_type.shape != cg::static_infer::InferType::NO_DESC) {
  435. output_shape = graph->static_infer_manager().infer_shape_fallible(
  436. output_vars[i]);
  437. }
  438. if (infer_type.value != cg::static_infer::InferType::NO_DESC) {
  439. output_value = graph->static_infer_manager().infer_value_fallible(
  440. output_vars[i]);
  441. }
  442. if (output_shape && output_value) {
  443. mgb_assert(
  444. output_shape->eq_shape(output_value->shape()),
  445. "shape infer result mismatch, %s vs %s",
  446. output_shape->to_string().c_str(),
  447. output_value->shape().to_string().c_str());
  448. }
  449. if (output_shape) {
  450. ((TensorShape&)desc.layout) = *output_shape;
  451. }
  452. if (output_value) {
  453. ((TensorShape&)desc.layout) = output_value->shape();
  454. desc.value = *output_value;
  455. }
  456. desc.layout.dtype = output_vars[i]->dtype();
  457. desc.comp_node = output_vars[i]->comp_node();
  458. if (!desc.layout.ndim) {
  459. validated = false;
  460. }
  461. desc.layout.init_contiguous_stride();
  462. }
  463. return {output_descs, validated};
  464. }
  465. };
  466. static std::atomic<size_t> nr_cg_cache = 0;
  467. template <HolderKind Kind>
  468. ComputingGraphHolder<Kind>& get_computing_graph(
  469. std::shared_ptr<OpDef> compiled_op,
  470. const SmallVector<LogicalTensorDesc>& descs) {
  471. using ComputingGraphHolderCache =
  472. OpMethResultCache<std::deque<std::unique_ptr<ComputingGraphHolder<Kind>>>>;
  473. thread_local auto& cache = ([]() -> auto& {
  474. mgb_assert(
  475. nr_cg_cache++ < 5,
  476. "using subgraph in too many threads, this causes resource leakage");
  477. #if MGB_CUDA && defined(WIN32)
  478. // FIXME: Create as global to skip resource finalize and windows with cuda
  479. // doesn't cleanup global resources
  480. return *ResourceManager::create_global<ComputingGraphHolderCache>();
  481. #else
  482. // Otherwise this should be local because compnode may be unusable when global
  483. // resource finalizing.
  484. // For example, CpuCompNode.sync hang on because underlying thread died
  485. return *ResourceManager::create_local<ComputingGraphHolderCache>();
  486. #endif
  487. })();
  488. thread_local size_t nr_cg_holders = 0;
  489. typename ComputingGraphHolderCache::key_t cache_key = {compiled_op, descs};
  490. auto& cg_holder_queue = cache[cache_key];
  491. std::unique_ptr<ComputingGraphHolder<Kind>> holder;
  492. if (!cg_holder_queue.empty()) {
  493. // pick one
  494. std::swap(cg_holder_queue.front(), holder);
  495. // check all events finished
  496. for (auto&& event : holder->events) {
  497. if (!event->finished()) {
  498. bool queue_limited =
  499. event->comp_node().contain_flag(CompNode::Flag::QUEUE_LIMITED);
  500. bool many_graph = cg_holder_queue.size() > 10;
  501. if (queue_limited || !many_graph) {
  502. std::swap(cg_holder_queue.front(), holder);
  503. break;
  504. } else {
  505. // graph limit
  506. mgb_log_debug(
  507. "computing graph limit for compiled op exceeded, waiting "
  508. "for prev graph");
  509. event->host_wait();
  510. }
  511. } else {
  512. event->host_wait();
  513. }
  514. }
  515. if (holder) {
  516. cg_holder_queue.pop_front();
  517. }
  518. }
  519. if (!holder) {
  520. // create new computing graph
  521. auto create_holder = [&] {
  522. auto holder = std::make_unique<ComputingGraphHolder<Kind>>();
  523. auto& cg_holder = *holder;
  524. cg_holder.initialize(compiled_op->cast_final_safe<CompiledOp>(), descs);
  525. nr_cg_holders++;
  526. mgb_log_debug(
  527. "add new computing graph for compiled op, now %zu graphs",
  528. nr_cg_holders);
  529. return holder;
  530. };
  531. size_t nr_graphs = std::max(cg_holder_queue.size(), (size_t)1);
  532. for (size_t i = 1; i < nr_graphs; ++i) {
  533. cg_holder_queue.push_front(create_holder());
  534. }
  535. holder = create_holder();
  536. }
  537. cg_holder_queue.push_back(std::move(holder));
  538. return *cg_holder_queue.back();
  539. }
  540. auto apply_on_physical_tensor(
  541. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  542. SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
  543. SmallVector<LogicalTensorDesc> input_descs;
  544. for (auto&& input : inputs) {
  545. input_descs.push_back({input->layout(), input->comp_node()});
  546. if (auto* host_value = input->try_get_value()) {
  547. if (host_value->layout().total_nr_elems() <=
  548. MEGDNN_MAX_NDIM) { // infer small tensor
  549. input_descs.back().value = host_value->proxy_to_default_cpu();
  550. }
  551. }
  552. }
  553. auto shared_def = const_cast<OpDef&>(def).shared_from_this();
  554. auto& cg_holder = get_computing_graph<HolderKind::Execute>(shared_def, input_descs);
  555. return cg_holder.apply_on_physical_tensor(def, input_descs, inputs);
  556. }
  557. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  558. auto& op = def.cast_final_safe<CompiledOp>();
  559. op.op->set_scope(op.scope());
  560. return OpDef::apply_on_var_node(*op.op, inputs);
  561. }
  562. auto infer_output_attrs_fallible(
  563. const OpDef& def, SmallVector<LogicalTensorDesc> input_descs) {
  564. bool shape_all_valid = true;
  565. for (auto&& input_desc : input_descs) {
  566. if (!input_desc.layout.ndim) {
  567. shape_all_valid = false;
  568. break;
  569. }
  570. }
  571. if (!shape_all_valid) {
  572. return OpDef::infer_output_attrs_fallible(
  573. *def.cast_final_safe<CompiledOp>().op, input_descs);
  574. }
  575. auto shared_def = const_cast<OpDef&>(def).shared_from_this();
  576. for (auto& input_desc : input_descs) {
  577. if (input_desc.layout.total_nr_elems() >
  578. MEGDNN_MAX_NDIM) { // skip large tensor
  579. input_desc.value = {};
  580. }
  581. }
  582. auto& cg_holder =
  583. get_computing_graph<HolderKind::ShapeInfer>(shared_def, input_descs);
  584. return cg_holder.infer_output_attrs_fallible(def, input_descs);
  585. }
  586. auto props(const OpDef& def) {
  587. return OpDef::props(*def.cast_final_safe<CompiledOp>().op);
  588. }
  589. auto make_name(const OpDef& def) {
  590. auto& op = def.cast_final_safe<CompiledOp>();
  591. MGB_MARK_USED_VAR(op);
  592. return ssprintf("CompiledOp[%s]", op.op->make_name().c_str());
  593. }
  594. EncodedSubgraph make_backward_graph(
  595. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
  596. const SmallVector<bool>& input_requires_grad,
  597. const SmallVector<bool>& output_has_grad) {
  598. auto& op = def.cast_final_safe<CompiledOp>();
  599. auto backward_graph = OpDef::make_backward_graph(
  600. *op.op, inputs, input_requires_grad, output_has_grad);
  601. auto name = def.trait()->make_name(def);
  602. std::shared_ptr<OpDef> bgraph_op =
  603. SubgraphOp::wrap(name + "Grad", backward_graph.graph);
  604. auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level);
  605. auto encoded_graph = EncodedSubgraph::make_single(
  606. compiled_op, backward_graph.input_mask, backward_graph.output_mask);
  607. return encoded_graph;
  608. }
  609. auto hash(const OpDef& def) {
  610. auto& op = def.cast_final_safe<CompiledOp>();
  611. return mgb::hash_pair_combine(op.op->hash(), op.gopt_level);
  612. }
  613. auto is_same_st(const OpDef& def, const OpDef& another) {
  614. if (!another.same_type<CompiledOp>()) {
  615. return false;
  616. }
  617. auto& lhs = def.cast_final_safe<CompiledOp>();
  618. auto& rhs = another.cast_final_safe<CompiledOp>();
  619. return lhs.op->is_same(*rhs.op) && lhs.gopt_level == rhs.gopt_level;
  620. }
  621. OP_TRAIT_REG(CompiledOp, CompiledOp)
  622. .apply_on_var_node(apply_on_var_node)
  623. .apply_on_physical_tensor(apply_on_physical_tensor)
  624. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  625. .make_backward_graph(make_backward_graph)
  626. .make_name(make_name)
  627. .props(props)
  628. .hash(hash)
  629. .is_same_st(is_same_st)
  630. .fallback();
  631. } // namespace compiled_op
  632. } // namespace
  633. namespace {
  634. namespace jit_fusion {
  635. static thread_local bool tm_enabled = true;
  636. auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
  637. auto& op = def.cast_final_safe<JITFusionOp>();
  638. op.op->set_scope(op.scope());
  639. auto outputs = OpDef::apply_on_var_node(*op.op, inputs);
  640. if (!tm_enabled) {
  641. // skip for dump (JITExecutor can not be dumped)
  642. return outputs;
  643. }
  644. #if MGB_JIT
  645. for (auto& output : outputs) {
  646. jit::InternalGraphGenerator igg{output->owner_opr()};
  647. std::vector<cg::OperatorNodeBase*> reverse_order;
  648. cg::DepOprIter iter{
  649. [&](cg::OperatorNodeBase* opr) { reverse_order.push_back(opr); }};
  650. for (auto&& input : inputs) {
  651. iter.set_visited(input->owner_opr());
  652. }
  653. iter.add(output->owner_opr());
  654. std::reverse(reverse_order.begin(), reverse_order.end());
  655. for (auto&& opr : reverse_order) {
  656. igg.add_opr(opr);
  657. }
  658. auto ig = igg.generate();
  659. output = jit::JITExecutor::make(ig, igg.orig_inps()).node();
  660. }
  661. #else
  662. mgb_assert(false, "MGB_WITH_JIT was disabled");
  663. #endif
  664. return outputs;
  665. }
  666. auto infer_output_attrs_fallible(
  667. const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
  668. TensorShape shape;
  669. DType dtype = input_descs[0].layout.dtype;
  670. CompNode comp_node = input_descs[0].comp_node;
  671. for (auto&& desc : input_descs) {
  672. if (desc.layout.ndim) {
  673. shape = desc.layout;
  674. break;
  675. }
  676. }
  677. for (size_t i = 0; i < input_descs.size(); ++i) {
  678. if (input_descs[i].layout.ndim) {
  679. mgb_assert(
  680. input_descs[i].layout.eq_shape(shape),
  681. "inputs of JITFusionOp should have same shapes");
  682. }
  683. mgb_assert(
  684. input_descs[i].layout.dtype == dtype,
  685. "inputs of JITFusionOp should have same dtypes");
  686. mgb_assert(
  687. input_descs[i].comp_node == comp_node,
  688. "inputs of JITFusionOp should have same devices");
  689. }
  690. return OpDef::infer_output_attrs_fallible(
  691. *def.cast_final_safe<JITFusionOp>().op, input_descs);
  692. }
  693. auto props(const OpDef& def) {
  694. return OpDef::props(*def.cast_final_safe<JITFusionOp>().op);
  695. }
  696. auto hash(const OpDef& def) {
  697. return def.cast_final_safe<JITFusionOp>().op->hash();
  698. }
  699. auto is_samt_st(const OpDef& def, const OpDef& another) {
  700. if (!another.same_type<JITFusionOp>()) {
  701. return false;
  702. }
  703. auto& lhs = def.cast_final_safe<JITFusionOp>();
  704. auto& rhs = another.cast_final_safe<JITFusionOp>();
  705. return lhs.op->is_same(*rhs.op);
  706. }
  707. EncodedSubgraph make_backward_graph(
  708. const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
  709. const SmallVector<bool>& input_requires_grad,
  710. const SmallVector<bool>& output_has_grad) {
  711. return {};
  712. }
  713. OP_TRAIT_REG(JITFusionOp, JITFusionOp)
  714. .apply_on_var_node(apply_on_var_node)
  715. .infer_output_attrs_fallible(infer_output_attrs_fallible)
  716. .props(props)
  717. .hash(hash)
  718. .is_same_st(is_samt_st)
  719. .make_backward_graph(make_backward_graph)
  720. .fallback();
  721. } // namespace jit_fusion
  722. } // namespace
  723. bool JITFusionOp::set_enabled(bool enabled) {
  724. std::swap(enabled, jit_fusion::tm_enabled);
  725. return enabled;
  726. }
  727. MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey);
  728. MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp);
  729. MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey);
  730. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp);
  731. MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITFusionOp);
  732. } // namespace mgb::imperative