You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. /**
  2. * \file src/network.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/network.h"
  12. #include "function_base.h"
  13. #include "network_impl_base.h"
  14. #include "parse_info/parse_info_base.h"
  15. #include "parse_model/model_parser.h"
  16. #include "type_info.h"
  17. #if LITE_BUILD_WITH_MGE
  18. #include "mge/function_dft.h"
  19. #include "mge/network_impl.h"
  20. #endif
  21. #include <fstream>
  22. #include <memory>
  23. using namespace lite;
  24. /**
  25. * \brief Construct the new work implement
  26. * the order must be :
  27. * 1. creeat the implement
  28. * 2. config and load
  29. * 3. set_io
  30. */
  31. Network::Network(const Config& config, const NetworkIO& network_io) {
  32. LITE_ERROR_HANDLER_BEGIN
  33. m_config = config;
  34. m_network_io = network_io;
  35. if (config.backend == LiteBackend::LITE_DEFAULT) {
  36. m_impl = call_func<
  37. NetworkImplDft, std::unique_ptr<lite::Network::NetworkImplBase>>(
  38. "create_network");
  39. }
  40. m_impl->set_config(config);
  41. m_impl->set_io(network_io);
  42. LITE_ERROR_HANDLER_END
  43. }
  44. Network::Network(const NetworkIO& network_io, const Config& config) {
  45. LITE_ERROR_HANDLER_BEGIN
  46. m_config = config;
  47. m_network_io = network_io;
  48. if (config.backend == LiteBackend::LITE_DEFAULT) {
  49. m_impl = call_func<
  50. NetworkImplDft, std::unique_ptr<lite::Network::NetworkImplBase>>(
  51. "create_network");
  52. }
  53. m_impl->set_config(config);
  54. m_impl->set_io(network_io);
  55. LITE_ERROR_HANDLER_END
  56. }
  57. void Network::load_model(void* model_mem, size_t size) {
  58. LITE_ERROR_HANDLER_BEGIN
  59. LITE_CHECK_NON_NULL_POINTER(m_impl);
  60. //! this model_mem is managed by user
  61. std::shared_ptr<void> model{model_mem, [](void*) {}};
  62. prase_model(model, size);
  63. LITE_ERROR_HANDLER_END
  64. }
  65. void Network::load_model(std::string model_path) {
  66. LITE_ERROR_HANDLER_BEGIN
  67. LITE_CHECK_NON_NULL_POINTER(m_impl);
  68. FILE* fin = fopen(model_path.c_str(), "rb");
  69. LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  70. fseek(fin, 0, SEEK_END);
  71. size_t size = ftell(fin);
  72. fseek(fin, 0, SEEK_SET);
  73. void* ptr = malloc(size);
  74. std::shared_ptr<void> buf{ptr, ::free};
  75. auto nr = fread(buf.get(), 1, size, fin);
  76. LITE_ASSERT(nr == size);
  77. fclose(fin);
  78. prase_model(buf, size);
  79. LITE_ERROR_HANDLER_END
  80. }
  81. void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
  82. std::unordered_map<std::string, LiteAny> separate_config_map;
  83. ModelParser model_parser(model_data, size);
  84. //! parse the model info
  85. if (model_parser.parse_model_info(
  86. m_config, m_network_io, separate_config_map, m_extra_info)) {
  87. if (m_config.backend == LiteBackend::LITE_DEFAULT &&
  88. m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
  89. m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
  90. "parse_model"));
  91. }
  92. m_impl->set_config(m_config);
  93. m_impl->set_io(m_network_io);
  94. }
  95. //! decryption the model
  96. size_t model_length;
  97. auto&& model_shared_ptr = model_parser.parse_model(model_length, m_config);
  98. m_impl->load_model(model_shared_ptr, model_length, separate_config_map);
  99. m_loaded = true;
  100. update_from_implement();
  101. }
  102. Network::~Network() = default;
  103. void Network::update_from_implement() {
  104. m_config.device_type = m_impl->get_device_type();
  105. }
  106. void Network::compute_only_configured_output() {
  107. LITE_ERROR_HANDLER_BEGIN
  108. LITE_ASSERT(
  109. !m_loaded,
  110. "compute_only_configured_output should be used before model "
  111. "loaded.");
  112. LITE_CHECK_NON_NULL_POINTER(m_impl);
  113. return m_impl->compute_only_configured_output();
  114. LITE_ERROR_HANDLER_END
  115. }
  116. std::shared_ptr<Tensor> Network::get_io_tensor(
  117. std::string name, LiteTensorPhase phase) {
  118. LITE_ERROR_HANDLER_BEGIN
  119. LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded.");
  120. LITE_CHECK_NON_NULL_POINTER(m_impl);
  121. return m_impl->get_io_tensor(name, phase);
  122. LITE_ERROR_HANDLER_END
  123. }
  124. std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) {
  125. LITE_ERROR_HANDLER_BEGIN
  126. LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded.");
  127. LITE_CHECK_NON_NULL_POINTER(m_impl);
  128. return m_impl->get_input_tensor(index);
  129. LITE_ERROR_HANDLER_END
  130. }
  131. std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) {
  132. LITE_ERROR_HANDLER_BEGIN
  133. LITE_ASSERT(m_loaded, "get_output_tensor should be used after model loaded.");
  134. LITE_CHECK_NON_NULL_POINTER(m_impl);
  135. return m_impl->get_output_tensor(index);
  136. LITE_ERROR_HANDLER_END
  137. }
  138. Network& Network::set_async_callback(const AsyncCallback& callback) {
  139. LITE_ERROR_HANDLER_BEGIN
  140. LITE_CHECK_NON_NULL_POINTER(m_impl);
  141. m_impl->set_async_callback(std::move(callback));
  142. return *this;
  143. LITE_ERROR_HANDLER_END
  144. }
  145. Network& Network::set_start_callback(const StartCallback& callback) {
  146. LITE_ERROR_HANDLER_BEGIN
  147. LITE_CHECK_NON_NULL_POINTER(m_impl);
  148. m_impl->set_start_callback(std::move(callback));
  149. return *this;
  150. LITE_ERROR_HANDLER_END
  151. }
  152. Network& Network::set_finish_callback(const FinishCallback& callback) {
  153. LITE_ERROR_HANDLER_BEGIN
  154. LITE_CHECK_NON_NULL_POINTER(m_impl);
  155. m_impl->set_finish_callback(std::move(callback));
  156. return *this;
  157. LITE_ERROR_HANDLER_END
  158. }
  159. Network& Network::set_device_id(int device_id) {
  160. LITE_ERROR_HANDLER_BEGIN
  161. LITE_ASSERT(!m_loaded, "set_device_id should be used before model loaded.");
  162. LITE_CHECK_NON_NULL_POINTER(m_impl);
  163. m_impl->set_device_id(device_id);
  164. return *this;
  165. LITE_ERROR_HANDLER_END
  166. }
  167. Network& Network::set_stream_id(int stream_id) {
  168. LITE_ERROR_HANDLER_BEGIN
  169. LITE_ASSERT(!m_loaded, "set_stream_id should be used before model loaded.");
  170. LITE_CHECK_NON_NULL_POINTER(m_impl);
  171. m_impl->set_stream_id(stream_id);
  172. return *this;
  173. LITE_ERROR_HANDLER_END
  174. }
  175. void Network::forward() {
  176. LITE_ERROR_HANDLER_BEGIN
  177. LITE_ASSERT(m_loaded, "forward should be used after model loaded.");
  178. LITE_CHECK_NON_NULL_POINTER(m_impl.get());
  179. m_impl->forward();
  180. LITE_ERROR_HANDLER_END
  181. }
  182. void Network::wait() {
  183. LITE_ERROR_HANDLER_BEGIN
  184. LITE_ASSERT(m_loaded, "wait should be used after model loaded.");
  185. LITE_CHECK_NON_NULL_POINTER(m_impl);
  186. m_impl->wait();
  187. LITE_ERROR_HANDLER_END
  188. }
  189. std::string Network::get_input_name(size_t index) const {
  190. LITE_ERROR_HANDLER_BEGIN
  191. LITE_ASSERT(m_loaded, "get_input_name should be used after model loaded.");
  192. LITE_CHECK_NON_NULL_POINTER(m_impl);
  193. return m_impl->get_input_name(index);
  194. LITE_ERROR_HANDLER_END
  195. }
  196. std::string Network::get_output_name(size_t index) const {
  197. LITE_ERROR_HANDLER_BEGIN
  198. LITE_ASSERT(m_loaded, "get_output_name should be used after model loaded.");
  199. LITE_CHECK_NON_NULL_POINTER(m_impl);
  200. return m_impl->get_output_name(index);
  201. LITE_ERROR_HANDLER_END
  202. }
  203. std::vector<std::string> Network::get_all_input_name() const {
  204. LITE_ERROR_HANDLER_BEGIN
  205. LITE_ASSERT(m_loaded, "get_all_input_name should be used after model loaded.");
  206. LITE_CHECK_NON_NULL_POINTER(m_impl);
  207. auto all_input_name = m_impl->get_all_input_name();
  208. std::vector<std::string> all_names;
  209. for (auto& name : all_input_name) {
  210. all_names.push_back(name);
  211. }
  212. return all_names;
  213. LITE_ERROR_HANDLER_END
  214. }
  215. std::vector<std::string> Network::get_all_output_name() const {
  216. LITE_ERROR_HANDLER_BEGIN
  217. LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded.");
  218. LITE_CHECK_NON_NULL_POINTER(m_impl);
  219. auto all_output_name = m_impl->get_all_output_name();
  220. std::vector<std::string> all_names;
  221. for (auto& name : all_output_name) {
  222. all_names.push_back(name);
  223. }
  224. return all_names;
  225. LITE_ERROR_HANDLER_END
  226. }
  227. int Network::get_device_id() const {
  228. LITE_ERROR_HANDLER_BEGIN
  229. LITE_CHECK_NON_NULL_POINTER(m_impl);
  230. return m_impl->get_device_id();
  231. LITE_ERROR_HANDLER_END
  232. }
  233. int Network::get_stream_id() const {
  234. LITE_ERROR_HANDLER_BEGIN
  235. LITE_CHECK_NON_NULL_POINTER(m_impl);
  236. return m_impl->get_stream_id();
  237. LITE_ERROR_HANDLER_END
  238. }
  239. void Network::enable_profile_performance(std::string profile_file_path) {
  240. LITE_ERROR_HANDLER_BEGIN
  241. m_impl->enable_profile_performance(profile_file_path);
  242. LITE_ERROR_HANDLER_END
  243. }
  244. const std::string& Network::get_model_extra_info() {
  245. LITE_ERROR_HANDLER_BEGIN
  246. return m_extra_info;
  247. LITE_ERROR_HANDLER_END
  248. }
  249. LiteDeviceType Network::get_device_type() const {
  250. LITE_ERROR_HANDLER_BEGIN
  251. return m_impl->get_device_type();
  252. LITE_ERROR_HANDLER_END
  253. }
  254. /*********************** MGE special network function ***************/
  255. void Runtime::set_cpu_threads_number(
  256. std::shared_ptr<Network> network, size_t nr_threads) {
  257. LITE_ERROR_HANDLER_BEGIN
  258. auto network_impl = NetworkHelper::implement(network);
  259. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  260. LITE_ASSERT(
  261. !NetworkHelper::loaded(network),
  262. "set_cpu_threads_number should be used before model loaded.");
  263. call_func<NetworkImplDft, void>(
  264. "set_cpu_threads_number", network_impl, nr_threads);
  265. return;
  266. }
  267. LITE_THROW("set_cpu_threads_number is not aviliable in the backend.");
  268. LITE_ERROR_HANDLER_END
  269. }
  270. void Runtime::use_tensorrt(std::shared_ptr<Network> network) {
  271. LITE_ERROR_HANDLER_BEGIN
  272. auto network_impl = NetworkHelper::implement(network);
  273. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  274. LITE_ASSERT(
  275. !NetworkHelper::loaded(network),
  276. "use_tensorrt should be used before model loaded.");
  277. call_func<NetworkImplDft, void>("use_tensorrt", network_impl);
  278. return;
  279. }
  280. LITE_THROW("use_tensorrt is not aviliable in the backend.");
  281. LITE_ERROR_HANDLER_END
  282. }
  283. size_t Runtime::get_cpu_threads_number(const std::shared_ptr<Network> network) {
  284. LITE_ERROR_HANDLER_BEGIN
  285. auto network_impl = NetworkHelper::implement(network);
  286. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  287. return call_func<NetworkImplDft, size_t>(
  288. "get_cpu_threads_number", network_impl);
  289. }
  290. LITE_THROW("get_cpu_threads_number is not aviliable in the backend.");
  291. LITE_ERROR_HANDLER_END
  292. }
  293. void Runtime::set_runtime_thread_affinity(
  294. std::shared_ptr<Network> network,
  295. const ThreadAffinityCallback& thread_affinity_callback) {
  296. LITE_ERROR_HANDLER_BEGIN
  297. auto network_impl = NetworkHelper::implement(network);
  298. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  299. LITE_ASSERT(
  300. NetworkHelper::loaded(network),
  301. "set_runtime_thread_affinity should be used after model "
  302. "loaded.");
  303. call_func<NetworkImplDft, void>(
  304. "set_runtime_thread_affinity", network_impl, thread_affinity_callback);
  305. return;
  306. }
  307. LITE_THROW("set_runtime_thread_affinity is not aviliable in the backend.");
  308. LITE_ERROR_HANDLER_END
  309. }
  310. void Runtime::set_cpu_inplace_mode(std::shared_ptr<Network> network) {
  311. LITE_ERROR_HANDLER_BEGIN
  312. auto network_impl = NetworkHelper::implement(network);
  313. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  314. LITE_ASSERT(
  315. !NetworkHelper::loaded(network),
  316. "set_cpu_inplace_mode should be used before model loaded.");
  317. call_func<NetworkImplDft, void>("set_cpu_inplace_mode", network_impl);
  318. return;
  319. }
  320. LITE_THROW("set_cpu_inplace_mode is not aviliable in the backend.");
  321. LITE_ERROR_HANDLER_END
  322. }
  323. bool Runtime::is_cpu_inplace_mode(const std::shared_ptr<Network> network) {
  324. LITE_ERROR_HANDLER_BEGIN
  325. auto network_impl = NetworkHelper::implement(network);
  326. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  327. return call_func<NetworkImplDft, bool>("is_cpu_inplace_mode", network_impl);
  328. }
  329. LITE_THROW("is_cpu_inplace_mode is not aviliable in the backend.");
  330. LITE_ERROR_HANDLER_END
  331. }
  332. //! set opr algorithm selection strategy in the network
  333. void Runtime::set_network_algo_policy(
  334. std::shared_ptr<Network> network, LiteAlgoSelectStrategy strategy,
  335. uint32_t shared_batch_size, bool binary_equal_between_batch) {
  336. LITE_ERROR_HANDLER_BEGIN
  337. auto network_impl = NetworkHelper::implement(network);
  338. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  339. call_func<NetworkImplDft, void>(
  340. "set_network_algo_policy", network_impl, strategy, shared_batch_size,
  341. binary_equal_between_batch);
  342. return;
  343. }
  344. LITE_THROW("set_network_algo_policy is not aviliable in the backend.");
  345. LITE_ERROR_HANDLER_END
  346. }
  347. //! set opr algorithm selection strategy in the network
  348. void Runtime::set_network_algo_workspace_limit(
  349. std::shared_ptr<Network> network, size_t workspace_limit) {
  350. LITE_ERROR_HANDLER_BEGIN
  351. auto network_impl = NetworkHelper::implement(network);
  352. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  353. LITE_ASSERT(
  354. NetworkHelper::loaded(network),
  355. "set_network_algo_policy should be used after model "
  356. "loaded.");
  357. call_func<NetworkImplDft, void>(
  358. "set_network_algo_workspace_limit", network_impl, workspace_limit);
  359. return;
  360. }
  361. LITE_THROW(
  362. "set_network_algo_workspace_limit is not aviliable in the "
  363. "backend.");
  364. LITE_ERROR_HANDLER_END
  365. }
  366. //! set the network memroy allocator, the allocator is defined by user
  367. void Runtime::set_memory_allocator(
  368. std::shared_ptr<Network> network, std::shared_ptr<Allocator> user_allocator) {
  369. LITE_ERROR_HANDLER_BEGIN
  370. auto network_impl = NetworkHelper::implement(network);
  371. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  372. LITE_ASSERT(
  373. !NetworkHelper::loaded(network),
  374. "set_memory_allocator should be used before model loaded.");
  375. call_func<NetworkImplDft, void>(
  376. "set_memory_allocator", network_impl, user_allocator);
  377. return;
  378. }
  379. LITE_THROW("set_memory_allocator is not aviliable in the backend.");
  380. LITE_ERROR_HANDLER_END
  381. }
  382. void Runtime::share_runtime_memory_with(
  383. std::shared_ptr<Network> dst_network, std::shared_ptr<Network> src_network) {
  384. LITE_ERROR_HANDLER_BEGIN
  385. auto network_impl_dst = NetworkHelper::implement(dst_network);
  386. if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  387. LITE_ASSERT(
  388. !NetworkHelper::loaded(dst_network),
  389. "share_runtime_memory_with should be used before model "
  390. "loaded.");
  391. call_func<NetworkImplDft, void>(
  392. "share_runtime_memory_with", network_impl_dst,
  393. NetworkHelper::implement(src_network));
  394. return;
  395. }
  396. LITE_THROW("share_runtime_memory_with is not aviliable in the backend.");
  397. LITE_ERROR_HANDLER_END
  398. }
  399. void Runtime::enable_io_txt_dump(
  400. std::shared_ptr<Network> network, std::string io_txt_out_file) {
  401. LITE_ERROR_HANDLER_BEGIN
  402. auto network_impl = NetworkHelper::implement(network);
  403. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  404. call_func<NetworkImplDft, void>(
  405. "enable_io_txt_dump", network_impl, io_txt_out_file);
  406. return;
  407. }
  408. LITE_THROW("enable_io_txt_dump is not aviliable in the backend.");
  409. LITE_ERROR_HANDLER_END
  410. }
  411. void Runtime::enable_io_bin_dump(
  412. std::shared_ptr<Network> network, std::string io_bin_out_dir) {
  413. LITE_ERROR_HANDLER_BEGIN
  414. auto network_impl = NetworkHelper::implement(network);
  415. if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  416. call_func<NetworkImplDft, void>(
  417. "enable_io_bin_dump", network_impl, io_bin_out_dir);
  418. return;
  419. }
  420. LITE_THROW("enable_io_bin_dump is not aviliable in the backend.");
  421. LITE_ERROR_HANDLER_END
  422. }
  423. void Runtime::shared_weight_with_network(
  424. std::shared_ptr<Network> dst_network,
  425. const std::shared_ptr<Network> src_network) {
  426. LITE_ERROR_HANDLER_BEGIN
  427. auto network_impl_dst = NetworkHelper::implement(dst_network);
  428. if (network_impl_dst->get_backend_type() == LiteBackend::LITE_DEFAULT) {
  429. LITE_ASSERT(
  430. NetworkHelper::loaded(src_network),
  431. "shared_weight_with_network should be used after the src "
  432. "network "
  433. "loaded.");
  434. auto src_implment = NetworkHelper::implement(src_network);
  435. call_func<NetworkImplDft, void>(
  436. "shared_weight_with", network_impl_dst, src_implment);
  437. NetworkHelper::loaded(dst_network, true);
  438. return;
  439. }
  440. LITE_THROW("shared_weight_with_network is not aviliable in the backend.");
  441. LITE_ERROR_HANDLER_END
  442. }
  443. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台