You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_options.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. #include <map>
  2. #include "helpers/data_parser.h"
  3. #include "misc.h"
  4. #include "models/model_lite.h"
  5. #include "models/model_mdl.h"
  6. #include "io_options.h"
  7. namespace lar {
  8. template <>
  9. void InputOption::config_model_internel<ModelLite>(
  10. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  11. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  12. auto&& parser = model->get_input_parser();
  13. auto&& io = model->get_networkIO();
  14. for (size_t idx = 0; idx < data_path.size(); ++idx) {
  15. parser.feed(data_path[idx].c_str());
  16. }
  17. auto inputs = parser.inputs;
  18. bool is_host = true;
  19. for (auto& i : inputs) {
  20. io.inputs.push_back({i.first, is_host});
  21. }
  22. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  23. auto&& parser = model->get_input_parser();
  24. auto&& network = model->get_lite_network();
  25. //! datd type map from lite data type to mgb data type
  26. std::map<LiteDataType, megdnn::DTypeEnum> type_map = {
  27. {LiteDataType::LITE_FLOAT, megdnn::DTypeEnum::Float32},
  28. {LiteDataType::LITE_INT, megdnn::DTypeEnum::Int32},
  29. {LiteDataType::LITE_INT8, megdnn::DTypeEnum::Int8},
  30. {LiteDataType::LITE_UINT8, megdnn::DTypeEnum::Uint8}};
  31. if (m_force_batch_size > 0) {
  32. LITE_WARN("force set batch size to %d", m_force_batch_size);
  33. auto all_inputs_name = network->get_all_input_name();
  34. for (auto& name : all_inputs_name) {
  35. std::shared_ptr<lite::Tensor> input_tensor =
  36. network->get_io_tensor(name);
  37. //! set lite layout
  38. lite::Layout layout;
  39. mgb::TensorShape new_shape;
  40. new_shape.ndim = input_tensor->get_layout().ndim;
  41. layout.ndim = input_tensor->get_layout().ndim;
  42. for (size_t idx = 0; idx < new_shape.ndim; idx++) {
  43. new_shape.shape[idx] = input_tensor->get_layout().shapes[idx];
  44. layout.shapes[idx] = new_shape.shape[idx];
  45. }
  46. new_shape.shape[0] = m_force_batch_size;
  47. layout.shapes[0] = m_force_batch_size;
  48. //! gengrate tesnor copy from origin tensor
  49. mgb::HostTensorND hv;
  50. hv.comp_node(mgb::CompNode::default_cpu(), true)
  51. .dtype(megdnn::DType::from_enum(
  52. type_map[input_tensor->get_layout().data_type]))
  53. .resize(new_shape);
  54. mgb::dt_byte* raw_ptr = hv.raw_ptr();
  55. //! single batch input size
  56. size_t batch_stride = hv.dtype().size() * hv.layout().total_nr_elems() /
  57. m_force_batch_size;
  58. size_t curr_batch_size = m_force_batch_size;
  59. //! copy data from origin input_tensor
  60. size_t init_batch = input_tensor->get_layout().shapes[0];
  61. while (curr_batch_size > init_batch) {
  62. memcpy((char*)raw_ptr, (char*)(input_tensor->get_memory_ptr()),
  63. batch_stride * init_batch);
  64. curr_batch_size -= init_batch;
  65. raw_ptr += batch_stride * init_batch;
  66. }
  67. memcpy((char*)raw_ptr, (char*)(input_tensor->get_memory_ptr()),
  68. batch_stride * curr_batch_size);
  69. input_tensor->reset(hv.raw_ptr(), layout);
  70. parser.inputs[name] = std::move(hv);
  71. }
  72. } else {
  73. for (auto& i : parser.inputs) {
  74. //! get tensor information from data parser
  75. auto tensor = i.second;
  76. auto tensor_shape = tensor.shape();
  77. mgb::dt_byte* src = tensor.raw_ptr();
  78. std::shared_ptr<lite::Tensor> input_tensor =
  79. network->get_io_tensor(i.first);
  80. //! set lite layout
  81. lite::Layout layout;
  82. layout.ndim = tensor_shape.ndim;
  83. for (size_t idx = 0; idx < tensor_shape.ndim; idx++) {
  84. layout.shapes[idx] = tensor_shape[idx];
  85. }
  86. layout.data_type = input_tensor->get_layout().data_type;
  87. //! set data for only given shape
  88. if (tensor.storage().empty()) {
  89. mgb::HostTensorND hv;
  90. hv.comp_node(mgb::CompNode::default_cpu(), true)
  91. .dtype(megdnn::DType::from_enum(type_map[layout.data_type]))
  92. .resize(tensor.shape());
  93. mgb::dt_byte* raw_ptr = hv.raw_ptr();
  94. //! set all value in tesnor to 1
  95. memset((char*)raw_ptr, 1,
  96. hv.layout().total_nr_elems() * hv.dtype().size());
  97. parser.inputs[i.first] = std::move(hv);
  98. input_tensor->reset(raw_ptr, layout);
  99. } else {
  100. //! set network input tensor
  101. input_tensor->reset(src, layout);
  102. }
  103. }
  104. }
  105. }
  106. }
  107. template <>
  108. void InputOption::config_model_internel<ModelMdl>(
  109. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  110. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  111. auto&& parser = model->get_input_parser();
  112. for (size_t idx = 0; idx < data_path.size(); ++idx) {
  113. parser.feed(data_path[idx].c_str());
  114. }
  115. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  116. auto&& parser = model->get_input_parser();
  117. auto&& network = model->get_mdl_load_result();
  118. auto&& tensormap = network.tensor_map;
  119. if (m_force_batch_size > 0) {
  120. mgb_log_warn("force set batch size to %d", m_force_batch_size);
  121. for (auto& iter : tensormap) {
  122. auto& in = iter.second;
  123. mgb::HostTensorND hv;
  124. mgb::TensorShape new_shape = in->shape();
  125. new_shape[0] = m_force_batch_size;
  126. hv.comp_node(mgb::CompNode::default_cpu(), true)
  127. .dtype(in->dtype())
  128. .resize(new_shape);
  129. mgb::dt_byte* raw_ptr = hv.raw_ptr();
  130. //! copy given batch data into new tensor
  131. size_t batch_stride = in->dtype().size() *
  132. in->layout().total_nr_elems() / (in->shape()[0]);
  133. size_t curr_batch_size = m_force_batch_size;
  134. //! copy data from origin input_tensor
  135. size_t init_batch = in->shape()[0];
  136. while (curr_batch_size > init_batch) {
  137. memcpy((char*)raw_ptr, (char*)(in->raw_ptr()),
  138. batch_stride * init_batch);
  139. curr_batch_size -= init_batch;
  140. raw_ptr += batch_stride * init_batch;
  141. }
  142. memcpy((char*)raw_ptr, (char*)(in->raw_ptr()),
  143. batch_stride * curr_batch_size);
  144. //! set input tensor
  145. in->copy_from(hv);
  146. parser.inputs[iter.first] = std::move(hv);
  147. }
  148. } else {
  149. for (auto& i : parser.inputs) {
  150. mgb_assert(
  151. tensormap.find(i.first) != tensormap.end(),
  152. "can't find tesnor named %s", i.first.c_str());
  153. auto& in = tensormap.find(i.first)->second;
  154. if (i.second.storage().empty()) {
  155. mgb::HostTensorND hv;
  156. hv.comp_node(mgb::CompNode::default_cpu(), true)
  157. .dtype(in->dtype())
  158. .resize(i.second.shape());
  159. mgb::dt_byte* raw_ptr = hv.raw_ptr();
  160. memset((char*)raw_ptr, 1,
  161. hv.layout().total_nr_elems() * hv.dtype().size());
  162. in->copy_from(hv);
  163. parser.inputs[i.first] = std::move(hv);
  164. } else {
  165. in->copy_from(i.second);
  166. }
  167. }
  168. }
  169. }
  170. }
  171. template <>
  172. void IOdumpOption::config_model_internel<ModelLite>(
  173. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  174. if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  175. if (enable_io_dump) {
  176. LITE_LOG("enable text io dump");
  177. lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path);
  178. }
  179. if (enable_bin_io_dump) {
  180. LITE_LOG("enable binary io dump");
  181. lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path);
  182. }
  183. //! FIX:when add API in lite complate this
  184. if (enable_io_dump_stdout || enable_io_dump_stderr) {
  185. LITE_THROW("lite model don't support the stdout or stderr io dump");
  186. }
  187. if (enable_bin_out_dump) {
  188. LITE_THROW("lite model don't support the binary output dump");
  189. }
  190. if (enable_copy_to_host) {
  191. LITE_LOG("lite model set copy to host defaultly");
  192. }
  193. }
  194. }
  195. template <>
  196. void IOdumpOption::config_model_internel<ModelMdl>(
  197. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  198. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  199. if (enable_io_dump) {
  200. mgb_log("enable text io dump");
  201. auto iodump = std::make_unique<mgb::TextOprIODump>(
  202. model->get_mdl_config().comp_graph.get(), dump_path.c_str());
  203. iodump->print_addr(false);
  204. io_dumper = std::move(iodump);
  205. }
  206. if (enable_io_dump_stdout) {
  207. mgb_log("enable text io dump to stdout");
  208. std::shared_ptr<FILE> std_out(stdout, [](FILE*) {});
  209. auto iodump = std::make_unique<mgb::TextOprIODump>(
  210. model->get_mdl_config().comp_graph.get(), std_out);
  211. iodump->print_addr(false);
  212. io_dumper = std::move(iodump);
  213. }
  214. if (enable_io_dump_stderr) {
  215. mgb_log("enable text io dump to stderr");
  216. std::shared_ptr<FILE> std_err(stderr, [](FILE*) {});
  217. auto iodump = std::make_unique<mgb::TextOprIODump>(
  218. model->get_mdl_config().comp_graph.get(), std_err);
  219. iodump->print_addr(false);
  220. io_dumper = std::move(iodump);
  221. }
  222. if (enable_bin_io_dump) {
  223. mgb_log("enable binary io dump");
  224. auto iodump = std::make_unique<mgb::BinaryOprIODump>(
  225. model->get_mdl_config().comp_graph.get(), dump_path);
  226. io_dumper = std::move(iodump);
  227. }
  228. if (enable_bin_out_dump) {
  229. mgb_log("enable binary output dump");
  230. out_dumper = std::make_unique<OutputDumper>(dump_path.c_str());
  231. }
  232. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  233. if (enable_bin_out_dump) {
  234. auto&& load_result = model->get_mdl_load_result();
  235. out_dumper->set(load_result.output_var_list);
  236. std::vector<mgb::ComputingGraph::Callback> cb;
  237. for (size_t i = 0; i < load_result.output_var_list.size(); i++) {
  238. cb.push_back(out_dumper->bind());
  239. }
  240. model->set_output_callback(cb);
  241. }
  242. if (enable_copy_to_host) {
  243. auto&& load_result = model->get_mdl_load_result();
  244. std::vector<mgb::ComputingGraph::Callback> cb;
  245. for (size_t i = 0; i < load_result.output_var_list.size(); i++) {
  246. mgb::HostTensorND val;
  247. auto callback = [val](const mgb::DeviceTensorND& dv) mutable {
  248. val.copy_from(dv);
  249. };
  250. cb.push_back(callback);
  251. }
  252. model->set_output_callback(cb);
  253. }
  254. } else if (runtime_param.stage == RunStage::AFTER_RUNNING_WAIT) {
  255. if (enable_bin_out_dump) {
  256. out_dumper->write_to_file();
  257. }
  258. }
  259. }
  260. } // namespace lar
  261. ////////////////////// Input options ////////////////////////
  262. using namespace lar;
  263. void InputOption::update() {
  264. data_path.clear();
  265. m_option_name = "input";
  266. size_t start = 0;
  267. auto end = FLAGS_input.find(";", start);
  268. while (end != std::string::npos) {
  269. std::string path = FLAGS_input.substr(start, end - start);
  270. data_path.emplace_back(path);
  271. start = end + 1;
  272. end = FLAGS_input.find(";", start);
  273. }
  274. data_path.emplace_back(FLAGS_input.substr(start));
  275. m_force_batch_size = FLAGS_batch_size;
  276. }
  277. std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() {
  278. static std::shared_ptr<InputOption> option(new InputOption);
  279. if (InputOption::is_valid()) {
  280. option->update();
  281. return std::static_pointer_cast<OptionBase>(option);
  282. } else {
  283. return nullptr;
  284. }
  285. }
  286. void InputOption::config_model(
  287. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  288. CONFIG_MODEL_FUN;
  289. }
  290. ////////////////////// OprIOdump options ////////////////////////
  291. void IOdumpOption::update() {
  292. m_option_name = "iodump";
  293. size_t valid_flag = 0;
  294. if (!FLAGS_io_dump.empty()) {
  295. dump_path = FLAGS_io_dump;
  296. enable_io_dump = true;
  297. valid_flag = valid_flag | (1 << 0);
  298. }
  299. if (!FLAGS_bin_io_dump.empty()) {
  300. dump_path = FLAGS_bin_io_dump;
  301. enable_bin_io_dump = true;
  302. valid_flag = valid_flag | (1 << 1);
  303. }
  304. if (!FLAGS_bin_out_dump.empty()) {
  305. dump_path = FLAGS_bin_out_dump;
  306. enable_bin_out_dump = true;
  307. valid_flag = valid_flag | (1 << 2);
  308. }
  309. if (FLAGS_io_dump_stdout) {
  310. enable_io_dump_stdout = FLAGS_io_dump_stdout;
  311. valid_flag = valid_flag | (1 << 3);
  312. }
  313. if (FLAGS_io_dump_stderr) {
  314. enable_io_dump_stderr = FLAGS_io_dump_stderr;
  315. valid_flag = valid_flag | (1 << 4);
  316. }
  317. // not only one dump set valid
  318. if (valid_flag && (valid_flag & (valid_flag - 1))) {
  319. mgb_log_warn(
  320. "ONLY the last io dump option is validate and others is "
  321. "skipped!!!");
  322. }
  323. enable_copy_to_host = FLAGS_copy_to_host;
  324. }
  325. bool IOdumpOption::is_valid() {
  326. bool ret = !FLAGS_io_dump.empty();
  327. ret = ret || FLAGS_io_dump_stdout;
  328. ret = ret || FLAGS_io_dump_stderr;
  329. ret = ret || !FLAGS_bin_io_dump.empty();
  330. ret = ret || !FLAGS_bin_out_dump.empty();
  331. ret = ret || FLAGS_copy_to_host;
  332. return ret;
  333. }
  334. std::shared_ptr<OptionBase> IOdumpOption::create_option() {
  335. static std::shared_ptr<IOdumpOption> option(new IOdumpOption);
  336. if (IOdumpOption::is_valid()) {
  337. option->update();
  338. return std::static_pointer_cast<OptionBase>(option);
  339. } else {
  340. return nullptr;
  341. }
  342. }
  343. void IOdumpOption::config_model(
  344. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  345. CONFIG_MODEL_FUN;
  346. }
  347. ////////////////////// Input gflags ////////////////////////
  348. DEFINE_string(
  349. input, "", "Set up inputs data for model --input [ file_path | data_string]");
  350. DEFINE_int32(
  351. batch_size, -1,
  352. "set the batch size of input(especially for global layout transform "
  353. "optimization working on)");
  354. ////////////////////// OprIOdump gflags ////////////////////////
  355. DEFINE_string(io_dump, "", "set the io dump file path in text format");
  356. DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format");
  357. DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format");
  358. DEFINE_string(
  359. bin_io_dump, "",
  360. "set the io dump directory path where variable in binary format located");
  361. DEFINE_string(
  362. bin_out_dump, "",
  363. "set the out dump directory path where output variable in binary format "
  364. "located");
  365. DEFINE_bool(copy_to_host, false, "copy device data to host");
  366. REGIST_OPTION_CREATOR(input, lar::InputOption::create_option);
  367. REGIST_OPTION_CREATOR(iodump, lar::IOdumpOption::create_option);