You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. #include "lite/network.h"
  2. #include "common.h"
  3. #include "lite-c/network_c.h"
  4. #include "../../src/network_impl_base.h"
  5. #include <string.h>
  6. #include <memory>
  7. #include <mutex>
  8. #include <unordered_map>
  9. //! define a default Options
  10. const LiteOptions default_option = {
  11. .weight_preprocess = false,
  12. .fuse_preprocess = false,
  13. .fake_next_exec = false,
  14. .var_sanity_check_first_run = true,
  15. .const_shape = false,
  16. .force_dynamic_alloc = false,
  17. .force_output_dynamic_alloc = false,
  18. .force_output_use_user_specified_memory = false,
  19. .no_profiling_on_shape_change = false,
  20. .jit_level = 0,
  21. .comp_node_seq_record_level = 0,
  22. .graph_opt_level = 2,
  23. .async_exec_level = 1,
  24. //! layout transform options
  25. .enable_nchw44 = 0,
  26. .enable_nchw44_dot = 0,
  27. .enable_nchw88 = 0,
  28. .enable_nhwcd4 = 0,
  29. .enable_nchw4 = 0,
  30. .enable_nchw32 = 0,
  31. .enable_nchw64 = 0,
  32. };
  33. //! define a default config
  34. LiteConfig default_config_t = {
  35. .has_compression = false,
  36. .device_id = -1,
  37. .device_type = LiteDeviceType::LITE_CPU,
  38. .backend = LiteBackend::LITE_DEFAULT,
  39. .bare_model_cryption_name = nullptr,
  40. .options = default_option,
  41. .auto_optimize_inference = false};
  42. LiteConfig* default_config() {
  43. return &default_config_t;
  44. }
  45. //! define a default IO
  46. const LiteIO default_io = {
  47. .name = nullptr,
  48. .is_host = true,
  49. .io_type = LiteIOType::LITE_IO_VALUE,
  50. .config_layout = default_layout};
  51. //! define a default NetworkIO
  52. LiteNetworkIO default_network_io_t = {
  53. .inputs = nullptr, .outputs = nullptr, .input_size = 0, .output_size = 0};
  54. LiteNetworkIO* default_network_io() {
  55. return &default_network_io_t;
  56. }
  57. namespace {
  58. static LITE_MUTEX mtx_network;
  59. std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() {
  60. static std::unordered_map<void*, std::shared_ptr<lite::Network>> network_holder;
  61. return network_holder;
  62. }
  63. /*!
  64. * \brief A user-implemented allocator interface
  65. */
  66. class UserAllocator : public lite::Allocator {
  67. public:
  68. UserAllocator(LiteAllocate allocate_func, LiteFree free_func)
  69. : m_allocator(allocate_func), m_free(free_func) {
  70. LITE_ASSERT(m_allocator && m_free);
  71. }
  72. //! allocate memory of size in the given device with the given align
  73. void* allocate(LiteDeviceType device_type, int device_id, size_t size, size_t align)
  74. override {
  75. return m_allocator(device_type, device_id, size, align);
  76. }
  77. //! free the memory pointed by ptr in the given device
  78. void free(LiteDeviceType device_type, int device_id, void* ptr) override {
  79. m_free(device_type, device_id, ptr);
  80. }
  81. private:
  82. LiteAllocate m_allocator;
  83. LiteFree m_free;
  84. };
  85. } // namespace
  86. //! convert c config to lite::config
  87. lite::Config convert_to_lite_config(const LiteConfig c_config) {
  88. lite::Config lite_config;
  89. lite_config.device_type = c_config.device_type;
  90. if (c_config.bare_model_cryption_name) {
  91. lite_config.bare_model_cryption_name = c_config.bare_model_cryption_name;
  92. }
  93. lite_config.backend = c_config.backend;
  94. lite_config.has_compression = c_config.has_compression;
  95. lite_config.device_id = c_config.device_id;
  96. lite_config.options.weight_preprocess = c_config.options.weight_preprocess;
  97. lite_config.options.fuse_preprocess = c_config.options.fuse_preprocess;
  98. lite_config.options.fake_next_exec = c_config.options.fake_next_exec;
  99. lite_config.options.var_sanity_check_first_run =
  100. c_config.options.var_sanity_check_first_run;
  101. lite_config.options.const_shape = c_config.options.const_shape;
  102. lite_config.options.force_dynamic_alloc = c_config.options.force_dynamic_alloc;
  103. lite_config.options.force_output_use_user_specified_memory =
  104. c_config.options.force_output_use_user_specified_memory;
  105. lite_config.options.force_output_dynamic_alloc =
  106. c_config.options.force_output_dynamic_alloc;
  107. lite_config.options.no_profiling_on_shape_change =
  108. c_config.options.no_profiling_on_shape_change;
  109. lite_config.options.jit_level = c_config.options.jit_level;
  110. lite_config.options.comp_node_seq_record_level =
  111. c_config.options.comp_node_seq_record_level;
  112. lite_config.options.graph_opt_level = c_config.options.graph_opt_level;
  113. lite_config.options.async_exec_level = c_config.options.async_exec_level;
  114. lite_config.options.enable_nchw44 = c_config.options.enable_nchw44;
  115. lite_config.options.enable_nchw44_dot = c_config.options.enable_nchw44_dot;
  116. lite_config.options.enable_nchw88 = c_config.options.enable_nchw88;
  117. lite_config.options.enable_nchw4 = c_config.options.enable_nchw4;
  118. lite_config.options.enable_nhwcd4 = c_config.options.enable_nhwcd4;
  119. lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
  120. lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
  121. lite_config.auto_optimize_inference = c_config.auto_optimize_inference;
  122. return lite_config;
  123. }
  124. //! convert C NetworkIO io to lite::NetworkIO
  125. lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
  126. lite::NetworkIO network_io;
  127. for (size_t i = 0; i < c_network_io.input_size; i++) {
  128. LiteIO* c_io = c_network_io.inputs + i;
  129. LITE_ASSERT(c_io->name, "input name of io tensor must set.");
  130. network_io.inputs.push_back(
  131. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  132. convert_to_layout(c_io->config_layout)});
  133. }
  134. for (size_t i = 0; i < c_network_io.output_size; i++) {
  135. LiteIO* c_io = c_network_io.outputs + i;
  136. LITE_ASSERT(c_io->name, "output name of io tensor must set.");
  137. network_io.outputs.push_back(
  138. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  139. convert_to_layout(c_io->config_layout)});
  140. }
  141. return network_io;
  142. }
  143. struct InnerIO {
  144. std::vector<std::string> names;
  145. std::vector<LiteIO> inputs;
  146. std::vector<LiteIO> outputs;
  147. };
  148. InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
  149. InnerIO innner_io;
  150. for (size_t i = 0; i < network_io.inputs.size(); i++) {
  151. lite::IO io = network_io.inputs[i];
  152. innner_io.names.push_back(io.name);
  153. innner_io.inputs.push_back(
  154. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  155. convert_to_clayout(io.config_layout)});
  156. }
  157. for (size_t i = 0; i < network_io.outputs.size(); i++) {
  158. lite::IO io = network_io.outputs[i];
  159. innner_io.names.push_back(io.name);
  160. innner_io.outputs.push_back(
  161. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  162. convert_to_clayout(io.config_layout)});
  163. }
  164. return innner_io;
  165. }
  166. lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
  167. lite::ExtraConfig ret;
  168. ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
  169. return ret;
  170. }
  171. int LITE_make_default_network(LiteNetwork* network) {
  172. LITE_CAPI_BEGIN();
  173. LITE_ASSERT(network, "The network pass to LITE api is null");
  174. auto lite_network = std::make_shared<lite::Network>();
  175. LITE_LOCK_GUARD(mtx_network);
  176. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  177. *network = lite_network.get();
  178. LITE_CAPI_END();
  179. }
  180. int LITE_make_network(
  181. LiteNetwork* network, const LiteConfig config, const LiteNetworkIO network_io) {
  182. LITE_CAPI_BEGIN();
  183. LITE_ASSERT(network, "The network pass to LITE api is null");
  184. auto lite_network = std::make_shared<lite::Network>(
  185. convert_to_lite_config(config), convert_to_lite_io(network_io));
  186. LITE_LOCK_GUARD(mtx_network);
  187. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  188. *network = lite_network.get();
  189. LITE_CAPI_END();
  190. }
  191. int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
  192. LITE_CAPI_BEGIN();
  193. LITE_ASSERT(network, "The network pass to LITE api is null");
  194. auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config));
  195. LITE_LOCK_GUARD(mtx_network);
  196. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  197. *network = lite_network.get();
  198. LITE_CAPI_END();
  199. }
  200. int LITE_load_model_from_mem(LiteNetwork network, void* model_mem, size_t size) {
  201. LITE_CAPI_BEGIN();
  202. LITE_ASSERT(network, "The network pass to LITE api is null");
  203. LITE_ASSERT(model_mem, "The model memory pass to LITE api is null");
  204. static_cast<lite::Network*>(network)->load_model(model_mem, size);
  205. LITE_CAPI_END();
  206. }
  207. int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
  208. LITE_CAPI_BEGIN();
  209. LITE_ASSERT(network, "The network pass to LITE api is null");
  210. LITE_ASSERT(model_path, "The model path pass to LITE api is null");
  211. static_cast<lite::Network*>(network)->load_model(model_path);
  212. LITE_CAPI_END();
  213. }
  214. int LITE_destroy_network(LiteNetwork network) {
  215. LITE_CAPI_BEGIN();
  216. LITE_ASSERT(network, "The network pass to LITE api is null");
  217. LITE_LOCK_GUARD(mtx_network);
  218. get_gloabl_network_holder().erase(network);
  219. LITE_CAPI_END();
  220. }
  221. int LITE_forward(const LiteNetwork network) {
  222. LITE_CAPI_BEGIN();
  223. LITE_ASSERT(network, "The network pass to LITE api is null");
  224. static_cast<lite::Network*>(network)->forward();
  225. LITE_CAPI_END();
  226. }
  227. int LITE_wait(const LiteNetwork network) {
  228. LITE_CAPI_BEGIN();
  229. LITE_ASSERT(network, "The network pass to LITE api is null");
  230. static_cast<lite::Network*>(network)->wait();
  231. LITE_CAPI_END();
  232. }
  233. int LITE_get_io_tensor(
  234. LiteNetwork network, const char* io_name, LiteTensorPhase phase,
  235. LiteTensor* tensor) {
  236. LITE_CAPI_BEGIN();
  237. LITE_ASSERT(network, "The network pass to LITE api is null");
  238. auto io_tensor =
  239. static_cast<lite::Network*>(network)->get_io_tensor(io_name, phase);
  240. *tensor = io_tensor.get();
  241. LITE_CAPI_END();
  242. }
  243. int LITE_get_input_name(const LiteNetwork network, size_t index, const char** name) {
  244. LITE_CAPI_BEGIN();
  245. LITE_ASSERT(network && name, "The network pass to LITE api is null");
  246. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  247. ->get_input_name(index);
  248. LITE_CAPI_END();
  249. }
  250. int LITE_get_output_name(const LiteNetwork network, size_t index, const char** name) {
  251. LITE_CAPI_BEGIN();
  252. LITE_ASSERT(network, "The network pass to LITE api is null");
  253. LITE_ASSERT(name, "The name ptr pass to LITE api is null");
  254. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  255. ->get_output_name(index);
  256. LITE_CAPI_END();
  257. }
  258. int LITE_get_all_input_name(
  259. const LiteNetwork network, size_t* size, const char** name) {
  260. LITE_CAPI_BEGIN();
  261. LITE_ASSERT(network, "The network pass to LITE api is null");
  262. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  263. ->get_all_input_name();
  264. if (size)
  265. *size = names.size();
  266. if (name) {
  267. for (auto in_name : names) {
  268. *name = in_name;
  269. name++;
  270. }
  271. }
  272. LITE_CAPI_END();
  273. }
  274. int LITE_get_all_output_name(
  275. const LiteNetwork network, size_t* size, const char** name) {
  276. LITE_CAPI_BEGIN();
  277. LITE_ASSERT(network, "The network pass to LITE api is null");
  278. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  279. ->get_all_output_name();
  280. if (size)
  281. *size = names.size();
  282. if (name) {
  283. for (auto in_name : names) {
  284. *name = in_name;
  285. name++;
  286. }
  287. }
  288. LITE_CAPI_END();
  289. }
  290. int LITE_set_device_id(LiteNetwork network, int device_id) {
  291. LITE_CAPI_BEGIN();
  292. LITE_ASSERT(network, "The network pass to LITE api is null");
  293. static_cast<lite::Network*>(network)->set_device_id(device_id);
  294. LITE_CAPI_END();
  295. }
  296. int LITE_get_device_id(const LiteNetwork network, int* device_id) {
  297. LITE_CAPI_BEGIN();
  298. LITE_ASSERT(network, "The network pass to LITE api is null");
  299. LITE_ASSERT(device_id, "The device_id pass to LITE api is null");
  300. *device_id = static_cast<lite::Network*>(network)->get_device_id();
  301. LITE_CAPI_END();
  302. }
  303. int LITE_set_stream_id(LiteNetwork network, int stream_id) {
  304. LITE_CAPI_BEGIN();
  305. LITE_ASSERT(network, "The network pass to LITE api is null");
  306. static_cast<lite::Network*>(network)->set_stream_id(stream_id);
  307. LITE_CAPI_END();
  308. }
  309. int LITE_get_stream_id(const LiteNetwork network, int* stream_id) {
  310. LITE_CAPI_BEGIN();
  311. LITE_ASSERT(network, "The network pass to LITE api is null");
  312. LITE_ASSERT(stream_id, "The stream_id pass to LITE api is null");
  313. *stream_id = static_cast<lite::Network*>(network)->get_stream_id();
  314. LITE_CAPI_END();
  315. }
  316. int LITE_get_model_extra_info(
  317. const LiteNetwork network, const char** info, int* info_size) {
  318. LITE_CAPI_BEGIN();
  319. LITE_ASSERT(network, "The network pass to LITE api is null");
  320. LITE_ASSERT(info_size, "The info and info_size are all null");
  321. auto& extra_info = static_cast<lite::Network*>(network)->get_model_extra_info();
  322. *info_size = extra_info.size();
  323. *info = extra_info.c_str();
  324. LITE_MARK_USED_VAR(info);
  325. LITE_CAPI_END();
  326. }
  327. int LITE_get_device_type(const LiteNetwork network, LiteDeviceType* device_type) {
  328. LITE_CAPI_BEGIN();
  329. LITE_ASSERT(network, "The network pass to LITE api is null");
  330. LITE_ASSERT(device_type, "The device_type pass to LITE api is null");
  331. *device_type = static_cast<lite::Network*>(network)->get_device_type();
  332. LITE_CAPI_END();
  333. }
  334. int LITE_set_async_callback(
  335. LiteNetwork network, const LiteAsyncCallback async_callback) {
  336. LITE_CAPI_BEGIN();
  337. LITE_ASSERT(network, "The network pass to LITE api is null");
  338. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  339. static_cast<lite::Network*>(network)->set_async_callback(std::move(async_callback));
  340. LITE_CAPI_END();
  341. }
  342. int LITE_set_async_callback_with_userdata(
  343. LiteNetwork network, LiteAsyncCallbackWithData async_callback,
  344. void* user_data) {
  345. LITE_CAPI_BEGIN();
  346. LITE_ASSERT(network, "The network pass to LITE api is null");
  347. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  348. auto lite_async_callback = [async_callback, user_data]() -> void {
  349. async_callback(user_data);
  350. };
  351. static_cast<lite::Network*>(network)->set_async_callback(
  352. std::move(lite_async_callback));
  353. LITE_CAPI_END();
  354. }
  355. int LITE_set_start_callback(
  356. LiteNetwork network, const LiteStartCallback start_callback) {
  357. LITE_CAPI_BEGIN();
  358. LITE_ASSERT(network, "The network pass to LITE api is null");
  359. auto lite_start_callback =
  360. [start_callback](const std::unordered_map<
  361. std::string,
  362. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  363. inputs_map) -> void {
  364. std::vector<LiteIO> ios;
  365. std::vector<LiteTensor> io_tensors;
  366. size_t nr_io = 0;
  367. for (const auto& io : inputs_map) {
  368. nr_io++;
  369. auto&& lite_io = io.second.first;
  370. ios.push_back(
  371. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  372. convert_to_clayout(lite_io.config_layout)});
  373. io_tensors.push_back(io.second.second.get());
  374. }
  375. start_callback(ios.data(), io_tensors.data(), nr_io);
  376. };
  377. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  378. LITE_CAPI_END();
  379. }
  380. int LITE_set_start_callback_with_userdata(
  381. LiteNetwork network, const LiteStartCallbackWithData start_callback,
  382. void* user_data) {
  383. LITE_CAPI_BEGIN();
  384. LITE_ASSERT(network, "The network pass to LITE api is null");
  385. auto lite_start_callback =
  386. [start_callback,
  387. user_data](const std::unordered_map<
  388. std::string,
  389. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& inputs_map)
  390. -> void {
  391. std::vector<LiteIO> ios;
  392. std::vector<LiteTensor> io_tensors;
  393. size_t nr_io = 0;
  394. for (const auto& io : inputs_map) {
  395. nr_io++;
  396. auto&& lite_io = io.second.first;
  397. ios.push_back(
  398. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  399. convert_to_clayout(lite_io.config_layout)});
  400. io_tensors.push_back(io.second.second.get());
  401. }
  402. start_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  403. };
  404. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  405. LITE_CAPI_END();
  406. }
  407. int LITE_set_finish_callback(
  408. LiteNetwork network, const LiteFinishCallback finish_callback) {
  409. LITE_CAPI_BEGIN();
  410. LITE_ASSERT(network, "The network pass to LITE api is null");
  411. auto lite_finish_callback =
  412. [finish_callback](const std::unordered_map<
  413. std::string,
  414. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  415. outputs_map) -> void {
  416. std::vector<LiteIO> ios;
  417. std::vector<LiteTensor> io_tensors;
  418. size_t nr_io = 0;
  419. for (const auto& io : outputs_map) {
  420. nr_io++;
  421. auto&& lite_io = io.second.first;
  422. ios.push_back(
  423. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  424. convert_to_clayout(lite_io.config_layout)});
  425. io_tensors.push_back(io.second.second.get());
  426. }
  427. finish_callback(ios.data(), io_tensors.data(), nr_io);
  428. };
  429. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  430. LITE_CAPI_END();
  431. }
  432. int LITE_set_finish_callback_with_userdata(
  433. LiteNetwork network, const LiteFinishCallbackWithData finish_callback,
  434. void* user_data) {
  435. LITE_CAPI_BEGIN();
  436. LITE_ASSERT(network, "The network pass to LITE api is null");
  437. auto lite_finish_callback =
  438. [finish_callback,
  439. user_data](const std::unordered_map<
  440. std::string,
  441. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  442. outputs_map) -> void {
  443. std::vector<LiteIO> ios;
  444. std::vector<LiteTensor> io_tensors;
  445. size_t nr_io = 0;
  446. for (const auto& io : outputs_map) {
  447. nr_io++;
  448. auto&& lite_io = io.second.first;
  449. ios.push_back(
  450. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  451. convert_to_clayout(lite_io.config_layout)});
  452. io_tensors.push_back(io.second.second.get());
  453. }
  454. finish_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  455. };
  456. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  457. LITE_CAPI_END();
  458. }
  459. int LITE_enable_profile_performance(
  460. LiteNetwork network, const char* profile_json_file_path) {
  461. LITE_CAPI_BEGIN();
  462. LITE_ASSERT(network, "The network pass to LITE api is null");
  463. static_cast<lite::Network*>(network)->enable_profile_performance(
  464. profile_json_file_path);
  465. LITE_CAPI_END();
  466. }
  467. int LITE_is_cpu_inplace_mode(const LiteNetwork network, int* is_cpu_inplace_mode) {
  468. LITE_CAPI_BEGIN();
  469. LITE_ASSERT(network && is_cpu_inplace_mode, "The network pass to LITE api is null");
  470. std::shared_ptr<lite::Network> network_shared{
  471. static_cast<lite::Network*>(network), [](void*) {}};
  472. *is_cpu_inplace_mode = lite::Runtime::is_cpu_inplace_mode(network_shared);
  473. LITE_CAPI_END();
  474. }
  475. int LITE_get_cpu_threads_number(const LiteNetwork network, size_t* nr_threads) {
  476. LITE_CAPI_BEGIN();
  477. LITE_ASSERT(network, "The network pass to LITE api is null");
  478. LITE_ASSERT(nr_threads, "The ptr pass to LITE api is null");
  479. std::shared_ptr<lite::Network> network_shared{
  480. static_cast<lite::Network*>(network), [](void*) {}};
  481. *nr_threads = lite::Runtime::get_cpu_threads_number(network_shared);
  482. LITE_CAPI_END();
  483. }
  484. int LITE_set_cpu_inplace_mode(LiteNetwork network) {
  485. LITE_CAPI_BEGIN();
  486. LITE_ASSERT(network, "The network pass to LITE api is null");
  487. std::shared_ptr<lite::Network> network_shared{
  488. static_cast<lite::Network*>(network), [](void*) {}};
  489. lite::Runtime::set_cpu_inplace_mode(network_shared);
  490. LITE_CAPI_END();
  491. }
  492. int LITE_use_tensorrt(LiteNetwork network) {
  493. LITE_CAPI_BEGIN();
  494. LITE_ASSERT(network, "The network pass to LITE api is null");
  495. std::shared_ptr<lite::Network> network_shared{
  496. static_cast<lite::Network*>(network), [](void*) {}};
  497. lite::Runtime::use_tensorrt(network_shared);
  498. LITE_CAPI_END();
  499. }
  500. int LITE_set_cpu_threads_number(LiteNetwork network, size_t nr_threads) {
  501. LITE_CAPI_BEGIN();
  502. LITE_ASSERT(network, "The network pass to LITE api is null");
  503. std::shared_ptr<lite::Network> network_shared{
  504. static_cast<lite::Network*>(network), [](void*) {}};
  505. lite::Runtime::set_cpu_threads_number(network_shared, nr_threads);
  506. LITE_CAPI_END();
  507. }
  508. int LITE_set_network_algo_policy(LiteNetwork network, LiteAlgoSelectStrategy strategy) {
  509. LITE_CAPI_BEGIN();
  510. LITE_ASSERT(network, "The network pass to LITE api is null");
  511. std::shared_ptr<lite::Network> network_shared{
  512. static_cast<lite::Network*>(network), [](void*) {}};
  513. lite::Runtime::set_network_algo_policy(network_shared, strategy);
  514. LITE_CAPI_END();
  515. }
  516. int LITE_set_network_algo_fastrun_config(
  517. LiteNetwork network, unsigned int shared_batch_size,
  518. int binary_equal_between_batch) {
  519. LITE_CAPI_BEGIN();
  520. LITE_ASSERT(network, "The network pass to LITE api is null");
  521. std::shared_ptr<lite::Network> network_shared{
  522. static_cast<lite::Network*>(network), [](void*) {}};
  523. lite::Runtime::set_network_algo_policy(
  524. network_shared, LiteAlgoSelectStrategy(0), shared_batch_size,
  525. binary_equal_between_batch);
  526. LITE_CAPI_END();
  527. }
  528. int LITE_set_network_algo_workspace_limit(LiteNetwork network, size_t workspace_limit) {
  529. LITE_CAPI_BEGIN();
  530. LITE_ASSERT(network, "The network pass to LITE api is null");
  531. std::shared_ptr<lite::Network> network_shared{
  532. static_cast<lite::Network*>(network), [](void*) {}};
  533. lite::Runtime::set_network_algo_workspace_limit(network_shared, workspace_limit);
  534. LITE_CAPI_END();
  535. }
  536. int LITE_set_runtime_thread_affinity(
  537. LiteNetwork network,
  538. const LiteThreadAffinityCallback thread_affinity_callback) {
  539. LITE_CAPI_BEGIN();
  540. LITE_ASSERT(network, "The network pass to LITE api is null");
  541. std::shared_ptr<lite::Network> network_shared{
  542. static_cast<lite::Network*>(network), [](void*) {}};
  543. lite::Runtime::set_runtime_thread_affinity(
  544. network_shared, std::move(thread_affinity_callback));
  545. LITE_CAPI_END();
  546. }
  547. int LITE_set_memory_allocator(
  548. LiteNetwork network, const LiteAllocate allocate_fun, const LiteFree free_fun) {
  549. LITE_CAPI_BEGIN();
  550. LITE_ASSERT(
  551. network && allocate_fun && free_fun, "The ptr pass to LITE api is null");
  552. std::shared_ptr<lite::Network> network_shared{
  553. static_cast<lite::Network*>(network), [](void*) {}};
  554. lite::Runtime::set_memory_allocator(
  555. network_shared, std::make_shared<UserAllocator>(allocate_fun, free_fun));
  556. LITE_CAPI_END();
  557. }
  558. int LITE_enable_io_txt_dump(LiteNetwork network, const char* io_txt_out_file) {
  559. LITE_CAPI_BEGIN();
  560. LITE_ASSERT(network, "The network pass to LITE api is null");
  561. std::shared_ptr<lite::Network> network_shared{
  562. static_cast<lite::Network*>(network), [](void*) {}};
  563. lite::Runtime::enable_io_txt_dump(network_shared, io_txt_out_file);
  564. LITE_CAPI_END();
  565. }
  566. int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out_dir) {
  567. LITE_CAPI_BEGIN();
  568. LITE_ASSERT(network, "The network pass to LITE api is null");
  569. std::shared_ptr<lite::Network> network_shared{
  570. static_cast<lite::Network*>(network), [](void*) {}};
  571. lite::Runtime::enable_io_bin_dump(network_shared, io_bin_out_dir);
  572. LITE_CAPI_END();
  573. }
  574. int LITE_shared_weight_with_network(
  575. LiteNetwork dst_network, const LiteNetwork src_network) {
  576. LITE_CAPI_BEGIN();
  577. LITE_ASSERT(dst_network && src_network, "The network pass to LITE api is null");
  578. const std::shared_ptr<lite::Network> src_shared_net{
  579. static_cast<lite::Network*>(src_network), [](void*) {}};
  580. std::shared_ptr<lite::Network> dst_shared_net{
  581. static_cast<lite::Network*>(dst_network), [](void*) {}};
  582. lite::Runtime::shared_weight_with_network(dst_shared_net, src_shared_net);
  583. LITE_CAPI_END();
  584. }
  585. int LITE_share_runtime_memroy(LiteNetwork dst_network, LiteNetwork src_network) {
  586. LITE_CAPI_BEGIN();
  587. LITE_ASSERT(src_network && dst_network, "The network pass to LITE api is null");
  588. std::shared_ptr<lite::Network> src_shared{
  589. static_cast<lite::Network*>(src_network), [](void*) {}};
  590. std::shared_ptr<lite::Network> dst_shared{
  591. static_cast<lite::Network*>(dst_network), [](void*) {}};
  592. lite::Runtime::share_runtime_memory_with(dst_shared, src_shared);
  593. LITE_CAPI_END();
  594. }
  595. int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir) {
  596. LITE_CAPI_BEGIN();
  597. #ifndef __IN_TEE_ENV__
  598. #if MGB_ENABLE_JSON
  599. LITE_ASSERT(network, "The network pass to LITE api is null");
  600. static_cast<lite::Network*>(network)->get_static_memory_alloc_info(log_dir);
  601. return 0;
  602. #endif
  603. #endif
  604. LITE_MARK_USED_VAR(network);
  605. LITE_MARK_USED_VAR(log_dir);
  606. LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
  607. LITE_CAPI_END();
  608. }
  609. int LITE_enable_global_layout_transform(LiteNetwork network) {
  610. LITE_CAPI_BEGIN();
  611. LITE_ASSERT(network, "The network pass to LITE api is null");
  612. std::shared_ptr<lite::Network> network_shared{
  613. static_cast<lite::Network*>(network), [](void*) {}};
  614. lite::Runtime::enable_global_layout_transform(network_shared);
  615. LITE_CAPI_END();
  616. }
  617. int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
  618. LITE_CAPI_BEGIN();
  619. LITE_ASSERT(network, "The network pass to LITE api is null");
  620. std::shared_ptr<lite::Network> network_shared{
  621. static_cast<lite::Network*>(network), [](void*) {}};
  622. lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
  623. LITE_CAPI_END();
  624. }
  625. namespace {
  626. static LITE_MUTEX mtx_io;
  627. static std::unordered_map<const void*, InnerIO>& get_global_io_holder() {
  628. static std::unordered_map<const void*, InnerIO> global_holder;
  629. return global_holder;
  630. }
  631. int write_ios_from_cpp_io(
  632. const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) {
  633. LITE_CAPI_BEGIN();
  634. LITE_LOCK_GUARD(mtx_io);
  635. get_global_io_holder()[key] = convert_to_inner_io(cpp_io);
  636. auto&& inner_io = get_global_io_holder()[key];
  637. ios->input_size = inner_io.inputs.size();
  638. ios->output_size = inner_io.outputs.size();
  639. ios->inputs = inner_io.inputs.data();
  640. ios->outputs = inner_io.outputs.data();
  641. size_t i = 0;
  642. for (; i < ios->input_size; i++) {
  643. auto io_ptr = ios->inputs + i;
  644. io_ptr->name = inner_io.names[i].c_str();
  645. }
  646. for (; i < ios->output_size; i++) {
  647. auto io_ptr = ios->outputs + i;
  648. io_ptr->name = inner_io.names[i].c_str();
  649. }
  650. LITE_CAPI_END();
  651. }
  652. } // namespace
  653. int LITE_get_model_io_info_by_path(
  654. const char* model_path, const LiteConfig config, LiteNetworkIO* ios) {
  655. LITE_CAPI_BEGIN();
  656. LITE_ASSERT(model_path, "The model_path pass to LITE api is null");
  657. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  658. std::string{model_path}, convert_to_lite_config(config));
  659. return write_ios_from_cpp_io(
  660. cpp_ios, ios, reinterpret_cast<const void*>(model_path));
  661. LITE_CAPI_END();
  662. }
  663. int LITE_get_model_io_info_by_memory(
  664. const void* model_mem, size_t size, const LiteConfig config,
  665. LiteNetworkIO* ios) {
  666. LITE_CAPI_BEGIN();
  667. LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null");
  668. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  669. model_mem, size, convert_to_lite_config(config));
  670. return write_ios_from_cpp_io(
  671. cpp_ios, ios, reinterpret_cast<const void*>(model_mem));
  672. LITE_CAPI_END();
  673. }
  674. LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
  675. LITE_CAPI_BEGIN();
  676. LITE_ASSERT(network, "The network pass to LITE api is null");
  677. static_cast<lite::Network*>(network)->extra_configure(
  678. convert_extra_config(extra_config));
  679. LITE_CAPI_END();
  680. }
  681. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}