You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. /**
  2. * \file lite-c/src/network.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/network.h"
  12. #include "common.h"
  13. #include "lite-c/network_c.h"
  14. #include "../../src/network_impl_base.h"
  15. #include <string.h>
  16. #include <memory>
  17. #include <mutex>
  18. #include <unordered_map>
  19. //! define a default Options
  20. const LiteOptions default_option = {
  21. .weight_preprocess = false,
  22. .fuse_preprocess = false,
  23. .fake_next_exec = false,
  24. .var_sanity_check_first_run = true,
  25. .const_shape = false,
  26. .force_dynamic_alloc = false,
  27. .force_output_dynamic_alloc = false,
  28. .force_output_use_user_specified_memory = false,
  29. .no_profiling_on_shape_change = false,
  30. .jit_level = 0,
  31. .comp_node_seq_record_level = 0,
  32. .graph_opt_level = 2,
  33. .async_exec_level = 1,
  34. //! layout transform options
  35. .enable_nchw44 = 0,
  36. .enable_nchw44_dot = 0,
  37. .enable_nchw88 = 0,
  38. .enable_nhwcd4 = 0,
  39. .enable_nchw4 = 0,
  40. .enable_nchw32 = 0,
  41. .enable_nchw64 = 0,
  42. };
  43. //! define a default config
  44. LiteConfig default_config_t = {
  45. .has_compression = false,
  46. .device_id = -1,
  47. .device_type = LiteDeviceType::LITE_CPU,
  48. .backend = LiteBackend::LITE_DEFAULT,
  49. .bare_model_cryption_name = nullptr,
  50. .options = default_option};
  51. LiteConfig* default_config() {
  52. return &default_config_t;
  53. }
  54. //! define a default IO
  55. const LiteIO default_io = {
  56. .name = nullptr,
  57. .is_host = true,
  58. .io_type = LiteIOType::LITE_IO_VALUE,
  59. .config_layout = default_layout};
  60. //! define a default NetworkIO
  61. LiteNetworkIO default_network_io_t = {
  62. .inputs = nullptr, .outputs = nullptr, .input_size = 0, .output_size = 0};
  63. LiteNetworkIO* default_network_io() {
  64. return &default_network_io_t;
  65. }
  66. namespace {
  67. static LITE_MUTEX mtx_network;
  68. std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() {
  69. static std::unordered_map<void*, std::shared_ptr<lite::Network>> network_holder;
  70. return network_holder;
  71. }
  72. /*!
  73. * \brief A user-implemented allocator interface
  74. */
  75. class UserAllocator : public lite::Allocator {
  76. public:
  77. UserAllocator(LiteAllocate allocate_func, LiteFree free_func)
  78. : m_allocator(allocate_func), m_free(free_func) {
  79. LITE_ASSERT(m_allocator && m_free);
  80. }
  81. //! allocate memory of size in the given device with the given align
  82. void* allocate(LiteDeviceType device_type, int device_id, size_t size, size_t align)
  83. override {
  84. return m_allocator(device_type, device_id, size, align);
  85. }
  86. //! free the memory pointed by ptr in the given device
  87. void free(LiteDeviceType device_type, int device_id, void* ptr) override {
  88. m_free(device_type, device_id, ptr);
  89. }
  90. private:
  91. LiteAllocate m_allocator;
  92. LiteFree m_free;
  93. };
  94. } // namespace
  95. //! convert c config to lite::config
  96. lite::Config convert_to_lite_config(const LiteConfig c_config) {
  97. lite::Config lite_config;
  98. lite_config.device_type = c_config.device_type;
  99. if (c_config.bare_model_cryption_name) {
  100. lite_config.bare_model_cryption_name = c_config.bare_model_cryption_name;
  101. }
  102. lite_config.backend = c_config.backend;
  103. lite_config.has_compression = c_config.has_compression;
  104. lite_config.device_id = c_config.device_id;
  105. lite_config.options.weight_preprocess = c_config.options.weight_preprocess;
  106. lite_config.options.fuse_preprocess = c_config.options.fuse_preprocess;
  107. lite_config.options.fake_next_exec = c_config.options.fake_next_exec;
  108. lite_config.options.var_sanity_check_first_run =
  109. c_config.options.var_sanity_check_first_run;
  110. lite_config.options.const_shape = c_config.options.const_shape;
  111. lite_config.options.force_dynamic_alloc = c_config.options.force_dynamic_alloc;
  112. lite_config.options.force_output_use_user_specified_memory =
  113. c_config.options.force_output_use_user_specified_memory;
  114. lite_config.options.force_output_dynamic_alloc =
  115. c_config.options.force_output_dynamic_alloc;
  116. lite_config.options.no_profiling_on_shape_change =
  117. c_config.options.no_profiling_on_shape_change;
  118. lite_config.options.jit_level = c_config.options.jit_level;
  119. lite_config.options.comp_node_seq_record_level =
  120. c_config.options.comp_node_seq_record_level;
  121. lite_config.options.graph_opt_level = c_config.options.graph_opt_level;
  122. lite_config.options.async_exec_level = c_config.options.async_exec_level;
  123. lite_config.options.enable_nchw44 = c_config.options.enable_nchw44;
  124. lite_config.options.enable_nchw44_dot = c_config.options.enable_nchw44_dot;
  125. lite_config.options.enable_nchw88 = c_config.options.enable_nchw88;
  126. lite_config.options.enable_nchw4 = c_config.options.enable_nchw4;
  127. lite_config.options.enable_nhwcd4 = c_config.options.enable_nhwcd4;
  128. lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
  129. lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
  130. return lite_config;
  131. }
  132. //! convert C NetworkIO io to lite::NetworkIO
  133. lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
  134. lite::NetworkIO network_io;
  135. for (size_t i = 0; i < c_network_io.input_size; i++) {
  136. LiteIO* c_io = c_network_io.inputs + i;
  137. LITE_ASSERT(c_io->name, "input name of io tensor must set.");
  138. network_io.inputs.push_back(
  139. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  140. convert_to_layout(c_io->config_layout)});
  141. }
  142. for (size_t i = 0; i < c_network_io.output_size; i++) {
  143. LiteIO* c_io = c_network_io.outputs + i;
  144. LITE_ASSERT(c_io->name, "output name of io tensor must set.");
  145. network_io.outputs.push_back(
  146. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  147. convert_to_layout(c_io->config_layout)});
  148. }
  149. return network_io;
  150. }
  151. struct InnerIO {
  152. std::vector<std::string> names;
  153. std::vector<LiteIO> inputs;
  154. std::vector<LiteIO> outputs;
  155. };
  156. InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
  157. InnerIO innner_io;
  158. for (size_t i = 0; i < network_io.inputs.size(); i++) {
  159. lite::IO io = network_io.inputs[i];
  160. innner_io.names.push_back(io.name);
  161. innner_io.inputs.push_back(
  162. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  163. convert_to_clayout(io.config_layout)});
  164. }
  165. for (size_t i = 0; i < network_io.outputs.size(); i++) {
  166. lite::IO io = network_io.outputs[i];
  167. innner_io.names.push_back(io.name);
  168. innner_io.outputs.push_back(
  169. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  170. convert_to_clayout(io.config_layout)});
  171. }
  172. return innner_io;
  173. }
  174. int LITE_make_default_network(LiteNetwork* network) {
  175. LITE_CAPI_BEGIN();
  176. LITE_ASSERT(network, "The network pass to LITE api is null");
  177. auto lite_network = std::make_shared<lite::Network>();
  178. LITE_LOCK_GUARD(mtx_network);
  179. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  180. *network = lite_network.get();
  181. LITE_CAPI_END();
  182. }
  183. int LITE_make_network(
  184. LiteNetwork* network, const LiteConfig config, const LiteNetworkIO network_io) {
  185. LITE_CAPI_BEGIN();
  186. LITE_ASSERT(network, "The network pass to LITE api is null");
  187. auto lite_network = std::make_shared<lite::Network>(
  188. convert_to_lite_config(config), convert_to_lite_io(network_io));
  189. LITE_LOCK_GUARD(mtx_network);
  190. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  191. *network = lite_network.get();
  192. LITE_CAPI_END();
  193. }
  194. int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
  195. LITE_CAPI_BEGIN();
  196. LITE_ASSERT(network, "The network pass to LITE api is null");
  197. auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config));
  198. LITE_LOCK_GUARD(mtx_network);
  199. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  200. *network = lite_network.get();
  201. LITE_CAPI_END();
  202. }
  203. int LITE_load_model_from_mem(LiteNetwork network, void* model_mem, size_t size) {
  204. LITE_CAPI_BEGIN();
  205. LITE_ASSERT(network, "The network pass to LITE api is null");
  206. LITE_ASSERT(model_mem, "The model memory pass to LITE api is null");
  207. static_cast<lite::Network*>(network)->load_model(model_mem, size);
  208. LITE_CAPI_END();
  209. }
  210. int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
  211. LITE_CAPI_BEGIN();
  212. LITE_ASSERT(network, "The network pass to LITE api is null");
  213. LITE_ASSERT(model_path, "The model path pass to LITE api is null");
  214. static_cast<lite::Network*>(network)->load_model(model_path);
  215. LITE_CAPI_END();
  216. }
  217. int LITE_destroy_network(LiteNetwork network) {
  218. LITE_CAPI_BEGIN();
  219. LITE_ASSERT(network, "The network pass to LITE api is null");
  220. LITE_LOCK_GUARD(mtx_network);
  221. get_gloabl_network_holder().erase(network);
  222. LITE_CAPI_END();
  223. }
  224. int LITE_forward(const LiteNetwork network) {
  225. LITE_CAPI_BEGIN();
  226. LITE_ASSERT(network, "The network pass to LITE api is null");
  227. static_cast<lite::Network*>(network)->forward();
  228. LITE_CAPI_END();
  229. }
  230. int LITE_wait(const LiteNetwork network) {
  231. LITE_CAPI_BEGIN();
  232. LITE_ASSERT(network, "The network pass to LITE api is null");
  233. static_cast<lite::Network*>(network)->wait();
  234. LITE_CAPI_END();
  235. }
  236. int LITE_get_io_tensor(
  237. LiteNetwork network, const char* io_name, LiteTensorPhase phase,
  238. LiteTensor* tensor) {
  239. LITE_CAPI_BEGIN();
  240. LITE_ASSERT(network, "The network pass to LITE api is null");
  241. auto io_tensor =
  242. static_cast<lite::Network*>(network)->get_io_tensor(io_name, phase);
  243. *tensor = io_tensor.get();
  244. LITE_CAPI_END();
  245. }
  246. int LITE_get_input_name(const LiteNetwork network, size_t index, const char** name) {
  247. LITE_CAPI_BEGIN();
  248. LITE_ASSERT(network && name, "The network pass to LITE api is null");
  249. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  250. ->get_input_name(index);
  251. LITE_CAPI_END();
  252. }
  253. int LITE_get_output_name(const LiteNetwork network, size_t index, const char** name) {
  254. LITE_CAPI_BEGIN();
  255. LITE_ASSERT(network, "The network pass to LITE api is null");
  256. LITE_ASSERT(name, "The name ptr pass to LITE api is null");
  257. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  258. ->get_output_name(index);
  259. LITE_CAPI_END();
  260. }
  261. int LITE_get_all_input_name(
  262. const LiteNetwork network, size_t* size, const char** name) {
  263. LITE_CAPI_BEGIN();
  264. LITE_ASSERT(network, "The network pass to LITE api is null");
  265. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  266. ->get_all_input_name();
  267. if (size)
  268. *size = names.size();
  269. if (name) {
  270. for (auto in_name : names) {
  271. *name = in_name;
  272. name++;
  273. }
  274. }
  275. LITE_CAPI_END();
  276. }
  277. int LITE_get_all_output_name(
  278. const LiteNetwork network, size_t* size, const char** name) {
  279. LITE_CAPI_BEGIN();
  280. LITE_ASSERT(network, "The network pass to LITE api is null");
  281. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  282. ->get_all_output_name();
  283. if (size)
  284. *size = names.size();
  285. if (name) {
  286. for (auto in_name : names) {
  287. *name = in_name;
  288. name++;
  289. }
  290. }
  291. LITE_CAPI_END();
  292. }
  293. int LITE_set_device_id(LiteNetwork network, int device_id) {
  294. LITE_CAPI_BEGIN();
  295. LITE_ASSERT(network, "The network pass to LITE api is null");
  296. static_cast<lite::Network*>(network)->set_device_id(device_id);
  297. LITE_CAPI_END();
  298. }
  299. int LITE_get_device_id(const LiteNetwork network, int* device_id) {
  300. LITE_CAPI_BEGIN();
  301. LITE_ASSERT(network, "The network pass to LITE api is null");
  302. LITE_ASSERT(device_id, "The device_id pass to LITE api is null");
  303. *device_id = static_cast<lite::Network*>(network)->get_device_id();
  304. LITE_CAPI_END();
  305. }
  306. int LITE_set_stream_id(LiteNetwork network, int stream_id) {
  307. LITE_CAPI_BEGIN();
  308. LITE_ASSERT(network, "The network pass to LITE api is null");
  309. static_cast<lite::Network*>(network)->set_stream_id(stream_id);
  310. LITE_CAPI_END();
  311. }
  312. int LITE_get_stream_id(const LiteNetwork network, int* stream_id) {
  313. LITE_CAPI_BEGIN();
  314. LITE_ASSERT(network, "The network pass to LITE api is null");
  315. LITE_ASSERT(stream_id, "The stream_id pass to LITE api is null");
  316. *stream_id = static_cast<lite::Network*>(network)->get_stream_id();
  317. LITE_CAPI_END();
  318. }
  319. int LITE_get_model_extra_info(
  320. const LiteNetwork network, const char** info, int* info_size) {
  321. LITE_CAPI_BEGIN();
  322. LITE_ASSERT(network, "The network pass to LITE api is null");
  323. LITE_ASSERT(info_size, "The info and info_size are all null");
  324. auto& extra_info = static_cast<lite::Network*>(network)->get_model_extra_info();
  325. *info_size = extra_info.size();
  326. *info = extra_info.c_str();
  327. LITE_MARK_USED_VAR(info);
  328. LITE_CAPI_END();
  329. }
  330. int LITE_get_device_type(const LiteNetwork network, LiteDeviceType* device_type) {
  331. LITE_CAPI_BEGIN();
  332. LITE_ASSERT(network, "The network pass to LITE api is null");
  333. LITE_ASSERT(device_type, "The device_type pass to LITE api is null");
  334. *device_type = static_cast<lite::Network*>(network)->get_device_type();
  335. LITE_CAPI_END();
  336. }
  337. int LITE_set_async_callback(
  338. LiteNetwork network, const LiteAsyncCallback async_callback) {
  339. LITE_CAPI_BEGIN();
  340. LITE_ASSERT(network, "The network pass to LITE api is null");
  341. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  342. static_cast<lite::Network*>(network)->set_async_callback(std::move(async_callback));
  343. LITE_CAPI_END();
  344. }
  345. int LITE_set_async_callback_with_userdata(
  346. LiteNetwork network, LiteAsyncCallbackWithData async_callback,
  347. void* user_data) {
  348. LITE_CAPI_BEGIN();
  349. LITE_ASSERT(network, "The network pass to LITE api is null");
  350. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  351. auto lite_async_callback = [async_callback, user_data]() -> void {
  352. async_callback(user_data);
  353. };
  354. static_cast<lite::Network*>(network)->set_async_callback(
  355. std::move(lite_async_callback));
  356. LITE_CAPI_END();
  357. }
  358. int LITE_set_start_callback(
  359. LiteNetwork network, const LiteStartCallback start_callback) {
  360. LITE_CAPI_BEGIN();
  361. LITE_ASSERT(network, "The network pass to LITE api is null");
  362. auto lite_start_callback =
  363. [start_callback](const std::unordered_map<
  364. std::string,
  365. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  366. inputs_map) -> void {
  367. std::vector<LiteIO> ios;
  368. std::vector<LiteTensor> io_tensors;
  369. size_t nr_io = 0;
  370. for (const auto& io : inputs_map) {
  371. nr_io++;
  372. auto&& lite_io = io.second.first;
  373. ios.push_back(
  374. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  375. convert_to_clayout(lite_io.config_layout)});
  376. io_tensors.push_back(io.second.second.get());
  377. }
  378. start_callback(ios.data(), io_tensors.data(), nr_io);
  379. };
  380. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  381. LITE_CAPI_END();
  382. }
  383. int LITE_set_start_callback_with_userdata(
  384. LiteNetwork network, const LiteStartCallbackWithData start_callback,
  385. void* user_data) {
  386. LITE_CAPI_BEGIN();
  387. LITE_ASSERT(network, "The network pass to LITE api is null");
  388. auto lite_start_callback =
  389. [start_callback,
  390. user_data](const std::unordered_map<
  391. std::string,
  392. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& inputs_map)
  393. -> void {
  394. std::vector<LiteIO> ios;
  395. std::vector<LiteTensor> io_tensors;
  396. size_t nr_io = 0;
  397. for (const auto& io : inputs_map) {
  398. nr_io++;
  399. auto&& lite_io = io.second.first;
  400. ios.push_back(
  401. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  402. convert_to_clayout(lite_io.config_layout)});
  403. io_tensors.push_back(io.second.second.get());
  404. }
  405. start_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  406. };
  407. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  408. LITE_CAPI_END();
  409. }
  410. int LITE_set_finish_callback(
  411. LiteNetwork network, const LiteFinishCallback finish_callback) {
  412. LITE_CAPI_BEGIN();
  413. LITE_ASSERT(network, "The network pass to LITE api is null");
  414. auto lite_finish_callback =
  415. [finish_callback](const std::unordered_map<
  416. std::string,
  417. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  418. outputs_map) -> void {
  419. std::vector<LiteIO> ios;
  420. std::vector<LiteTensor> io_tensors;
  421. size_t nr_io = 0;
  422. for (const auto& io : outputs_map) {
  423. nr_io++;
  424. auto&& lite_io = io.second.first;
  425. ios.push_back(
  426. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  427. convert_to_clayout(lite_io.config_layout)});
  428. io_tensors.push_back(io.second.second.get());
  429. }
  430. finish_callback(ios.data(), io_tensors.data(), nr_io);
  431. };
  432. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  433. LITE_CAPI_END();
  434. }
  435. int LITE_set_finish_callback_with_userdata(
  436. LiteNetwork network, const LiteFinishCallbackWithData finish_callback,
  437. void* user_data) {
  438. LITE_CAPI_BEGIN();
  439. LITE_ASSERT(network, "The network pass to LITE api is null");
  440. auto lite_finish_callback =
  441. [finish_callback,
  442. user_data](const std::unordered_map<
  443. std::string,
  444. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  445. outputs_map) -> void {
  446. std::vector<LiteIO> ios;
  447. std::vector<LiteTensor> io_tensors;
  448. size_t nr_io = 0;
  449. for (const auto& io : outputs_map) {
  450. nr_io++;
  451. auto&& lite_io = io.second.first;
  452. ios.push_back(
  453. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  454. convert_to_clayout(lite_io.config_layout)});
  455. io_tensors.push_back(io.second.second.get());
  456. }
  457. finish_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  458. };
  459. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  460. LITE_CAPI_END();
  461. }
  462. int LITE_enable_profile_performance(
  463. LiteNetwork network, const char* profile_json_file_path) {
  464. LITE_CAPI_BEGIN();
  465. LITE_ASSERT(network, "The network pass to LITE api is null");
  466. static_cast<lite::Network*>(network)->enable_profile_performance(
  467. profile_json_file_path);
  468. LITE_CAPI_END();
  469. }
  470. int LITE_is_cpu_inplace_mode(const LiteNetwork network, int* is_cpu_inplace_mode) {
  471. LITE_CAPI_BEGIN();
  472. LITE_ASSERT(network && is_cpu_inplace_mode, "The network pass to LITE api is null");
  473. std::shared_ptr<lite::Network> network_shared{
  474. static_cast<lite::Network*>(network), [](void*) {}};
  475. *is_cpu_inplace_mode = lite::Runtime::is_cpu_inplace_mode(network_shared);
  476. LITE_CAPI_END();
  477. }
  478. int LITE_get_cpu_threads_number(const LiteNetwork network, size_t* nr_threads) {
  479. LITE_CAPI_BEGIN();
  480. LITE_ASSERT(network, "The network pass to LITE api is null");
  481. LITE_ASSERT(nr_threads, "The ptr pass to LITE api is null");
  482. std::shared_ptr<lite::Network> network_shared{
  483. static_cast<lite::Network*>(network), [](void*) {}};
  484. *nr_threads = lite::Runtime::get_cpu_threads_number(network_shared);
  485. LITE_CAPI_END();
  486. }
  487. int LITE_set_cpu_inplace_mode(LiteNetwork network) {
  488. LITE_CAPI_BEGIN();
  489. LITE_ASSERT(network, "The network pass to LITE api is null");
  490. std::shared_ptr<lite::Network> network_shared{
  491. static_cast<lite::Network*>(network), [](void*) {}};
  492. lite::Runtime::set_cpu_inplace_mode(network_shared);
  493. LITE_CAPI_END();
  494. }
  495. int LITE_use_tensorrt(LiteNetwork network) {
  496. LITE_CAPI_BEGIN();
  497. LITE_ASSERT(network, "The network pass to LITE api is null");
  498. std::shared_ptr<lite::Network> network_shared{
  499. static_cast<lite::Network*>(network), [](void*) {}};
  500. lite::Runtime::use_tensorrt(network_shared);
  501. LITE_CAPI_END();
  502. }
  503. int LITE_set_cpu_threads_number(LiteNetwork network, size_t nr_threads) {
  504. LITE_CAPI_BEGIN();
  505. LITE_ASSERT(network, "The network pass to LITE api is null");
  506. std::shared_ptr<lite::Network> network_shared{
  507. static_cast<lite::Network*>(network), [](void*) {}};
  508. lite::Runtime::set_cpu_threads_number(network_shared, nr_threads);
  509. LITE_CAPI_END();
  510. }
  511. int LITE_set_network_algo_policy(LiteNetwork network, LiteAlgoSelectStrategy strategy) {
  512. LITE_CAPI_BEGIN();
  513. LITE_ASSERT(network, "The network pass to LITE api is null");
  514. std::shared_ptr<lite::Network> network_shared{
  515. static_cast<lite::Network*>(network), [](void*) {}};
  516. lite::Runtime::set_network_algo_policy(network_shared, strategy);
  517. LITE_CAPI_END();
  518. }
  519. int LITE_set_network_algo_fastrun_config(
  520. LiteNetwork network, unsigned int shared_batch_size,
  521. int binary_equal_between_batch) {
  522. LITE_CAPI_BEGIN();
  523. LITE_ASSERT(network, "The network pass to LITE api is null");
  524. std::shared_ptr<lite::Network> network_shared{
  525. static_cast<lite::Network*>(network), [](void*) {}};
  526. lite::Runtime::set_network_algo_policy(
  527. network_shared, LiteAlgoSelectStrategy(0), shared_batch_size,
  528. binary_equal_between_batch);
  529. LITE_CAPI_END();
  530. }
  531. int LITE_set_network_algo_workspace_limit(LiteNetwork network, size_t workspace_limit) {
  532. LITE_CAPI_BEGIN();
  533. LITE_ASSERT(network, "The network pass to LITE api is null");
  534. std::shared_ptr<lite::Network> network_shared{
  535. static_cast<lite::Network*>(network), [](void*) {}};
  536. lite::Runtime::set_network_algo_workspace_limit(network_shared, workspace_limit);
  537. LITE_CAPI_END();
  538. }
  539. int LITE_set_runtime_thread_affinity(
  540. LiteNetwork network,
  541. const LiteThreadAffinityCallback thread_affinity_callback) {
  542. LITE_CAPI_BEGIN();
  543. LITE_ASSERT(network, "The network pass to LITE api is null");
  544. std::shared_ptr<lite::Network> network_shared{
  545. static_cast<lite::Network*>(network), [](void*) {}};
  546. lite::Runtime::set_runtime_thread_affinity(
  547. network_shared, std::move(thread_affinity_callback));
  548. LITE_CAPI_END();
  549. }
  550. int LITE_set_memory_allocator(
  551. LiteNetwork network, const LiteAllocate allocate_fun, const LiteFree free_fun) {
  552. LITE_CAPI_BEGIN();
  553. LITE_ASSERT(
  554. network && allocate_fun && free_fun, "The ptr pass to LITE api is null");
  555. std::shared_ptr<lite::Network> network_shared{
  556. static_cast<lite::Network*>(network), [](void*) {}};
  557. lite::Runtime::set_memory_allocator(
  558. network_shared, std::make_shared<UserAllocator>(allocate_fun, free_fun));
  559. LITE_CAPI_END();
  560. }
  561. int LITE_enable_io_txt_dump(LiteNetwork network, const char* io_txt_out_file) {
  562. LITE_CAPI_BEGIN();
  563. LITE_ASSERT(network, "The network pass to LITE api is null");
  564. std::shared_ptr<lite::Network> network_shared{
  565. static_cast<lite::Network*>(network), [](void*) {}};
  566. lite::Runtime::enable_io_txt_dump(network_shared, io_txt_out_file);
  567. LITE_CAPI_END();
  568. }
  569. int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out_dir) {
  570. LITE_CAPI_BEGIN();
  571. LITE_ASSERT(network, "The network pass to LITE api is null");
  572. std::shared_ptr<lite::Network> network_shared{
  573. static_cast<lite::Network*>(network), [](void*) {}};
  574. lite::Runtime::enable_io_bin_dump(network_shared, io_bin_out_dir);
  575. LITE_CAPI_END();
  576. }
  577. int LITE_shared_weight_with_network(
  578. LiteNetwork dst_network, const LiteNetwork src_network) {
  579. LITE_CAPI_BEGIN();
  580. LITE_ASSERT(dst_network && src_network, "The network pass to LITE api is null");
  581. const std::shared_ptr<lite::Network> src_shared_net{
  582. static_cast<lite::Network*>(src_network), [](void*) {}};
  583. std::shared_ptr<lite::Network> dst_shared_net{
  584. static_cast<lite::Network*>(dst_network), [](void*) {}};
  585. lite::Runtime::shared_weight_with_network(dst_shared_net, src_shared_net);
  586. LITE_CAPI_END();
  587. }
  588. int LITE_share_runtime_memroy(LiteNetwork dst_network, LiteNetwork src_network) {
  589. LITE_CAPI_BEGIN();
  590. LITE_ASSERT(src_network && dst_network, "The network pass to LITE api is null");
  591. std::shared_ptr<lite::Network> src_shared{
  592. static_cast<lite::Network*>(src_network), [](void*) {}};
  593. std::shared_ptr<lite::Network> dst_shared{
  594. static_cast<lite::Network*>(dst_network), [](void*) {}};
  595. lite::Runtime::share_runtime_memory_with(dst_shared, src_shared);
  596. LITE_CAPI_END();
  597. }
  598. int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir) {
  599. LITE_CAPI_BEGIN();
  600. #ifndef __IN_TEE_ENV__
  601. #if MGB_ENABLE_JSON
  602. LITE_ASSERT(network, "The network pass to LITE api is null");
  603. static_cast<lite::Network*>(network)->get_static_memory_alloc_info(log_dir);
  604. return 0;
  605. #endif
  606. #endif
  607. LITE_MARK_USED_VAR(network);
  608. LITE_MARK_USED_VAR(log_dir);
  609. LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
  610. LITE_CAPI_END();
  611. }
  612. int LITE_enable_global_layout_transform(LiteNetwork network) {
  613. LITE_CAPI_BEGIN();
  614. LITE_ASSERT(network, "The network pass to LITE api is null");
  615. std::shared_ptr<lite::Network> network_shared{
  616. static_cast<lite::Network*>(network), [](void*) {}};
  617. lite::Runtime::enable_global_layout_transform(network_shared);
  618. LITE_CAPI_END();
  619. }
  620. int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
  621. LITE_CAPI_BEGIN();
  622. LITE_ASSERT(network, "The network pass to LITE api is null");
  623. std::shared_ptr<lite::Network> network_shared{
  624. static_cast<lite::Network*>(network), [](void*) {}};
  625. lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
  626. LITE_CAPI_END();
  627. }
  628. namespace {
  629. static LITE_MUTEX mtx_io;
  630. static std::unordered_map<const void*, InnerIO>& get_global_io_holder() {
  631. static std::unordered_map<const void*, InnerIO> global_holder;
  632. return global_holder;
  633. }
  634. int write_ios_from_cpp_io(
  635. const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) {
  636. LITE_CAPI_BEGIN();
  637. LITE_LOCK_GUARD(mtx_io);
  638. get_global_io_holder()[key] = convert_to_inner_io(cpp_io);
  639. auto&& inner_io = get_global_io_holder()[key];
  640. ios->input_size = inner_io.inputs.size();
  641. ios->output_size = inner_io.outputs.size();
  642. ios->inputs = inner_io.inputs.data();
  643. ios->outputs = inner_io.outputs.data();
  644. size_t i = 0;
  645. for (; i < ios->input_size; i++) {
  646. auto io_ptr = ios->inputs + i;
  647. io_ptr->name = inner_io.names[i].c_str();
  648. }
  649. for (; i < ios->output_size; i++) {
  650. auto io_ptr = ios->outputs + i;
  651. io_ptr->name = inner_io.names[i].c_str();
  652. }
  653. LITE_CAPI_END();
  654. }
  655. } // namespace
  656. int LITE_get_model_io_info_by_path(
  657. const char* model_path, const LiteConfig config, LiteNetworkIO* ios) {
  658. LITE_CAPI_BEGIN();
  659. LITE_ASSERT(model_path, "The model_path pass to LITE api is null");
  660. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  661. std::string{model_path}, convert_to_lite_config(config));
  662. return write_ios_from_cpp_io(
  663. cpp_ios, ios, reinterpret_cast<const void*>(model_path));
  664. LITE_CAPI_END();
  665. }
  666. int LITE_get_model_io_info_by_memory(
  667. const void* model_mem, size_t size, const LiteConfig config,
  668. LiteNetworkIO* ios) {
  669. LITE_CAPI_BEGIN();
  670. LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null");
  671. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  672. model_mem, size, convert_to_lite_config(config));
  673. return write_ios_from_cpp_io(
  674. cpp_ios, ios, reinterpret_cast<const void*>(model_mem));
  675. LITE_CAPI_END();
  676. }
  677. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}