You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 30 kB


  1. #include "lite/network.h"
  2. #include "common.h"
  3. #include "lite-c/network_c.h"
  4. #include "../../src/network_impl_base.h"
  5. #include <string.h>
  6. #include <memory>
  7. #include <mutex>
  8. #include <unordered_map>
  9. //! define a default Options
  10. const LiteOptions default_option = {
  11. .weight_preprocess = false,
  12. .fuse_preprocess = false,
  13. .fake_next_exec = false,
  14. .var_sanity_check_first_run = true,
  15. .const_shape = false,
  16. .force_dynamic_alloc = false,
  17. .force_output_dynamic_alloc = false,
  18. .force_output_use_user_specified_memory = false,
  19. .no_profiling_on_shape_change = false,
  20. .jit_level = 0,
  21. .comp_node_seq_record_level = 0,
  22. .graph_opt_level = 2,
  23. .async_exec_level = 1,
  24. //! layout transform options
  25. .enable_nchw44 = 0,
  26. .enable_nchw44_dot = 0,
  27. .enable_nchw88 = 0,
  28. .enable_nhwcd4 = 0,
  29. .enable_nchw4 = 0,
  30. .enable_nchw32 = 0,
  31. .enable_nchw64 = 0,
  32. };
  33. //! define a default config
  34. LiteConfig default_config_t = {
  35. .has_compression = false,
  36. .device_id = -1,
  37. .device_type = LiteDeviceType::LITE_CPU,
  38. .backend = LiteBackend::LITE_DEFAULT,
  39. .bare_model_cryption_name = nullptr,
  40. .options = default_option,
  41. .auto_optimize_inference = false,
  42. .discrete_input_name = nullptr};
  43. LiteConfig* default_config() {
  44. return &default_config_t;
  45. }
  46. //! define a default IO
  47. const LiteIO default_io = {
  48. .name = nullptr,
  49. .is_host = true,
  50. .io_type = LiteIOType::LITE_IO_VALUE,
  51. .config_layout = default_layout};
  52. //! define a default NetworkIO
  53. LiteNetworkIO default_network_io_t = {
  54. .inputs = nullptr, .outputs = nullptr, .input_size = 0, .output_size = 0};
  55. LiteNetworkIO* default_network_io() {
  56. return &default_network_io_t;
  57. }
  58. namespace {
  59. static LITE_MUTEX mtx_network;
  60. std::unordered_map<void*, std::shared_ptr<lite::Network>>& get_gloabl_network_holder() {
  61. static std::unordered_map<void*, std::shared_ptr<lite::Network>> network_holder;
  62. return network_holder;
  63. }
  64. /*!
  65. * \brief A user-implemented allocator interface
  66. */
  67. class UserAllocator : public lite::Allocator {
  68. public:
  69. UserAllocator(LiteAllocate allocate_func, LiteFree free_func)
  70. : m_allocator(allocate_func), m_free(free_func) {
  71. LITE_ASSERT(m_allocator && m_free);
  72. }
  73. //! allocate memory of size in the given device with the given align
  74. void* allocate(LiteDeviceType device_type, int device_id, size_t size, size_t align)
  75. override {
  76. return m_allocator(device_type, device_id, size, align);
  77. }
  78. //! free the memory pointed by ptr in the given device
  79. void free(LiteDeviceType device_type, int device_id, void* ptr) override {
  80. m_free(device_type, device_id, ptr);
  81. }
  82. private:
  83. LiteAllocate m_allocator;
  84. LiteFree m_free;
  85. };
  86. } // namespace
  87. //! convert c config to lite::config
  88. lite::Config convert_to_lite_config(const LiteConfig c_config) {
  89. lite::Config lite_config;
  90. lite_config.device_type = c_config.device_type;
  91. if (c_config.bare_model_cryption_name) {
  92. lite_config.bare_model_cryption_name = c_config.bare_model_cryption_name;
  93. }
  94. lite_config.backend = c_config.backend;
  95. lite_config.has_compression = c_config.has_compression;
  96. lite_config.device_id = c_config.device_id;
  97. lite_config.options.weight_preprocess = c_config.options.weight_preprocess;
  98. lite_config.options.fuse_preprocess = c_config.options.fuse_preprocess;
  99. lite_config.options.fake_next_exec = c_config.options.fake_next_exec;
  100. lite_config.options.var_sanity_check_first_run =
  101. c_config.options.var_sanity_check_first_run;
  102. lite_config.options.const_shape = c_config.options.const_shape;
  103. lite_config.options.force_dynamic_alloc = c_config.options.force_dynamic_alloc;
  104. lite_config.options.force_output_use_user_specified_memory =
  105. c_config.options.force_output_use_user_specified_memory;
  106. lite_config.options.force_output_dynamic_alloc =
  107. c_config.options.force_output_dynamic_alloc;
  108. lite_config.options.no_profiling_on_shape_change =
  109. c_config.options.no_profiling_on_shape_change;
  110. lite_config.options.jit_level = c_config.options.jit_level;
  111. lite_config.options.comp_node_seq_record_level =
  112. c_config.options.comp_node_seq_record_level;
  113. lite_config.options.graph_opt_level = c_config.options.graph_opt_level;
  114. lite_config.options.async_exec_level = c_config.options.async_exec_level;
  115. lite_config.options.enable_nchw44 = c_config.options.enable_nchw44;
  116. lite_config.options.enable_nchw44_dot = c_config.options.enable_nchw44_dot;
  117. lite_config.options.enable_nchw88 = c_config.options.enable_nchw88;
  118. lite_config.options.enable_nchw4 = c_config.options.enable_nchw4;
  119. lite_config.options.enable_nhwcd4 = c_config.options.enable_nhwcd4;
  120. lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
  121. lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
  122. lite_config.auto_optimize_inference = c_config.auto_optimize_inference;
  123. if (c_config.discrete_input_name) {
  124. lite_config.discrete_input_name = c_config.discrete_input_name;
  125. }
  126. return lite_config;
  127. }
  128. //! convert C NetworkIO io to lite::NetworkIO
  129. lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
  130. lite::NetworkIO network_io;
  131. for (size_t i = 0; i < c_network_io.input_size; i++) {
  132. LiteIO* c_io = c_network_io.inputs + i;
  133. LITE_ASSERT(c_io->name, "input name of io tensor must set.");
  134. network_io.inputs.push_back(
  135. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  136. convert_to_layout(c_io->config_layout)});
  137. }
  138. for (size_t i = 0; i < c_network_io.output_size; i++) {
  139. LiteIO* c_io = c_network_io.outputs + i;
  140. LITE_ASSERT(c_io->name, "output name of io tensor must set.");
  141. network_io.outputs.push_back(
  142. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  143. convert_to_layout(c_io->config_layout)});
  144. }
  145. return network_io;
  146. }
  147. struct InnerIO {
  148. std::vector<std::string> names;
  149. std::vector<LiteIO> inputs;
  150. std::vector<LiteIO> outputs;
  151. };
  152. InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
  153. InnerIO innner_io;
  154. for (size_t i = 0; i < network_io.inputs.size(); i++) {
  155. lite::IO io = network_io.inputs[i];
  156. innner_io.names.push_back(io.name);
  157. innner_io.inputs.push_back(
  158. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  159. convert_to_clayout(io.config_layout)});
  160. }
  161. for (size_t i = 0; i < network_io.outputs.size(); i++) {
  162. lite::IO io = network_io.outputs[i];
  163. innner_io.names.push_back(io.name);
  164. innner_io.outputs.push_back(
  165. {innner_io.names.back().c_str(), io.is_host, io.io_type,
  166. convert_to_clayout(io.config_layout)});
  167. }
  168. return innner_io;
  169. }
  170. lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
  171. lite::ExtraConfig ret;
  172. ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
  173. return ret;
  174. }
  175. int LITE_make_default_network(LiteNetwork* network) {
  176. LITE_CAPI_BEGIN();
  177. LITE_ASSERT(network, "The network pass to LITE api is null");
  178. auto lite_network = std::make_shared<lite::Network>();
  179. LITE_LOCK_GUARD(mtx_network);
  180. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  181. *network = lite_network.get();
  182. LITE_CAPI_END();
  183. }
  184. int LITE_make_network(
  185. LiteNetwork* network, const LiteConfig config, const LiteNetworkIO network_io) {
  186. LITE_CAPI_BEGIN();
  187. LITE_ASSERT(network, "The network pass to LITE api is null");
  188. auto lite_network = std::make_shared<lite::Network>(
  189. convert_to_lite_config(config), convert_to_lite_io(network_io));
  190. LITE_LOCK_GUARD(mtx_network);
  191. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  192. *network = lite_network.get();
  193. LITE_CAPI_END();
  194. }
  195. int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
  196. LITE_CAPI_BEGIN();
  197. LITE_ASSERT(network, "The network pass to LITE api is null");
  198. auto lite_network = std::make_shared<lite::Network>(convert_to_lite_config(config));
  199. LITE_LOCK_GUARD(mtx_network);
  200. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  201. *network = lite_network.get();
  202. LITE_CAPI_END();
  203. }
  204. int LITE_load_model_from_mem(LiteNetwork network, void* model_mem, size_t size) {
  205. LITE_CAPI_BEGIN();
  206. LITE_ASSERT(network, "The network pass to LITE api is null");
  207. LITE_ASSERT(model_mem, "The model memory pass to LITE api is null");
  208. static_cast<lite::Network*>(network)->load_model(model_mem, size);
  209. LITE_CAPI_END();
  210. }
  211. int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
  212. LITE_CAPI_BEGIN();
  213. LITE_ASSERT(network, "The network pass to LITE api is null");
  214. LITE_ASSERT(model_path, "The model path pass to LITE api is null");
  215. static_cast<lite::Network*>(network)->load_model(model_path);
  216. LITE_CAPI_END();
  217. }
  218. int LITE_destroy_network(LiteNetwork network) {
  219. LITE_CAPI_BEGIN();
  220. LITE_ASSERT(network, "The network pass to LITE api is null");
  221. LITE_LOCK_GUARD(mtx_network);
  222. auto& global_holder = get_gloabl_network_holder();
  223. if (global_holder.find(network) != global_holder.end()) {
  224. global_holder.erase(network);
  225. }
  226. LITE_CAPI_END();
  227. }
  228. int LITE_forward(const LiteNetwork network) {
  229. LITE_CAPI_BEGIN();
  230. LITE_ASSERT(network, "The network pass to LITE api is null");
  231. static_cast<lite::Network*>(network)->forward();
  232. LITE_CAPI_END();
  233. }
  234. int LITE_wait(const LiteNetwork network) {
  235. LITE_CAPI_BEGIN();
  236. LITE_ASSERT(network, "The network pass to LITE api is null");
  237. static_cast<lite::Network*>(network)->wait();
  238. LITE_CAPI_END();
  239. }
  240. int LITE_get_io_tensor(
  241. LiteNetwork network, const char* io_name, LiteTensorPhase phase,
  242. LiteTensor* tensor) {
  243. LITE_CAPI_BEGIN();
  244. LITE_ASSERT(network, "The network pass to LITE api is null");
  245. auto io_tensor =
  246. static_cast<lite::Network*>(network)->get_io_tensor(io_name, phase);
  247. *tensor = io_tensor.get();
  248. LITE_CAPI_END();
  249. }
  250. int LITE_get_discrete_tensor(
  251. LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase,
  252. LiteTensor* tensor) {
  253. LITE_CAPI_BEGIN();
  254. LITE_ASSERT(network, "The network pass to LITE api is null");
  255. auto io_tensors =
  256. static_cast<lite::Network*>(network)->get_discrete_tensors(io_name, phase);
  257. LITE_ASSERT(
  258. n_idx < io_tensors.size(), "n_idx should be less than %zu",
  259. io_tensors.size());
  260. *tensor = io_tensors[n_idx].get();
  261. LITE_CAPI_END();
  262. }
  263. int LITE_get_input_name(const LiteNetwork network, size_t index, const char** name) {
  264. LITE_CAPI_BEGIN();
  265. LITE_ASSERT(network && name, "The network pass to LITE api is null");
  266. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  267. ->get_input_name(index);
  268. LITE_CAPI_END();
  269. }
  270. int LITE_get_output_name(const LiteNetwork network, size_t index, const char** name) {
  271. LITE_CAPI_BEGIN();
  272. LITE_ASSERT(network, "The network pass to LITE api is null");
  273. LITE_ASSERT(name, "The name ptr pass to LITE api is null");
  274. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  275. ->get_output_name(index);
  276. LITE_CAPI_END();
  277. }
  278. int LITE_get_all_input_name(
  279. const LiteNetwork network, size_t* size, const char** name) {
  280. LITE_CAPI_BEGIN();
  281. LITE_ASSERT(network, "The network pass to LITE api is null");
  282. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  283. ->get_all_input_name();
  284. if (size)
  285. *size = names.size();
  286. if (name) {
  287. for (auto in_name : names) {
  288. *name = in_name;
  289. name++;
  290. }
  291. }
  292. LITE_CAPI_END();
  293. }
  294. int LITE_get_all_output_name(
  295. const LiteNetwork network, size_t* size, const char** name) {
  296. LITE_CAPI_BEGIN();
  297. LITE_ASSERT(network, "The network pass to LITE api is null");
  298. auto&& names = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  299. ->get_all_output_name();
  300. if (size)
  301. *size = names.size();
  302. if (name) {
  303. for (auto in_name : names) {
  304. *name = in_name;
  305. name++;
  306. }
  307. }
  308. LITE_CAPI_END();
  309. }
  310. int LITE_set_device_id(LiteNetwork network, int device_id) {
  311. LITE_CAPI_BEGIN();
  312. LITE_ASSERT(network, "The network pass to LITE api is null");
  313. static_cast<lite::Network*>(network)->set_device_id(device_id);
  314. LITE_CAPI_END();
  315. }
  316. int LITE_get_device_id(const LiteNetwork network, int* device_id) {
  317. LITE_CAPI_BEGIN();
  318. LITE_ASSERT(network, "The network pass to LITE api is null");
  319. LITE_ASSERT(device_id, "The device_id pass to LITE api is null");
  320. *device_id = static_cast<lite::Network*>(network)->get_device_id();
  321. LITE_CAPI_END();
  322. }
  323. int LITE_set_stream_id(LiteNetwork network, int stream_id) {
  324. LITE_CAPI_BEGIN();
  325. LITE_ASSERT(network, "The network pass to LITE api is null");
  326. static_cast<lite::Network*>(network)->set_stream_id(stream_id);
  327. LITE_CAPI_END();
  328. }
  329. int LITE_get_stream_id(const LiteNetwork network, int* stream_id) {
  330. LITE_CAPI_BEGIN();
  331. LITE_ASSERT(network, "The network pass to LITE api is null");
  332. LITE_ASSERT(stream_id, "The stream_id pass to LITE api is null");
  333. *stream_id = static_cast<lite::Network*>(network)->get_stream_id();
  334. LITE_CAPI_END();
  335. }
  336. int LITE_get_model_extra_info(
  337. const LiteNetwork network, const char** info, int* info_size) {
  338. LITE_CAPI_BEGIN();
  339. LITE_ASSERT(network, "The network pass to LITE api is null");
  340. LITE_ASSERT(info_size, "The info and info_size are all null");
  341. auto& extra_info = static_cast<lite::Network*>(network)->get_model_extra_info();
  342. *info_size = extra_info.size();
  343. *info = extra_info.c_str();
  344. LITE_MARK_USED_VAR(info);
  345. LITE_CAPI_END();
  346. }
  347. int LITE_get_device_type(const LiteNetwork network, LiteDeviceType* device_type) {
  348. LITE_CAPI_BEGIN();
  349. LITE_ASSERT(network, "The network pass to LITE api is null");
  350. LITE_ASSERT(device_type, "The device_type pass to LITE api is null");
  351. *device_type = static_cast<lite::Network*>(network)->get_device_type();
  352. LITE_CAPI_END();
  353. }
  354. int LITE_set_async_callback(
  355. LiteNetwork network, const LiteAsyncCallback async_callback) {
  356. LITE_CAPI_BEGIN();
  357. LITE_ASSERT(network, "The network pass to LITE api is null");
  358. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  359. static_cast<lite::Network*>(network)->set_async_callback(std::move(async_callback));
  360. LITE_CAPI_END();
  361. }
  362. int LITE_set_async_callback_with_userdata(
  363. LiteNetwork network, LiteAsyncCallbackWithData async_callback,
  364. void* user_data) {
  365. LITE_CAPI_BEGIN();
  366. LITE_ASSERT(network, "The network pass to LITE api is null");
  367. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  368. auto lite_async_callback = [async_callback, user_data]() -> void {
  369. async_callback(user_data);
  370. };
  371. static_cast<lite::Network*>(network)->set_async_callback(
  372. std::move(lite_async_callback));
  373. LITE_CAPI_END();
  374. }
  375. int LITE_set_start_callback(
  376. LiteNetwork network, const LiteStartCallback start_callback) {
  377. LITE_CAPI_BEGIN();
  378. LITE_ASSERT(network, "The network pass to LITE api is null");
  379. auto lite_start_callback =
  380. [start_callback](const std::unordered_map<
  381. std::string,
  382. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  383. inputs_map) -> void {
  384. std::vector<LiteIO> ios;
  385. std::vector<LiteTensor> io_tensors;
  386. size_t nr_io = 0;
  387. for (const auto& io : inputs_map) {
  388. nr_io++;
  389. auto&& lite_io = io.second.first;
  390. ios.push_back(
  391. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  392. convert_to_clayout(lite_io.config_layout)});
  393. io_tensors.push_back(io.second.second.get());
  394. }
  395. start_callback(ios.data(), io_tensors.data(), nr_io);
  396. };
  397. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  398. LITE_CAPI_END();
  399. }
  400. int LITE_set_start_callback_with_userdata(
  401. LiteNetwork network, const LiteStartCallbackWithData start_callback,
  402. void* user_data) {
  403. LITE_CAPI_BEGIN();
  404. LITE_ASSERT(network, "The network pass to LITE api is null");
  405. auto lite_start_callback =
  406. [start_callback,
  407. user_data](const std::unordered_map<
  408. std::string,
  409. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& inputs_map)
  410. -> void {
  411. std::vector<LiteIO> ios;
  412. std::vector<LiteTensor> io_tensors;
  413. size_t nr_io = 0;
  414. for (const auto& io : inputs_map) {
  415. nr_io++;
  416. auto&& lite_io = io.second.first;
  417. ios.push_back(
  418. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  419. convert_to_clayout(lite_io.config_layout)});
  420. io_tensors.push_back(io.second.second.get());
  421. }
  422. start_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  423. };
  424. static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback);
  425. LITE_CAPI_END();
  426. }
  427. int LITE_set_finish_callback(
  428. LiteNetwork network, const LiteFinishCallback finish_callback) {
  429. LITE_CAPI_BEGIN();
  430. LITE_ASSERT(network, "The network pass to LITE api is null");
  431. auto lite_finish_callback =
  432. [finish_callback](const std::unordered_map<
  433. std::string,
  434. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  435. outputs_map) -> void {
  436. std::vector<LiteIO> ios;
  437. std::vector<LiteTensor> io_tensors;
  438. size_t nr_io = 0;
  439. for (const auto& io : outputs_map) {
  440. nr_io++;
  441. auto&& lite_io = io.second.first;
  442. ios.push_back(
  443. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  444. convert_to_clayout(lite_io.config_layout)});
  445. io_tensors.push_back(io.second.second.get());
  446. }
  447. finish_callback(ios.data(), io_tensors.data(), nr_io);
  448. };
  449. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  450. LITE_CAPI_END();
  451. }
  452. int LITE_set_finish_callback_with_userdata(
  453. LiteNetwork network, const LiteFinishCallbackWithData finish_callback,
  454. void* user_data) {
  455. LITE_CAPI_BEGIN();
  456. LITE_ASSERT(network, "The network pass to LITE api is null");
  457. auto lite_finish_callback =
  458. [finish_callback,
  459. user_data](const std::unordered_map<
  460. std::string,
  461. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  462. outputs_map) -> void {
  463. std::vector<LiteIO> ios;
  464. std::vector<LiteTensor> io_tensors;
  465. size_t nr_io = 0;
  466. for (const auto& io : outputs_map) {
  467. nr_io++;
  468. auto&& lite_io = io.second.first;
  469. ios.push_back(
  470. {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type,
  471. convert_to_clayout(lite_io.config_layout)});
  472. io_tensors.push_back(io.second.second.get());
  473. }
  474. finish_callback(ios.data(), io_tensors.data(), nr_io, user_data);
  475. };
  476. static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback);
  477. LITE_CAPI_END();
  478. }
  479. int LITE_enable_profile_performance(
  480. LiteNetwork network, const char* profile_json_file_path) {
  481. LITE_CAPI_BEGIN();
  482. LITE_ASSERT(network, "The network pass to LITE api is null");
  483. static_cast<lite::Network*>(network)->enable_profile_performance(
  484. profile_json_file_path);
  485. LITE_CAPI_END();
  486. }
  487. int LITE_is_cpu_inplace_mode(const LiteNetwork network, int* is_cpu_inplace_mode) {
  488. LITE_CAPI_BEGIN();
  489. LITE_ASSERT(network && is_cpu_inplace_mode, "The network pass to LITE api is null");
  490. std::shared_ptr<lite::Network> network_shared{
  491. static_cast<lite::Network*>(network), [](void*) {}};
  492. *is_cpu_inplace_mode = lite::Runtime::is_cpu_inplace_mode(network_shared);
  493. LITE_CAPI_END();
  494. }
  495. int LITE_get_cpu_threads_number(const LiteNetwork network, size_t* nr_threads) {
  496. LITE_CAPI_BEGIN();
  497. LITE_ASSERT(network, "The network pass to LITE api is null");
  498. LITE_ASSERT(nr_threads, "The ptr pass to LITE api is null");
  499. std::shared_ptr<lite::Network> network_shared{
  500. static_cast<lite::Network*>(network), [](void*) {}};
  501. *nr_threads = lite::Runtime::get_cpu_threads_number(network_shared);
  502. LITE_CAPI_END();
  503. }
  504. int LITE_set_cpu_inplace_mode(LiteNetwork network) {
  505. LITE_CAPI_BEGIN();
  506. LITE_ASSERT(network, "The network pass to LITE api is null");
  507. std::shared_ptr<lite::Network> network_shared{
  508. static_cast<lite::Network*>(network), [](void*) {}};
  509. lite::Runtime::set_cpu_inplace_mode(network_shared);
  510. LITE_CAPI_END();
  511. }
  512. int LITE_use_tensorrt(LiteNetwork network) {
  513. LITE_CAPI_BEGIN();
  514. LITE_ASSERT(network, "The network pass to LITE api is null");
  515. std::shared_ptr<lite::Network> network_shared{
  516. static_cast<lite::Network*>(network), [](void*) {}};
  517. lite::Runtime::use_tensorrt(network_shared);
  518. LITE_CAPI_END();
  519. }
  520. int LITE_set_cpu_threads_number(LiteNetwork network, size_t nr_threads) {
  521. LITE_CAPI_BEGIN();
  522. LITE_ASSERT(network, "The network pass to LITE api is null");
  523. std::shared_ptr<lite::Network> network_shared{
  524. static_cast<lite::Network*>(network), [](void*) {}};
  525. lite::Runtime::set_cpu_threads_number(network_shared, nr_threads);
  526. LITE_CAPI_END();
  527. }
  528. int LITE_set_network_algo_policy(LiteNetwork network, LiteAlgoSelectStrategy strategy) {
  529. LITE_CAPI_BEGIN();
  530. LITE_ASSERT(network, "The network pass to LITE api is null");
  531. std::shared_ptr<lite::Network> network_shared{
  532. static_cast<lite::Network*>(network), [](void*) {}};
  533. lite::Runtime::set_network_algo_policy(network_shared, strategy);
  534. LITE_CAPI_END();
  535. }
  536. int LITE_set_network_algo_fastrun_config(
  537. LiteNetwork network, unsigned int shared_batch_size,
  538. int binary_equal_between_batch) {
  539. LITE_CAPI_BEGIN();
  540. LITE_ASSERT(network, "The network pass to LITE api is null");
  541. std::shared_ptr<lite::Network> network_shared{
  542. static_cast<lite::Network*>(network), [](void*) {}};
  543. lite::Runtime::set_network_algo_policy(
  544. network_shared, LiteAlgoSelectStrategy(0), shared_batch_size,
  545. binary_equal_between_batch);
  546. LITE_CAPI_END();
  547. }
  548. int LITE_set_network_algo_workspace_limit(LiteNetwork network, size_t workspace_limit) {
  549. LITE_CAPI_BEGIN();
  550. LITE_ASSERT(network, "The network pass to LITE api is null");
  551. std::shared_ptr<lite::Network> network_shared{
  552. static_cast<lite::Network*>(network), [](void*) {}};
  553. lite::Runtime::set_network_algo_workspace_limit(network_shared, workspace_limit);
  554. LITE_CAPI_END();
  555. }
  556. int LITE_set_runtime_thread_affinity(
  557. LiteNetwork network,
  558. const LiteThreadAffinityCallback thread_affinity_callback) {
  559. LITE_CAPI_BEGIN();
  560. LITE_ASSERT(network, "The network pass to LITE api is null");
  561. std::shared_ptr<lite::Network> network_shared{
  562. static_cast<lite::Network*>(network), [](void*) {}};
  563. lite::Runtime::set_runtime_thread_affinity(
  564. network_shared, std::move(thread_affinity_callback));
  565. LITE_CAPI_END();
  566. }
  567. int LITE_set_memory_allocator(
  568. LiteNetwork network, const LiteAllocate allocate_fun, const LiteFree free_fun) {
  569. LITE_CAPI_BEGIN();
  570. LITE_ASSERT(
  571. network && allocate_fun && free_fun, "The ptr pass to LITE api is null");
  572. std::shared_ptr<lite::Network> network_shared{
  573. static_cast<lite::Network*>(network), [](void*) {}};
  574. lite::Runtime::set_memory_allocator(
  575. network_shared, std::make_shared<UserAllocator>(allocate_fun, free_fun));
  576. LITE_CAPI_END();
  577. }
  578. int LITE_enable_io_txt_dump(LiteNetwork network, const char* io_txt_out_file) {
  579. LITE_CAPI_BEGIN();
  580. LITE_ASSERT(network, "The network pass to LITE api is null");
  581. std::shared_ptr<lite::Network> network_shared{
  582. static_cast<lite::Network*>(network), [](void*) {}};
  583. lite::Runtime::enable_io_txt_dump(network_shared, io_txt_out_file);
  584. LITE_CAPI_END();
  585. }
  586. int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out_dir) {
  587. LITE_CAPI_BEGIN();
  588. LITE_ASSERT(network, "The network pass to LITE api is null");
  589. std::shared_ptr<lite::Network> network_shared{
  590. static_cast<lite::Network*>(network), [](void*) {}};
  591. lite::Runtime::enable_io_bin_dump(network_shared, io_bin_out_dir);
  592. LITE_CAPI_END();
  593. }
  594. int LITE_shared_weight_with_network(
  595. LiteNetwork dst_network, const LiteNetwork src_network) {
  596. LITE_CAPI_BEGIN();
  597. LITE_ASSERT(dst_network && src_network, "The network pass to LITE api is null");
  598. const std::shared_ptr<lite::Network> src_shared_net{
  599. static_cast<lite::Network*>(src_network), [](void*) {}};
  600. std::shared_ptr<lite::Network> dst_shared_net{
  601. static_cast<lite::Network*>(dst_network), [](void*) {}};
  602. lite::Runtime::shared_weight_with_network(dst_shared_net, src_shared_net);
  603. LITE_CAPI_END();
  604. }
  605. int LITE_share_runtime_memroy(LiteNetwork dst_network, LiteNetwork src_network) {
  606. LITE_CAPI_BEGIN();
  607. LITE_ASSERT(src_network && dst_network, "The network pass to LITE api is null");
  608. std::shared_ptr<lite::Network> src_shared{
  609. static_cast<lite::Network*>(src_network), [](void*) {}};
  610. std::shared_ptr<lite::Network> dst_shared{
  611. static_cast<lite::Network*>(dst_network), [](void*) {}};
  612. lite::Runtime::share_runtime_memory_with(dst_shared, src_shared);
  613. LITE_CAPI_END();
  614. }
  615. int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir) {
  616. LITE_CAPI_BEGIN();
  617. #ifndef __IN_TEE_ENV__
  618. #if MGB_ENABLE_JSON
  619. LITE_ASSERT(network, "The network pass to LITE api is null");
  620. static_cast<lite::Network*>(network)->get_static_memory_alloc_info(log_dir);
  621. return 0;
  622. #endif
  623. #endif
  624. LITE_MARK_USED_VAR(network);
  625. LITE_MARK_USED_VAR(log_dir);
  626. LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
  627. LITE_CAPI_END();
  628. }
  629. int LITE_enable_global_layout_transform(LiteNetwork network) {
  630. LITE_CAPI_BEGIN();
  631. LITE_ASSERT(network, "The network pass to LITE api is null");
  632. std::shared_ptr<lite::Network> network_shared{
  633. static_cast<lite::Network*>(network), [](void*) {}};
  634. lite::Runtime::enable_global_layout_transform(network_shared);
  635. LITE_CAPI_END();
  636. }
  637. int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
  638. LITE_CAPI_BEGIN();
  639. LITE_ASSERT(network, "The network pass to LITE api is null");
  640. std::shared_ptr<lite::Network> network_shared{
  641. static_cast<lite::Network*>(network), [](void*) {}};
  642. lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
  643. LITE_CAPI_END();
  644. }
  645. namespace {
  646. static LITE_MUTEX mtx_io;
  647. static std::unordered_map<const void*, InnerIO>& get_global_io_holder() {
  648. static std::unordered_map<const void*, InnerIO> global_holder;
  649. return global_holder;
  650. }
  651. int write_ios_from_cpp_io(
  652. const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) {
  653. LITE_CAPI_BEGIN();
  654. LITE_LOCK_GUARD(mtx_io);
  655. get_global_io_holder()[key] = convert_to_inner_io(cpp_io);
  656. auto&& inner_io = get_global_io_holder()[key];
  657. ios->input_size = inner_io.inputs.size();
  658. ios->output_size = inner_io.outputs.size();
  659. ios->inputs = inner_io.inputs.data();
  660. ios->outputs = inner_io.outputs.data();
  661. size_t i = 0;
  662. for (; i < ios->input_size; i++) {
  663. auto io_ptr = ios->inputs + i;
  664. io_ptr->name = inner_io.names[i].c_str();
  665. }
  666. for (; i < ios->output_size; i++) {
  667. auto io_ptr = ios->outputs + i;
  668. io_ptr->name = inner_io.names[i].c_str();
  669. }
  670. LITE_CAPI_END();
  671. }
  672. } // namespace
  673. int LITE_get_model_io_info_by_path(
  674. const char* model_path, const LiteConfig config, LiteNetworkIO* ios) {
  675. LITE_CAPI_BEGIN();
  676. LITE_ASSERT(model_path, "The model_path pass to LITE api is null");
  677. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  678. std::string{model_path}, convert_to_lite_config(config));
  679. return write_ios_from_cpp_io(
  680. cpp_ios, ios, reinterpret_cast<const void*>(model_path));
  681. LITE_CAPI_END();
  682. }
  683. int LITE_get_model_io_info_by_memory(
  684. const void* model_mem, size_t size, const LiteConfig config,
  685. LiteNetworkIO* ios) {
  686. LITE_CAPI_BEGIN();
  687. LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null");
  688. auto&& cpp_ios = lite::Runtime::get_model_io_info(
  689. model_mem, size, convert_to_lite_config(config));
  690. return write_ios_from_cpp_io(
  691. cpp_ios, ios, reinterpret_cast<const void*>(model_mem));
  692. LITE_CAPI_END();
  693. }
  694. LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
  695. LITE_CAPI_BEGIN();
  696. LITE_ASSERT(network, "The network pass to LITE api is null");
  697. static_cast<lite::Network*>(network)->extra_configure(
  698. convert_extra_config(extra_config));
  699. LITE_CAPI_END();
  700. }
  701. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}