You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. /**
  2. * \file src/mge/network_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "common.h"
  14. #include "lite/network.h"
  15. #include "memory_allocator.h"
  16. #include "network_impl.h"
  17. #include "parse_info/parse_info_base.h"
  18. #include "parse_model/model_parser.h"
  19. #include "megbrain/common.h"
  20. #include "megbrain/comp_node.h"
  21. #include "megbrain/comp_node_env.h"
  22. #include "megbrain/gopt/inference.h"
  23. #include "megbrain/graph.h"
  24. #include "megbrain/graph/cg.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/tensor.h"
  27. #if MGB_OPENCL
  28. #include "megcore_opencl.h"
  29. #endif
  30. #include <fstream>
  31. #include <memory>
  32. #include <set>
  33. using namespace lite;
  34. using namespace mgb;
  35. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  36. void NetworkImplDft::set_config(const Config& config) {
  37. m_user_config = std::make_unique<Config>();
  38. *m_user_config = config;
  39. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  40. m_compnode_locator.device = config.device_id;
  41. }
  42. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  43. application_config();
  44. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  45. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  46. m_load_result = src_impl.m_loader->load(m_load_config, true);
  47. //! flag weather the mode is cross compnode model
  48. cross_compnode_model_detect();
  49. //! update the IO of the network
  50. update_io();
  51. //! replace the IO when there is device input or output
  52. compile_graph();
  53. }
  54. void NetworkImplDft::application_config() {
  55. auto device_type = m_user_config->device_type;
  56. m_compnode_locator.type = to_compnode_locator(device_type).type;
  57. m_compnode_locator.device = m_user_config->device_id;
  58. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  59. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  60. m_compnode_locator.device = m_user_config->device_id;
  61. }
  62. //! model options
  63. #define ConfigOption(mge_name, lite_name) \
  64. options.mge_name = m_user_config->options.lite_name;
  65. auto&& options = m_load_config.comp_graph->options();
  66. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  67. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  68. ConfigOption(fake_next_exec, fake_next_exec);
  69. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  70. m_load_config.const_var_shape = m_user_config->options.const_shape;
  71. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  72. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  73. ConfigOption(
  74. force_output_use_user_specified_memory,
  75. force_output_use_user_specified_memory);
  76. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  77. LITE_ASSERT(
  78. m_user_config->options.jit_level == 0 ||
  79. (m_user_config->options.jit_level > 0 &&
  80. device_type == LiteDeviceType::LITE_CUDA),
  81. "jit only support in cuda device.");
  82. ConfigOption(graph_opt.jit, jit_level);
  83. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  84. ConfigOption(graph_opt_level, graph_opt_level);
  85. ConfigOption(async_exec_level, async_exec_level);
  86. #undef ConfigOption
  87. #define ConfigOptionLayoutTransform(name) \
  88. if (m_user_config->options.name) { \
  89. options.graph_opt.name(); \
  90. }
  91. ConfigOptionLayoutTransform(enable_nchw44);
  92. ConfigOptionLayoutTransform(enable_nchw44_dot);
  93. ConfigOptionLayoutTransform(enable_nchw88);
  94. ConfigOptionLayoutTransform(enable_nhwcd4);
  95. ConfigOptionLayoutTransform(enable_nchw4);
  96. ConfigOptionLayoutTransform(enable_nchw32);
  97. ConfigOptionLayoutTransform(enable_nchw64);
  98. #undef ConfigOptionLayoutTransform
  99. if (m_user_config->has_compression) {
  100. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  101. }
  102. //! if device is LITE_NONE, the compnode information is stored in model
  103. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  104. //! currently not set Locator type because an atlas mgb model is a
  105. //! cross-compnode graph
  106. if (device_type == LiteDeviceType::LITE_ATLAS) {
  107. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  108. if (loc.type == mgb::CompNode::DeviceType::ATLAS) {
  109. loc.device = m_compnode_locator.device;
  110. loc.stream = m_compnode_locator.stream;
  111. } else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
  112. loc.stream = m_nr_threads;
  113. }
  114. };
  115. } else {
  116. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  117. loc = m_compnode_locator;
  118. };
  119. }
  120. }
  121. }
  122. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  123. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  124. LITE_ASSERT(m_load_config.comp_graph);
  125. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  126. }
  127. //! share the runtime memory with other network, the weights is not shared
  128. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  129. LITE_ASSERT(network_impl);
  130. LITE_ASSERT(m_load_config.comp_graph);
  131. m_load_config.comp_graph->share_device_memory_with(*(
  132. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  133. }
  134. void NetworkImplDft::set_cpu_inplace_mode() {
  135. LITE_ASSERT(
  136. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  137. "cpu inplace mode is only avaliable in CPU.");
  138. m_is_cpu_inplace_mode = true;
  139. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  140. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  141. } else {
  142. LITE_ASSERT(
  143. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  144. "cpu inplace mode is only avaliable in CPU.");
  145. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  146. }
  147. }
  148. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  149. LITE_ASSERT(
  150. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  151. "multi threads mode is only avaliable in CPU.");
  152. if (nr_threads > 1) {
  153. m_nr_threads = nr_threads;
  154. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  155. m_compnode_locator.nr_threads = nr_threads;
  156. }
  157. }
  158. void NetworkImplDft::set_runtime_thread_affinity(
  159. const ThreadAffinityCallback& thread_affinity_callback) {
  160. LITE_ASSERT(
  161. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  162. "multi threads mode is only avaliable in CPU.");
  163. mgb::CompNode::Locator loc;
  164. m_load_config.comp_node_mapper(loc);
  165. auto cn = mgb::CompNode::load(loc);
  166. if (m_nr_threads > 1) {
  167. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  168. thread_affinity_callback);
  169. } else {
  170. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  171. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  172. }
  173. }
  174. void NetworkImplDft::set_device_id(int device_id) {
  175. m_compnode_locator.device = device_id;
  176. m_user_config->device_id = device_id;
  177. }
  178. void NetworkImplDft::set_stream_id(int stream_id) {
  179. m_compnode_locator.stream = stream_id;
  180. }
  181. void NetworkImplDft::use_tensorrt() {
  182. auto&& options = m_load_config.comp_graph->options();
  183. options.graph_opt.tensorrt = true;
  184. }
  185. //! set the callback in async model
  186. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  187. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  188. LITE_ASSERT(
  189. m_user_config->options.comp_node_seq_record_level == 0,
  190. "record mode not support async mode");
  191. LITE_ASSERT(
  192. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  193. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  194. "Now only cpu and cuda>10.0 support async mode");
  195. m_async = true;
  196. m_async_callback = std::move(callback);
  197. }
  198. void NetworkImplDft::make_output_spec() {
  199. m_output_spec.clear();
  200. for (auto&& out : m_network_io->outputs) {
  201. if (m_load_result.output_var_map.count(out.name)) {
  202. auto&& load_out = m_load_result.output_var_map[out.name];
  203. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  204. mgb::CompNode comp_node = dv.comp_node();
  205. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  206. auto mgb_layout = dv.layout();
  207. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  208. } else {
  209. TensorHelper::implement(out.lite_tensor)
  210. ->cast_final_safe<TensorImplDft>()
  211. .copy_from_mge_tensor(dv);
  212. out.lite_tensor->update_from_implement();
  213. }
  214. if (m_async) {
  215. out.have_sync = true;
  216. bool need_exec_cb = true;
  217. for (auto&& j : m_network_io->outputs) {
  218. if (!j.have_sync) {
  219. need_exec_cb = false;
  220. }
  221. }
  222. if (need_exec_cb) {
  223. for (auto&& j : m_network_io->outputs) {
  224. j.have_sync = false;
  225. }
  226. comp_node.add_callback([this]() { finish(); });
  227. }
  228. }
  229. };
  230. //! if write to user-specified memory, the CallbackCaller must be nullptr.
  231. if (m_user_config->options.force_output_use_user_specified_memory ||
  232. m_user_config->options.force_output_dynamic_alloc) {
  233. m_output_spec.emplace_back(load_out, nullptr);
  234. } else {
  235. m_output_spec.emplace_back(load_out, std::move(cb));
  236. }
  237. } else {
  238. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  239. }
  240. }
  241. }
  242. void NetworkImplDft::replace_dev_input_pass() {
  243. mgb::CompNode::Locator locator;
  244. m_load_config.comp_node_mapper(locator);
  245. //! CPU is not need use device input
  246. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  247. return;
  248. }
  249. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  250. //! in m_network_io.input, user can directly change the dev tensor
  251. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  252. using DeviceTensorMap =
  253. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  254. DeviceTensorMap name2dev_tensor;
  255. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  256. //! construct host_val2var that maps from host tensor to corresponding var
  257. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  258. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  259. mgb::HostTensorND* tensor =
  260. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  261. host_val2var[tensor] = opr->output(0);
  262. }
  263. };
  264. mgb::cg::DepOprIter dep_iter{on_opr};
  265. for (auto i : m_load_result.output_var_list) {
  266. dep_iter.add(i.node()->owner_opr());
  267. }
  268. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  269. mgb::SmallVector<std::string> to_clear;
  270. for (auto&& config_in : m_network_io->inputs) {
  271. if (!config_in.is_host) {
  272. auto host_val = m_load_result.tensor_map[config_in.name];
  273. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  274. ->cast_final_safe<TensorImplDft>()
  275. .m_dev_tensor;
  276. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  277. *m_load_result.graph, dev_val, {config_in.name});
  278. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  279. name2dev_tensor[config_in.name] = dev_val;
  280. }
  281. }
  282. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  283. for (size_t i = 0; i < new_ovar.size(); ++i) {
  284. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  285. }
  286. for (auto&& i : m_load_result.output_var_map) {
  287. i.second = out_var_map.at(i.second);
  288. }
  289. for (auto&& i : m_load_result.output_var_map_id) {
  290. i.second = out_var_map.at(i.second);
  291. }
  292. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  293. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  294. }
  295. m_load_result.output_var_list = std::move(new_ovar);
  296. }
  297. void NetworkImplDft::cross_compnode_model_detect() {
  298. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  299. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  300. for (auto j : opr->output()) {
  301. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  302. nr_used_device_type.insert(
  303. get_device_from_locator(j->comp_node().locator()));
  304. }
  305. }
  306. };
  307. mgb::cg::DepOprIter dep_iter{on_opr};
  308. for (auto i : m_load_result.output_var_list) {
  309. dep_iter.add(i.node()->owner_opr());
  310. }
  311. m_nr_device_type = nr_used_device_type.size();
  312. }
  313. void NetworkImplDft::load_model(
  314. std::shared_ptr<void> model_mem, size_t size,
  315. std::unordered_map<std::string, LiteAny> separate_config_map) {
  316. if (!m_loader) {
  317. m_input_file =
  318. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  319. auto format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  320. *m_input_file);
  321. if (!format.valid()) {
  322. LITE_THROW("invalid model format");
  323. }
  324. m_loader = mgb::serialization::GraphLoader::make(
  325. std::move(m_input_file), format.val());
  326. }
  327. //! applay the user configration to mge model
  328. application_config();
  329. //! config some flag get from json config file
  330. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  331. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  332. }
  333. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  334. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  335. set_cpu_threads_number(
  336. separate_config_map["number_threads"].safe_cast<uint32_t>());
  337. }
  338. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  339. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  340. set_cpu_inplace_mode();
  341. }
  342. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  343. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  344. use_tensorrt();
  345. }
  346. m_load_result = m_loader->load(m_load_config, true);
  347. cross_compnode_model_detect();
  348. //! update the IO of the network
  349. update_io();
  350. //! replace the IO when there is device input or output
  351. compile_graph();
  352. }
  353. void NetworkImplDft::compile_graph() {
  354. modify_exection_policy();
  355. replace_dev_input_pass();
  356. make_output_spec();
  357. m_execute_func = m_load_result.graph_compile(m_output_spec);
  358. }
  359. void NetworkImplDft::start() const {
  360. if (m_start_callback) {
  361. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  362. input_io_map;
  363. for (auto&& io_inner : m_network_io->inputs) {
  364. input_io_map[io_inner.name] = {
  365. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  366. io_inner.config_layout},
  367. io_inner.lite_tensor};
  368. }
  369. m_start_callback(input_io_map);
  370. }
  371. }
  372. void NetworkImplDft::forward() {
  373. start();
  374. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  375. m_execute_func->execute();
  376. }
  377. void NetworkImplDft::wait() {
  378. if (!m_async) {
  379. m_execute_func->wait();
  380. }
  381. finish();
  382. }
  383. void NetworkImplDft::finish() const {
  384. if (m_async) {
  385. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  386. m_async_callback();
  387. }
  388. if (m_finish_callback) {
  389. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  390. output_io_map;
  391. for (auto&& io_inner : m_network_io->outputs) {
  392. output_io_map[io_inner.name] = {
  393. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  394. io_inner.config_layout},
  395. io_inner.lite_tensor};
  396. }
  397. m_finish_callback(output_io_map);
  398. }
  399. output_plugin_result();
  400. }
  401. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  402. m_network_io = std::make_unique<NetworkIOInner>();
  403. for (auto&& in : network_io.inputs) {
  404. m_network_io->inputs.emplace_back(in);
  405. }
  406. for (auto&& out : network_io.outputs) {
  407. m_network_io->outputs.emplace_back(out);
  408. }
  409. }
  410. void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
  411. if (var.node()->capable_shape_infer()) {
  412. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  413. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  414. if (!shape) {
  415. LITE_WARN(
  416. "Lite infer output shape failed, maybe the model is "
  417. "dynamic "
  418. "shape.\n");
  419. LITE_ASSERT(
  420. !m_user_config->options.force_output_use_user_specified_memory,
  421. "force_output_use_user_specified_memory can't be used when output "
  422. "shape can't be derived.");
  423. return;
  424. }
  425. Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
  426. tensor->set_layout(layout);
  427. }
  428. }
  429. void NetworkImplDft::update_io() {
  430. update_input();
  431. update_output();
  432. }
  433. void NetworkImplDft::update_input() {
  434. auto device_type = m_user_config->device_type;
  435. auto device_id = m_compnode_locator.device;
  436. auto stream_id = m_compnode_locator.stream;
  437. //! if cpu all input and output are host
  438. if (device_type == LiteDeviceType::LITE_CPU) {
  439. for (auto&& in : m_network_io->inputs) {
  440. in.is_host = true;
  441. }
  442. }
  443. //! if cross compnode model, modify the device input if it is not valid
  444. if (m_nr_device_type > 1) {
  445. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  446. for (auto&& config_in : m_network_io->inputs) {
  447. //! if tensor is set to device input
  448. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  449. //! if the origin compnode of the tensor is not the device,
  450. //! set the input to host
  451. if (get_device_from_locator(
  452. in_tensor_iter.second->comp_node().locator()) ==
  453. LiteDeviceType::LITE_CPU) {
  454. config_in.is_host = true;
  455. LITE_WARN(
  456. "The input tensor %s of the cross device model "
  457. "should not from device.",
  458. config_in.name.c_str());
  459. }
  460. }
  461. }
  462. }
  463. }
  464. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  465. bool found = false;
  466. for (auto&& config_in : m_network_io->inputs) {
  467. if (in_tensor_iter.first == config_in.name) {
  468. found = true;
  469. if (config_in.is_host) {
  470. config_in.lite_tensor = std::make_shared<Tensor>(
  471. device_id, stream_id, device_type, true);
  472. TensorHelper::implement(config_in.lite_tensor)
  473. ->cast_final_safe<TensorImplDft>()
  474. .m_host_tensor = in_tensor_iter.second;
  475. config_in.lite_tensor->update_from_implement();
  476. } else {
  477. config_in.lite_tensor =
  478. std::make_shared<Tensor>(device_id, stream_id, device_type);
  479. config_in.lite_tensor->set_layout(
  480. to_lite_layout(in_tensor_iter.second->layout()));
  481. }
  482. if (config_in.config_layout.ndim &&
  483. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  484. config_in.lite_tensor->set_layout(config_in.config_layout);
  485. }
  486. }
  487. }
  488. if (!found) {
  489. IOInner io_in;
  490. io_in.name = in_tensor_iter.first;
  491. io_in.lite_tensor =
  492. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  493. TensorHelper::implement(io_in.lite_tensor)
  494. ->cast_final_safe<TensorImplDft>()
  495. .m_host_tensor = in_tensor_iter.second;
  496. io_in.lite_tensor->update_from_implement();
  497. m_network_io->inputs.push_back(io_in);
  498. }
  499. }
  500. //! delete the IO that is not the network
  501. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  502. if (it->lite_tensor == nullptr) {
  503. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  504. it = m_network_io->inputs.erase(it);
  505. } else {
  506. it++;
  507. }
  508. }
  509. }
  510. void NetworkImplDft::update_output() {
  511. auto device_type = m_user_config->device_type;
  512. auto device_id = m_compnode_locator.device;
  513. auto stream_id = m_compnode_locator.stream;
  514. if (device_type == LiteDeviceType::LITE_CPU) {
  515. for (auto&& out : m_network_io->outputs) {
  516. out.is_host = true;
  517. }
  518. }
  519. //! delete the output that is not the network
  520. for (auto out_it = m_network_io->outputs.begin();
  521. out_it != m_network_io->outputs.end();) {
  522. if (std::find_if(
  523. m_load_result.output_var_list.begin(),
  524. m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
  525. return var.node()->name() == out_it->name;
  526. }) == m_load_result.output_var_list.end()) {
  527. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  528. out_it = m_network_io->outputs.erase(out_it);
  529. } else {
  530. out_it++;
  531. }
  532. }
  533. //! user config the output tensor, so only compute the config output
  534. if (m_compute_configured_output_only) {
  535. LITE_ASSERT(
  536. m_network_io->outputs.size() > 0,
  537. "compute configured output only with no configure output.");
  538. for (auto out_it = m_network_io->outputs.begin();
  539. out_it != m_network_io->outputs.end(); out_it++) {
  540. //! use pinned memory to copy form device
  541. if (out_it->is_host) {
  542. out_it->lite_tensor = std::make_shared<Tensor>(
  543. device_id, stream_id, device_type, true);
  544. } else {
  545. out_it->lite_tensor =
  546. std::make_shared<Tensor>(device_id, stream_id, device_type);
  547. }
  548. SymbolVar var;
  549. for (auto&& out_var : m_load_result.output_var_list) {
  550. if (out_var.node()->name() == out_it->name) {
  551. var = out_var;
  552. break;
  553. }
  554. }
  555. try_infer_tensor_layout(out_it->lite_tensor, var);
  556. output_tensor_copy_optimize(var, out_it->lite_tensor);
  557. }
  558. //! user not set, use default output
  559. } else {
  560. for (auto&& out : m_load_result.output_var_list) {
  561. std::shared_ptr<Tensor> lite_tensor = nullptr;
  562. auto it = std::find_if(
  563. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  564. [&out](const IOInner io) { return io.name == out.node()->name(); });
  565. if (it != m_network_io->outputs.end()) {
  566. if (it->is_host) {
  567. it->lite_tensor = std::make_shared<Tensor>(
  568. device_id, stream_id, device_type, true);
  569. } else {
  570. it->lite_tensor =
  571. std::make_shared<Tensor>(device_id, stream_id, device_type);
  572. }
  573. try_infer_tensor_layout(it->lite_tensor, out);
  574. lite_tensor = it->lite_tensor;
  575. } else {
  576. IOInner output;
  577. output.name = out.node()->name();
  578. output.lite_tensor = std::make_shared<Tensor>(
  579. device_id, stream_id, device_type, true);
  580. m_network_io->outputs.push_back({output});
  581. try_infer_tensor_layout(output.lite_tensor, out);
  582. lite_tensor = output.lite_tensor;
  583. }
  584. output_tensor_copy_optimize(out, lite_tensor);
  585. }
  586. }
  587. }
  588. void NetworkImplDft::output_tensor_copy_optimize(
  589. Var var, std::shared_ptr<Tensor> tensor) {
  590. LITE_ASSERT(
  591. !(m_user_config->options.force_output_use_user_specified_memory &&
  592. m_user_config->options.force_output_dynamic_alloc),
  593. "Can't set force_output_use_user_specified_memory and "
  594. "force_output_dynamic_alloc at the same time.");
  595. if (m_user_config->options.force_output_use_user_specified_memory) {
  596. TensorHelper::implement(tensor)
  597. ->cast_final_safe<TensorImplDft>()
  598. .set_reset_callback([var](TensorImplDft* dft_tensor) {
  599. dft_tensor->device_share_host_memory();
  600. auto dv = dft_tensor->dev_tensor().get();
  601. dv->comp_node(var.node()->comp_node(), true);
  602. var.node()->init_mem_plan(dv);
  603. var.node()->reset_dev_tensor_from_tensor(*dv);
  604. });
  605. }
  606. if (m_user_config->options.force_output_dynamic_alloc) {
  607. TensorHelper::implement(tensor)
  608. ->cast_final_safe<TensorImplDft>()
  609. .set_get_memory_callback([var](TensorImplDft* dft_tensor) {
  610. if (dft_tensor->is_host()) {
  611. auto host_tensor = dft_tensor->m_host_tensor;
  612. *host_tensor =
  613. HostTensorND::make_proxy(var.node()->dev_tensor());
  614. } else {
  615. auto dev_tensor = dft_tensor->m_dev_tensor;
  616. *dev_tensor = var.node()->dev_tensor();
  617. }
  618. });
  619. }
  620. }
  621. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  622. std::string io_name, LiteTensorPhase phase) {
  623. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  624. for (auto&& config_in : m_network_io->inputs) {
  625. if (io_name == config_in.name) {
  626. return config_in.lite_tensor;
  627. }
  628. }
  629. }
  630. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  631. for (auto&& config_out : m_network_io->outputs) {
  632. if (io_name == config_out.name) {
  633. config_out.lite_tensor->update_from_implement();
  634. return config_out.lite_tensor;
  635. }
  636. }
  637. }
  638. LITE_THROW(mgb::ssprintf(
  639. "tensor name must be %s input tensor name or the registered "
  640. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  641. "the output tensor is all the network output tensor, or the output "
  642. "tensor is only the registered tensor.",
  643. io_name.c_str()));
  644. return nullptr;
  645. }
  646. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  647. return get_io_tensor(get_input_name(index));
  648. }
  649. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  650. return get_io_tensor(get_output_name(index));
  651. }
  652. //! set opr algorithm selection strategy in the network
  653. void NetworkImplDft::set_network_algo_policy(
  654. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  655. bool binary_equal_between_batch) {
  656. using S = megdnn::param::ExecutionPolicy::Strategy;
  657. auto dst_strategy = static_cast<S>(0);
  658. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  659. dst_strategy = dst_strategy | S::HEURISTIC;
  660. }
  661. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  662. dst_strategy = dst_strategy | S::PROFILE;
  663. }
  664. if (static_cast<uint32_t>(strategy) &
  665. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  666. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  667. }
  668. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  669. dst_strategy = dst_strategy | S::OPTIMIZED;
  670. }
  671. m_execution_policy = dst_strategy;
  672. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  673. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  674. fast_run_config.shared_batch_size = shared_batch_size;
  675. if (m_execute_func) {
  676. LITE_WARN(
  677. "set_network_algo_policy maybe cause error after loaded "
  678. "network!!!!");
  679. modify_exection_policy();
  680. }
  681. }
  682. void NetworkImplDft::modify_exection_policy() {
  683. mgb::SymbolVarArray vars;
  684. for (auto i : m_output_spec) {
  685. vars.push_back(i.first);
  686. }
  687. if (static_cast<uint32_t>(m_execution_policy) != 0)
  688. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  689. }
  690. //! set opr algorithm selection strategy in the network
  691. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  692. mgb::SymbolVarArray vars;
  693. for (auto i : m_output_spec) {
  694. vars.push_back(i.first);
  695. }
  696. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  697. }
  698. //! get the input tensor name in the order of graph
  699. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  700. std::vector<const char*> output_names;
  701. for (auto& output : m_network_io->outputs) {
  702. output_names.push_back(output.name.c_str());
  703. }
  704. return output_names;
  705. }
  706. //! get the input tensor name in the order of graph
  707. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  708. std::vector<const char*> input_names;
  709. for (auto& input : m_load_result.tensor_map) {
  710. input_names.push_back(input.first.c_str());
  711. }
  712. return input_names;
  713. }
  714. //! get the output tensor name in the order of graph
  715. const char* NetworkImplDft::get_output_name(size_t index) const {
  716. LITE_ASSERT(
  717. index < m_load_result.output_var_list.size(),
  718. "The output tensor index is large than the total outputs number.");
  719. return m_load_result.output_var_list[index].node()->name().c_str();
  720. }
  721. //! get the input tensor name in the order of graph
  722. const char* NetworkImplDft::get_input_name(size_t index) const {
  723. LITE_ASSERT(
  724. index < m_load_result.tensor_map.size(),
  725. "The input tensor index is large than the total inputs number.");
  726. size_t i = 0;
  727. for (auto& input : m_load_result.tensor_map) {
  728. if (i == index) {
  729. return input.first.c_str();
  730. }
  731. i++;
  732. }
  733. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  734. }
  735. //! Plugin part
  736. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  737. #if MGB_ENABLE_JSON
  738. #if MGB_OPENCL
  739. mgb::CompNode::enable_opencl_profile(true);
  740. #endif
  741. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  742. m_profiler_output_file = profile_json_file;
  743. #else
  744. LITE_MARK_USED_VAR(profile_json_file);
  745. LITE_THROW("JSON is disable at compile time.");
  746. #endif
  747. }
  748. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  749. auto iodump = std::make_unique<mgb::TextOprIODump>(
  750. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  751. iodump->print_addr(false);
  752. m_iodump = std::move(iodump);
  753. }
  754. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  755. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  756. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  757. }
  758. void inline NetworkImplDft::output_plugin_result() const {
  759. #if MGB_ENABLE_JSON
  760. if (m_profiler && m_execute_func) {
  761. m_profiler->to_json_full(m_execute_func.get())
  762. ->writeto_fpath(m_profiler_output_file);
  763. }
  764. #endif
  765. }
  766. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  767. #ifndef __IN_TEE_ENV__
  768. #if MGB_ENABLE_JSON
  769. m_execute_func->get_static_memory_alloc_info(log_dir);
  770. return;
  771. #endif
  772. #endif
  773. LITE_MARK_USED_VAR(log_dir);
  774. }
  775. #endif
  776. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台