You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global.cpp 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #include "lite/global.h"
  2. #include "common.h"
  3. #include "lite-c/global_c.h"
  4. namespace {
  5. class ErrorMsg {
  6. public:
  7. std::string& get_error_msg() { return error_msg; }
  8. ErrorCode get_error_code() { return error_code; }
  9. void set_error_msg(const std::string& msg, ErrorCode code) {
  10. error_msg = msg + ", Error Code: " + std::to_string(code);
  11. error_code = code;
  12. }
  13. void clear_error() {
  14. error_code = ErrorCode::OK;
  15. error_msg.clear();
  16. }
  17. private:
  18. std::string error_msg;
  19. ErrorCode error_code;
  20. };
  21. static LITE_MUTEX mtx_error;
  22. ErrorMsg& get_global_error() {
  23. static ErrorMsg error_msg;
  24. return error_msg;
  25. }
  26. } // namespace
  27. int LiteHandleException(const std::exception& e) {
  28. LITE_LOCK_GUARD(mtx_error);
  29. get_global_error().set_error_msg(e.what(), ErrorCode::LITE_INTERNAL_ERROR);
  30. return -1;
  31. }
  32. ErrorCode LITE_get_last_error_code() {
  33. LITE_LOCK_GUARD(mtx_error);
  34. return get_global_error().get_error_code();
  35. }
  36. void LITE_clear_last_error() {
  37. LITE_LOCK_GUARD(mtx_error);
  38. get_global_error().clear_error();
  39. }
  40. const char* LITE_get_last_error() {
  41. LITE_LOCK_GUARD(mtx_error);
  42. return get_global_error().get_error_msg().c_str();
  43. }
  44. int LITE_get_version(int* major, int* minor, int* patch) {
  45. LITE_ASSERT(major && minor && patch, "The ptr pass to LITE api is null");
  46. lite::get_version(*major, *minor, *patch);
  47. return 0;
  48. }
  49. int LITE_get_device_count(LiteDeviceType device_type, size_t* count) {
  50. LITE_CAPI_BEGIN();
  51. LITE_ASSERT(count, "The ptr pass to LITE api is null");
  52. *count = lite::get_device_count(device_type);
  53. LITE_CAPI_END();
  54. }
  55. int LITE_try_coalesce_all_free_memory() {
  56. LITE_CAPI_BEGIN();
  57. lite::try_coalesce_all_free_memory();
  58. LITE_CAPI_END();
  59. }
  60. int LITE_register_decryption_and_key(
  61. const char* decrypt_name, const LiteDecryptionFunc func,
  62. const uint8_t* key_data, size_t key_size) {
  63. LITE_CAPI_BEGIN();
  64. LITE_ASSERT(decrypt_name && key_data && func, "The ptr pass to LITE api is null");
  65. std::vector<uint8_t> key;
  66. for (size_t i = 0; i < key_size; i++) {
  67. key.push_back(key_data[i]);
  68. }
  69. auto decrypt_func = [func](const void* input_data, size_t input_size,
  70. const std::vector<uint8_t>& key) {
  71. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  72. std::vector<uint8_t> output(size, 0);
  73. func(input_data, input_size, key.data(), key.size(), output.data());
  74. return output;
  75. };
  76. lite::register_decryption_and_key(decrypt_name, decrypt_func, key);
  77. LITE_CAPI_END();
  78. }
  79. int LITE_update_decryption_or_key(
  80. const char* decrypt_name, const LiteDecryptionFunc func,
  81. const uint8_t* key_data, size_t key_size) {
  82. LITE_CAPI_BEGIN();
  83. std::vector<uint8_t> key;
  84. for (size_t i = 0; i < key_size; i++) {
  85. key.push_back(key_data[i]);
  86. }
  87. lite::DecryptionFunc decrypt_func = nullptr;
  88. if (func) {
  89. decrypt_func = [func](const void* input_data, size_t input_size,
  90. const std::vector<uint8_t>& key) {
  91. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  92. std::vector<uint8_t> output(size, 0);
  93. func(input_data, input_size, key.data(), key.size(), output.data());
  94. return output;
  95. };
  96. }
  97. lite::update_decryption_or_key(decrypt_name, decrypt_func, key);
  98. LITE_CAPI_END();
  99. }
  100. int LITE_register_parse_info_func(
  101. const char* info_type, const LiteParseInfoFunc parse_func) {
  102. LITE_CAPI_BEGIN();
  103. LITE_ASSERT(info_type && parse_func, "The ptr pass to LITE api is null");
  104. auto lite_func =
  105. [parse_func](
  106. const void* info_data, size_t info_size,
  107. const std::string model_name, lite::Config& config,
  108. lite::NetworkIO& network_io,
  109. std::unordered_map<std::string, lite::LiteAny>& separate_config_map,
  110. std::string& extra_info) {
  111. LITE_MARK_USED_VAR(extra_info);
  112. size_t nr_threads = 1;
  113. int device_id = 0, is_cpu_inplace_mode = false, use_tensorrt = false;
  114. LiteNetworkIO c_io;
  115. LiteConfig c_config;
  116. auto ret = parse_func(
  117. info_data, info_size, model_name.c_str(), &c_config, &c_io,
  118. &device_id, &nr_threads, &is_cpu_inplace_mode, &use_tensorrt);
  119. config = convert_to_lite_config(c_config);
  120. network_io = convert_to_lite_io(c_io);
  121. if (device_id != 0) {
  122. separate_config_map["device_id"] = device_id;
  123. }
  124. if (nr_threads != 1) {
  125. separate_config_map["nr_threads"] =
  126. static_cast<uint32_t>(nr_threads);
  127. }
  128. if (is_cpu_inplace_mode != false) {
  129. separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
  130. }
  131. if (use_tensorrt != false) {
  132. separate_config_map["use_tensorrt"] = use_tensorrt;
  133. }
  134. return ret;
  135. };
  136. lite::register_parse_info_func(info_type, lite_func);
  137. LITE_CAPI_END();
  138. }
  139. int LITE_set_loader_lib_path(const char* loader_path) {
  140. LITE_CAPI_BEGIN();
  141. LITE_ASSERT(loader_path, "The ptr pass to LITE api is null");
  142. lite::set_loader_lib_path(loader_path);
  143. LITE_CAPI_END();
  144. }
  145. int LITE_set_persistent_cache(const char* cache_path, int always_sync) {
  146. LITE_CAPI_BEGIN();
  147. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  148. lite::set_persistent_cache(cache_path, always_sync);
  149. LITE_CAPI_END();
  150. }
  151. int LITE_set_tensor_rt_cache(const char* cache_path) {
  152. LITE_CAPI_BEGIN();
  153. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  154. lite::set_tensor_rt_cache(cache_path);
  155. LITE_CAPI_END();
  156. }
  157. int LITE_set_log_level(LiteLogLevel level) {
  158. LITE_CAPI_BEGIN();
  159. lite::set_log_level(level);
  160. LITE_CAPI_END();
  161. }
  162. int LITE_get_log_level(LiteLogLevel* level) {
  163. LITE_CAPI_BEGIN();
  164. LITE_ASSERT(level, "The ptr pass to LITE api is null");
  165. *level = lite::get_log_level();
  166. LITE_CAPI_END();
  167. }
  168. int LITE_dump_persistent_cache(const char* cache_path) {
  169. LITE_CAPI_BEGIN();
  170. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  171. lite::dump_persistent_cache(cache_path);
  172. LITE_CAPI_END();
  173. }
  174. int LITE_dump_tensor_rt_cache() {
  175. LITE_CAPI_BEGIN();
  176. lite::dump_tensor_rt_cache();
  177. LITE_CAPI_END();
  178. }
  179. int LITE_register_memory_pair(
  180. void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
  181. LiteBackend backend) {
  182. LITE_CAPI_BEGIN();
  183. lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
  184. LITE_CAPI_END();
  185. }
  186. int LITE_clear_memory_pair(
  187. void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
  188. LITE_CAPI_BEGIN();
  189. lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
  190. LITE_CAPI_END();
  191. }
  192. int LITE_lookup_physic_ptr(
  193. void* vir_ptr, void** phy_ptr, LiteDeviceType device, LiteBackend backend) {
  194. LITE_CAPI_BEGIN();
  195. LITE_ASSERT(vir_ptr && phy_ptr, "The ptr pass to vir and phy is nullptr");
  196. *phy_ptr = lite::lookup_physic_ptr(vir_ptr, device, backend);
  197. LITE_CAPI_END();
  198. }
  199. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}