You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_options.cpp 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. /**
  2. * \file lite/load_and_run/src/options/optimize_options.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #include "megbrain/gopt/inference.h"
  10. #if MGB_ENABLE_TENSOR_RT
  11. #include "megbrain/tensorrt/tensorrt_engine_cache.h"
  12. #endif
  13. #include "lite/global.h"
  14. #include "misc.h"
  15. #include "models/model_lite.h"
  16. #include "models/model_mdl.h"
  17. #include "optimize_options.h"
  18. ///////////////////////// fuse and preprocess optimize options ///////////////
  19. namespace lar {
  20. template <>
  21. void FusePreprocessOption::config_model_internel<ModelLite>(
  22. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  23. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  24. if (enable_fuse_preprocess) {
  25. LITE_WARN("enable fuse-preprocess optimization");
  26. model->get_config().options.fuse_preprocess = true;
  27. }
  28. }
  29. }
  30. template <>
  31. void FusePreprocessOption::config_model_internel<ModelMdl>(
  32. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  33. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  34. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  35. if (enable_fuse_preprocess) {
  36. mgb_log_warn("enable fuse-preprocess optimization");
  37. graph_option.graph_opt.enable_fuse_preprocess();
  38. }
  39. }
  40. }
  41. } // namespace lar
  42. using namespace lar;
  43. FusePreprocessOption::FusePreprocessOption() {
  44. m_option_name = "fuse_preprocess";
  45. enable_fuse_preprocess = FLAGS_enable_fuse_preprocess;
  46. }
  47. bool FusePreprocessOption::is_valid() {
  48. bool ret = FLAGS_enable_fuse_preprocess;
  49. return ret;
  50. }
  51. std::shared_ptr<OptionBase> FusePreprocessOption::create_option() {
  52. static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption);
  53. if (FusePreprocessOption::is_valid()) {
  54. return std::static_pointer_cast<OptionBase>(option);
  55. } else {
  56. return nullptr;
  57. }
  58. }
  59. void FusePreprocessOption::config_model(
  60. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  61. CONFIG_MODEL_FUN;
  62. }
  63. ///////////////////////// weight preprocess optimize options ///////////////
  64. namespace lar {
  65. template <>
  66. void WeightPreprocessOption::config_model_internel<ModelLite>(
  67. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  68. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  69. if (weight_preprocess) {
  70. LITE_WARN("enable weight-preprocess optimization");
  71. model->get_config().options.weight_preprocess = true;
  72. }
  73. }
  74. }
  75. template <>
  76. void WeightPreprocessOption::config_model_internel<ModelMdl>(
  77. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  78. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  79. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  80. if (weight_preprocess) {
  81. mgb_log_warn("enable weight-preprocess optimization");
  82. graph_option.graph_opt.enable_weight_preprocess();
  83. }
  84. }
  85. }
  86. } // namespace lar
  87. WeightPreprocessOption::WeightPreprocessOption() {
  88. m_option_name = "weight_preprocess";
  89. weight_preprocess = FLAGS_weight_preprocess;
  90. }
  91. bool WeightPreprocessOption::is_valid() {
  92. bool ret = FLAGS_weight_preprocess;
  93. return ret;
  94. }
  95. std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() {
  96. static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption);
  97. if (WeightPreprocessOption::is_valid()) {
  98. return std::static_pointer_cast<OptionBase>(option);
  99. } else {
  100. return nullptr;
  101. }
  102. }
  103. void WeightPreprocessOption::config_model(
  104. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  105. CONFIG_MODEL_FUN;
  106. }
  107. ///// fuse conv bias and nonlinear activation opr optimize options ////////
  108. namespace lar {
  109. template <>
  110. void FuseConvBiasNonlinearOption::config_model_internel<ModelLite>(
  111. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  112. LITE_MARK_USED_VAR(model);
  113. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  114. if (enable_fuse_conv_bias_nonlinearity) {
  115. LITE_THROW("fuse conv+bias+nonlinearity not supported in lite model");
  116. }
  117. }
  118. }
  119. template <>
  120. void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>(
  121. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  122. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  123. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  124. if (enable_fuse_conv_bias_nonlinearity) {
  125. mgb_log_warn("enable fuse conv+bias+nonlinearity optimization");
  126. graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity();
  127. }
  128. }
  129. }
  130. } // namespace lar
  131. FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() {
  132. m_option_name = "fuse_conv_bias_nonlinear";
  133. enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity;
  134. }
  135. bool FuseConvBiasNonlinearOption::is_valid() {
  136. bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity;
  137. return ret;
  138. }
  139. std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() {
  140. static std::shared_ptr<FuseConvBiasNonlinearOption> option(
  141. new FuseConvBiasNonlinearOption);
  142. if (FuseConvBiasNonlinearOption::is_valid()) {
  143. return std::static_pointer_cast<OptionBase>(option);
  144. } else {
  145. return nullptr;
  146. }
  147. }
  148. void FuseConvBiasNonlinearOption::config_model(
  149. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  150. CONFIG_MODEL_FUN;
  151. }
  152. ///////////////////////// fuse and preprocess optimize options ///////////////
  153. namespace lar {
  154. template <>
  155. void FuseConvBiasElemwiseAddOption::config_model_internel<ModelLite>(
  156. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  157. LITE_MARK_USED_VAR(model);
  158. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  159. if (enable_fuse_conv_bias_with_z) {
  160. LITE_THROW(
  161. "fuse conv+bias+z optimization not supported in lite "
  162. "model");
  163. }
  164. }
  165. }
  166. template <>
  167. void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>(
  168. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  169. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  170. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  171. if (enable_fuse_conv_bias_with_z) {
  172. mgb_log_warn("enable fuse conv+bias+z optimization");
  173. graph_option.graph_opt.enable_fuse_conv_bias_with_z();
  174. }
  175. }
  176. }
  177. } // namespace lar
  178. FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() {
  179. m_option_name = "fuse_conv_bias_z";
  180. enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z;
  181. }
  182. bool FuseConvBiasElemwiseAddOption::is_valid() {
  183. bool ret = FLAGS_enable_fuse_conv_bias_with_z;
  184. return ret;
  185. }
  186. std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() {
  187. static std::shared_ptr<FuseConvBiasElemwiseAddOption> option(
  188. new FuseConvBiasElemwiseAddOption);
  189. if (FuseConvBiasElemwiseAddOption::is_valid()) {
  190. return std::static_pointer_cast<OptionBase>(option);
  191. } else {
  192. return nullptr;
  193. }
  194. }
  195. void FuseConvBiasElemwiseAddOption::config_model(
  196. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  197. CONFIG_MODEL_FUN;
  198. }
  199. ///////////////////////// graph retrict options /////////////////////////
  200. namespace lar {
  201. template <>
  202. void GraphRecordOption::config_model_internel<ModelLite>(
  203. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  204. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  205. auto&& config_option = model->get_config().options;
  206. if (const_shape) {
  207. LITE_WARN("enable const var shape");
  208. config_option.const_shape = true;
  209. }
  210. if (fake_first) {
  211. LITE_WARN("enable fake-first optimization");
  212. config_option.fake_next_exec = true;
  213. }
  214. if (no_sanity_check) {
  215. LITE_WARN("disable var sanity check optimization");
  216. config_option.var_sanity_check_first_run = false;
  217. }
  218. if (m_record_comp_seq == 1) {
  219. LITE_WARN("set record_comp_seq_level to 1");
  220. }
  221. if (m_record_comp_seq == 2) {
  222. mgb_assert(
  223. no_sanity_check,
  224. "--no-sanity-check should be set before "
  225. "--record-comp-seq2");
  226. LITE_WARN("set record_comp_seq_level to 2");
  227. }
  228. config_option.comp_node_seq_record_level = m_record_comp_seq;
  229. }
  230. }
  231. template <>
  232. void GraphRecordOption::config_model_internel<ModelMdl>(
  233. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  234. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  235. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  236. if (const_shape) {
  237. mgb_log_warn("enable const var shape");
  238. model->get_mdl_config().const_var_shape = true;
  239. }
  240. if (fake_first) {
  241. mgb_log_warn("enable fake-first optimization");
  242. graph_option.fake_next_exec = true;
  243. }
  244. if (no_sanity_check) {
  245. mgb_log_warn("disable var sanity check optimization");
  246. graph_option.var_sanity_check_first_run = false;
  247. }
  248. if (m_record_comp_seq == 1) {
  249. mgb_log_warn("set record_comp_seq_level to 1");
  250. }
  251. if (m_record_comp_seq == 2) {
  252. mgb_assert(
  253. no_sanity_check && !fake_first,
  254. "--no-sanity-check should be set before "
  255. "--record-comp-seq2 and --fake-first should not be set");
  256. mgb_log_warn("set record_comp_seq_level to 2");
  257. }
  258. graph_option.comp_node_seq_record_level = m_record_comp_seq;
  259. }
  260. }
  261. } // namespace lar
  262. GraphRecordOption::GraphRecordOption() {
  263. m_option_name = "graph_record";
  264. m_record_comp_seq = 0;
  265. const_shape = FLAGS_const_shape;
  266. fake_first = FLAGS_fake_first;
  267. no_sanity_check = FLAGS_no_sanity_check;
  268. if (FLAGS_record_comp_seq) {
  269. m_record_comp_seq = 1;
  270. }
  271. if (FLAGS_record_comp_seq2) {
  272. m_record_comp_seq = 2;
  273. }
  274. }
  275. bool GraphRecordOption::is_valid() {
  276. bool ret = FLAGS_const_shape;
  277. ret = ret || FLAGS_fake_first;
  278. ret = ret || FLAGS_no_sanity_check;
  279. ret = ret || FLAGS_record_comp_seq;
  280. ret = ret || FLAGS_record_comp_seq2;
  281. return ret;
  282. }
  283. std::shared_ptr<OptionBase> GraphRecordOption::create_option() {
  284. static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption);
  285. if (GraphRecordOption::is_valid()) {
  286. return std::static_pointer_cast<OptionBase>(option);
  287. } else {
  288. return nullptr;
  289. }
  290. }
  291. void GraphRecordOption::config_model(
  292. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  293. CONFIG_MODEL_FUN;
  294. }
  295. ///////////////////////// graph retrict options /////////////////////////
  296. namespace lar {
  297. template <>
  298. void MemoryOptimizeOption::config_model_internel<ModelLite>(
  299. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  300. LITE_MARK_USED_VAR(model);
  301. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  302. if (disable_mem_opt) {
  303. LITE_THROW("lite model don't support disable memory optimization");
  304. }
  305. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  306. if (workspace_limit != SIZE_MAX) {
  307. LITE_WARN("set workspace limit to %ld", workspace_limit);
  308. lite::Runtime::set_network_algo_workspace_limit(
  309. model->get_lite_network(), workspace_limit);
  310. }
  311. }
  312. }
  313. template <>
  314. void MemoryOptimizeOption::config_model_internel<ModelMdl>(
  315. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  316. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  317. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  318. if (disable_mem_opt) {
  319. mgb_log_warn("disable memory optimization");
  320. graph_option.seq_opt.enable_mem_plan_opt = false;
  321. graph_option.seq_opt.enable_mem_reuse_alloc = false;
  322. }
  323. if (workspace_limit < SIZE_MAX) {
  324. mgb_log_warn("set workspace limit to %ld", workspace_limit);
  325. auto&& output_spec = model->get_output_spec();
  326. mgb::SymbolVarArray vars;
  327. for (auto i : output_spec) {
  328. vars.push_back(i.first);
  329. }
  330. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  331. }
  332. }
  333. }
  334. } // namespace lar
  335. MemoryOptimizeOption::MemoryOptimizeOption() {
  336. m_option_name = "memory_optimize";
  337. disable_mem_opt = FLAGS_disable_mem_opt;
  338. workspace_limit = FLAGS_workspace_limit;
  339. }
  340. bool MemoryOptimizeOption::is_valid() {
  341. bool ret = FLAGS_disable_mem_opt;
  342. ret = ret || FLAGS_workspace_limit < SIZE_MAX;
  343. return ret;
  344. }
  345. std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() {
  346. static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption);
  347. if (MemoryOptimizeOption::is_valid()) {
  348. return std::static_pointer_cast<OptionBase>(option);
  349. } else {
  350. return nullptr;
  351. }
  352. }
  353. void MemoryOptimizeOption::config_model(
  354. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  355. CONFIG_MODEL_FUN;
  356. }
  357. ///////////////////////// other options for optimization /////////////////
  358. namespace lar {
  359. template <>
  360. void JITOption::config_model_internel<ModelLite>(
  361. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  362. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  363. auto&& config_option = model->get_config().options;
  364. if (enable_jit) {
  365. LITE_WARN("enable JIT (level 1)");
  366. config_option.jit_level = 1;
  367. }
  368. }
  369. }
  370. template <>
  371. void JITOption::config_model_internel<ModelMdl>(
  372. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  373. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  374. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  375. if (enable_jit) {
  376. mgb_log_warn("enable JIT (level 1)");
  377. graph_option.graph_opt.jit = 1;
  378. }
  379. }
  380. }
  381. } // namespace lar
  382. JITOption::JITOption() {
  383. m_option_name = "JIT";
  384. enable_jit = FLAGS_enable_jit;
  385. }
  386. bool JITOption::is_valid() {
  387. bool ret = FLAGS_enable_jit;
  388. return ret;
  389. }
  390. std::shared_ptr<OptionBase> JITOption::create_option() {
  391. static std::shared_ptr<JITOption> option(new JITOption);
  392. if (JITOption::is_valid()) {
  393. return std::static_pointer_cast<OptionBase>(option);
  394. } else {
  395. return nullptr;
  396. }
  397. }
  398. void JITOption::config_model(
  399. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  400. CONFIG_MODEL_FUN;
  401. }
  402. ///////////////////////// other options for optimization /////////////////
  403. #if MGB_ENABLE_TENSOR_RT
  404. namespace lar {
  405. template <>
  406. void TensorRTOption::config_model_internel<ModelLite>(
  407. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  408. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  409. if (!tensorrt_cache.empty()) {
  410. LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str());
  411. lite::set_tensor_rt_cache(tensorrt_cache);
  412. }
  413. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  414. if (enable_tensorrt) {
  415. LITE_WARN("enable TensorRT");
  416. lite::Runtime::use_tensorrt(model->get_lite_network());
  417. }
  418. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  419. if (!tensorrt_cache.empty()) {
  420. lite::dump_tensor_rt_cache();
  421. }
  422. }
  423. }
  424. template <>
  425. void TensorRTOption::config_model_internel<ModelMdl>(
  426. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  427. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  428. auto&& graph_option = model->get_mdl_config().comp_graph->options();
  429. if (enable_tensorrt) {
  430. mgb_log_warn("using tensorRT");
  431. graph_option.graph_opt.tensorrt = true;
  432. }
  433. if (!tensorrt_cache.empty()) {
  434. mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str());
  435. mgb::TensorRTEngineCache::enable_engine_cache(true);
  436. mgb::TensorRTEngineCache::set_impl(
  437. std::make_shared<mgb::TensorRTEngineCacheIO>(
  438. tensorrt_cache.c_str()));
  439. }
  440. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  441. if (!tensorrt_cache.empty()) {
  442. if (mgb::TensorRTEngineCache::enable_engine_cache()) {
  443. mgb::TensorRTEngineCache::inst().dump_cache();
  444. }
  445. }
  446. }
  447. }
  448. } // namespace lar
  449. TensorRTOption::TensorRTOption() {
  450. m_option_name = "tensorRT";
  451. enable_tensorrt = FLAGS_tensorrt;
  452. tensorrt_cache = FLAGS_tensorrt_cache;
  453. }
  454. bool TensorRTOption::is_valid() {
  455. bool ret = FLAGS_tensorrt;
  456. ret = ret || !FLAGS_tensorrt_cache.empty();
  457. return ret;
  458. }
  459. std::shared_ptr<OptionBase> TensorRTOption::create_option() {
  460. static std::shared_ptr<TensorRTOption> option(new TensorRTOption);
  461. if (TensorRTOption::is_valid()) {
  462. return std::static_pointer_cast<OptionBase>(option);
  463. } else {
  464. return nullptr;
  465. }
  466. }
  467. void TensorRTOption::config_model(
  468. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  469. CONFIG_MODEL_FUN;
  470. }
  471. #endif
  472. ///////////////////////// fuse and preprocess optimize options ///////////////
  473. DEFINE_bool(
  474. enable_fuse_preprocess, false,
  475. "Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr");
  476. DEFINE_bool(
  477. weight_preprocess, false,
  478. "Execute operators with weight preprocess, which can optimize the "
  479. "operator execution time with algo of winograd, im2col ,etc., but "
  480. "it may consume more memory.");
  481. DEFINE_bool(
  482. enable_fuse_conv_bias_nonlinearity, false,
  483. "whether to fuse conv+bias+nonlinearity");
  484. DEFINE_bool(
  485. enable_fuse_conv_bias_with_z, false,
  486. "fuse conv,bias (elemwise add),z(elemwise add) into one opr "
  487. "(only support on GPU)");
  488. ///////////////////////// graph retrict options /////////////////////////
  489. DEFINE_bool(
  490. const_shape, false,
  491. "set const_var_shape to reduce memory usage, since some static "
  492. "inference data structures can be omitted");
  493. DEFINE_bool(
  494. fake_first, false,
  495. "Enable fake exec for the first run. In fake exec mode, some "
  496. "initialization job would be done, but no actual computing is "
  497. "performed.");
  498. DEFINE_bool(no_sanity_check, false, "Disable var sanity check on the first run");
  499. DEFINE_bool(
  500. record_comp_seq, false,
  501. "Record the computing sequence, in level 1 . It reduces overhead of API"
  502. "calls of some asynchronous computing devices");
  503. DEFINE_bool(
  504. record_comp_seq2, false,
  505. "Record the computing sequence, in level 2, the computing graph can be"
  506. "destructed to reduce memory usage");
  507. DEFINE_bool(disable_mem_opt, false, "disable memory optimization!!");
  508. DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit");
  509. ///////////////////////// other options for optimization /////////////////
  510. DEFINE_bool(
  511. enable_jit, false,
  512. " Execute supported operators with JIT(now only support NVRTC). "
  513. "Can only be used on Nvidia GPUs");
  514. #if MGB_ENABLE_TENSOR_RT
  515. DEFINE_bool(
  516. tensorrt, false,
  517. " Execute supported operators with TensorRT. Can only be used on "
  518. "Nvidia GPUs,i.e. comp node is xpu or gpu.");
  519. DEFINE_string(
  520. tensorrt_cache, "",
  521. "Set the TensorRT engine cache path for serialized prebuilt "
  522. "ICudaEngine");
  523. #endif
  524. REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option);
  525. REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option);
  526. REGIST_OPTION_CREATOR(
  527. fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option);
  528. REGIST_OPTION_CREATOR(
  529. fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option);
  530. REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option);
  531. REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option);
  532. REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option);
  533. #if MGB_ENABLE_TENSOR_RT
  534. REGIST_OPTION_CREATOR(tensorRT, lar::TensorRTOption::create_option);
  535. #endif