You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layout_trans_options.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #include "layout_trans_options.h"
  2. #include <gflags/gflags.h>
  3. #include "megbrain/serialization/serializer.h"
  4. #include "misc.h"
  5. #include "models/model_lite.h"
  6. #include "models/model_mdl.h"
  7. namespace lar {
  8. template <>
  9. void GoptLayoutOption::config_model_internel<ModelLite>(
  10. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  11. if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
  12. if (m_layout_transform) {
  13. LITE_LOG("using global layout transform optimization\n");
  14. if (m_layout_transform_target ==
  15. mgb::gopt::GraphTuningOptions::Target::CPU) {
  16. model->get_config().device_type = LiteDeviceType::LITE_CPU;
  17. }
  18. #if LITE_WITH_CUDA
  19. else if (
  20. m_layout_transform_target ==
  21. mgb::gopt::GraphTuningOptions::Target::CUDA) {
  22. model->get_config().device_type = LiteDeviceType::LITE_CUDA;
  23. }
  24. #endif
  25. LITE_LOG("enable layout transform while load model for lite");
  26. auto&& lite_network = model->get_lite_network();
  27. lite::Runtime::enable_global_layout_transform(lite_network);
  28. }
  29. } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
  30. if (m_layout_transform) {
  31. auto&& network = model->get_lite_network();
  32. if (!m_layout_transform_dump_file.empty()) {
  33. lite::Runtime::dump_layout_transform_model(
  34. network, m_layout_transform_dump_file);
  35. }
  36. }
  37. }
  38. }
  39. template <>
  40. void GoptLayoutOption::config_model_internel<ModelMdl>(
  41. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  42. if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  43. if (m_layout_transform) {
  44. mgb_log_debug("update input shape for global layout transform\n");
  45. auto&& load_result = model->get_mdl_load_result();
  46. if (m_force_batch_size > 0) {
  47. for (auto&& i : load_result.tensor_map) {
  48. auto& in = i.second;
  49. mgb::TensorShape new_shape = in->shape();
  50. new_shape[0] = m_force_batch_size;
  51. mgb::HostTensorND new_tensor;
  52. new_tensor.comp_node(mgb::CompNode::default_cpu(), true)
  53. .dtype(in->dtype())
  54. .resize(new_shape);
  55. mgb::dt_byte* raw_ptr = new_tensor.raw_ptr();
  56. memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems());
  57. in->copy_from(new_tensor);
  58. }
  59. }
  60. for (auto&& item : load_result.output_var_list) {
  61. if (item.shape()[0] > 1) {
  62. mgb_log_warn(
  63. " model may be dumped with multi batch and will cost lots "
  64. "of time to profile during global layout transform!!!\n");
  65. }
  66. }
  67. //! update output varlist when input shape maybe change(some pass excution
  68. //! time depends on the shape of init input)
  69. mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar>
  70. varmap;
  71. mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) {
  72. if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) {
  73. auto param = h2d->param();
  74. mgb::TensorShape new_shape = h2d->host_data()->shape();
  75. std::shared_ptr<mgb::HostTensorND> new_tensor =
  76. std::make_shared<mgb::HostTensorND>(
  77. h2d->host_data()->comp_node(), new_shape,
  78. h2d->host_data()->dtype());
  79. new_tensor->only_reset_raw_storage(h2d->host_data()->storage());
  80. auto h2d_opr = mgb::opr::Host2DeviceCopy::make(
  81. *h2d->owner_graph(), new_tensor, param, h2d->config());
  82. varmap[h2d->output(0)] = h2d_opr;
  83. }
  84. });
  85. for (auto&& i : load_result.output_var_list)
  86. dep.add(i);
  87. if (!varmap.empty()) {
  88. auto output_vars =
  89. mgb::cg::replace_vars(load_result.output_var_list, varmap);
  90. for (size_t i = 0; i < load_result.output_var_list.size(); ++i) {
  91. output_vars[i].rename(
  92. load_result.output_var_list[i].node()->name());
  93. }
  94. load_result.output_var_list = output_vars;
  95. }
  96. }
  97. } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
  98. if (m_layout_transform) {
  99. mgb_log("using global layout transform optimization\n");
  100. auto&& load_result = model->get_mdl_load_result();
  101. load_result.output_var_list = mgb::gopt::layout_transform(
  102. load_result.output_var_list, m_layout_transform_target);
  103. if (!m_layout_transform_dump_file.empty()) {
  104. auto out_file = mgb::serialization::OutputFile::make_fs(
  105. m_layout_transform_dump_file.c_str(), 'w');
  106. auto testcase_num = model->get_testcase_num();
  107. if (testcase_num) {
  108. const char* magic = "mgbtest0";
  109. constexpr size_t len = sizeof(magic);
  110. out_file->write(magic, len);
  111. out_file->write(&testcase_num, sizeof(testcase_num));
  112. }
  113. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  114. DumpConfig config{1, false, false};
  115. auto dumper = model->get_dumper(std::move(out_file));
  116. dumper->dump(load_result.output_var_list, config);
  117. if (testcase_num) {
  118. auto input_file = model->get_loader()->reset_file();
  119. auto current_offset = input_file->tell();
  120. auto loader = model->reset_loader(std::move(input_file));
  121. auto testcase = loader->load(model->get_mdl_config(), false);
  122. mgb::serialization::GraphDumper::DumpConfig config{1, false, false};
  123. for (size_t i = 0; i < testcase_num; ++i) {
  124. auto casefile = mgb::serialization::OutputFile::make_fs(
  125. m_layout_transform_dump_file.c_str(), 'a');
  126. auto casedumper = model->get_dumper(std::move(casefile));
  127. casedumper->dump(testcase.output_var_list, config);
  128. if (i != testcase_num - 1) {
  129. loader = model->reset_loader();
  130. testcase = loader->load(model->get_mdl_config(), false);
  131. }
  132. }
  133. input_file = model->get_loader()->reset_file();
  134. input_file->rewind();
  135. input_file->skip(current_offset);
  136. model->reset_loader(std::move(input_file));
  137. }
  138. }
  139. }
  140. }
  141. }
  142. } // namespace lar
  143. using namespace lar;
  144. bool GoptLayoutOption::m_valid;
  145. void GoptLayoutOption::update() {
  146. m_option_name = "gopt_layout";
  147. if (FLAGS_layout_transform != "cpu"
  148. #if LITE_WITH_CUDA
  149. && FLAGS_layout_transform != "cuda"
  150. #endif
  151. ) {
  152. m_layout_transform = false;
  153. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  154. } else {
  155. m_layout_transform = true;
  156. if (FLAGS_layout_transform == "cpu") {
  157. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  158. }
  159. #if LITE_WITH_CUDA
  160. else if (FLAGS_layout_transform == "cuda") {
  161. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  162. }
  163. #endif
  164. }
  165. m_layout_transform_dump_file = FLAGS_layout_transform_dump;
  166. m_force_batch_size = FLAGS_layout_transform_batch_size;
  167. m_option = {
  168. {"layout_transform", lar::String::make("")},
  169. };
  170. std::static_pointer_cast<lar::String>(m_option["layout_transform"])
  171. ->set_value(FLAGS_layout_transform);
  172. }
  173. bool GoptLayoutOption::is_valid() {
  174. bool ret = false;
  175. if (!FLAGS_layout_transform.empty()) {
  176. if (FLAGS_layout_transform != "cpu"
  177. #if LITE_WITH_CUDA
  178. && FLAGS_layout_transform != "cuda"
  179. #endif
  180. ) {
  181. mgb_assert(
  182. false,
  183. "unsupported target(got:%s) for global layout "
  184. "transform",
  185. FLAGS_layout_transform.c_str());
  186. ret = false;
  187. } else {
  188. ret = true;
  189. }
  190. }
  191. ret = ret || !FLAGS_layout_transform_dump.empty();
  192. if (FLAGS_layout_transform_batch_size > 0) {
  193. mgb_assert(
  194. FLAGS_layout_transform_batch_size > 0 &&
  195. !FLAGS_layout_transform.empty(),
  196. "\"layout-transform-batch-size\" should be set with "
  197. "\"layout-transform\"");
  198. ret = ret || FLAGS_layout_transform_batch_size > 0;
  199. }
  200. return ret || m_valid;
  201. }
  202. std::shared_ptr<OptionBase> GoptLayoutOption::create_option() {
  203. static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption);
  204. if (GoptLayoutOption::is_valid()) {
  205. option->update();
  206. return std::static_pointer_cast<OptionBase>(option);
  207. } else {
  208. return nullptr;
  209. }
  210. }
  211. void GoptLayoutOption::config_model(
  212. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  213. auto value = std::static_pointer_cast<lar::String>(m_option["layout_transform"])
  214. ->get_value();
  215. if (value.empty()) {
  216. return;
  217. }
  218. if (value == "cpu") {
  219. m_layout_transform = true;
  220. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  221. }
  222. #if LITE_WITH_CUDA
  223. else if (value == "cuda") {
  224. m_layout_transform = true;
  225. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  226. }
  227. #endif
  228. else {
  229. mgb_throw(
  230. mgb::AssertionError, "invalid options of global layout transform %s",
  231. value.c_str());
  232. }
  233. CONFIG_MODEL_FUN;
  234. }
  235. DEFINE_string(
  236. layout_transform, "",
  237. "Enable global layout transform optimization for computing graph. User should "
  238. "specify the device target for the optimization, and a series of passes will "
  239. "be applied on the computing graph. The passes will benchmark the elapsed time "
  240. "of operators on different tensor layouts, and select fastest implementation "
  241. "for the operators. The optimization process will take some time. The default "
  242. "target is unspec, which all the available for operators will be profiled. So "
  243. "the optimize time will be longer.");
  244. DEFINE_string(
  245. layout_transform_dump, "",
  246. "The computing graph after global layout transform will be dumped to the given "
  247. "file path.");
  248. DEFINE_int32(
  249. layout_transform_batch_size, -1,
  250. "the batch size of input for global layout transform optimization working on");
  251. REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
  252. REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid);