You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global.cpp 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. /**
  2. * \file lite-c/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/global.h"
  12. #include "common.h"
  13. #include "lite-c/global_c.h"
  14. #include <exception>
  15. #include <mutex>
  16. namespace {
  17. class ErrorMsg {
  18. public:
  19. std::string& get_error_msg() { return error_msg; }
  20. void set_error_msg(const std::string& msg) { error_msg = msg; }
  21. private:
  22. std::string error_msg;
  23. };
  24. ErrorMsg& get_global_error() {
  25. static thread_local ErrorMsg error_msg;
  26. return error_msg;
  27. }
  28. } // namespace
  29. int LiteHandleException(const std::exception& e) {
  30. get_global_error().set_error_msg(e.what());
  31. return -1;
  32. }
  33. const char* LITE_get_last_error() {
  34. return get_global_error().get_error_msg().c_str();
  35. }
  36. int LITE_get_version(int* major, int* minor, int* patch) {
  37. LITE_ASSERT(major && minor && patch, "The ptr pass to LITE api is null");
  38. lite::get_version(*major, *minor, *patch);
  39. return 0;
  40. }
  41. int LITE_get_device_count(LiteDeviceType device_type, size_t* count) {
  42. LITE_CAPI_BEGIN();
  43. LITE_ASSERT(count, "The ptr pass to LITE api is null");
  44. *count = lite::get_device_count(device_type);
  45. LITE_CAPI_END();
  46. }
  47. int LITE_try_coalesce_all_free_memory(){
  48. LITE_CAPI_BEGIN();
  49. lite::try_coalesce_all_free_memory();
  50. LITE_CAPI_END();
  51. }
  52. int LITE_register_decryption_and_key(const char* decrypt_name,
  53. const LiteDecryptionFunc func,
  54. const uint8_t* key_data, size_t key_size) {
  55. LITE_CAPI_BEGIN();
  56. LITE_ASSERT(decrypt_name && key_data && func,
  57. "The ptr pass to LITE api is null");
  58. std::vector<uint8_t> key;
  59. for (size_t i = 0; i < key_size; i++) {
  60. key.push_back(key_data[i]);
  61. }
  62. auto decrypt_func = [func](const void* input_data, size_t input_size,
  63. const std::vector<uint8_t>& key) {
  64. auto size =
  65. func(input_data, input_size, key.data(), key.size(), nullptr);
  66. std::vector<uint8_t> output(size, 0);
  67. func(input_data, input_size, key.data(), key.size(), output.data());
  68. return output;
  69. };
  70. lite::register_decryption_and_key(decrypt_name, decrypt_func, key);
  71. LITE_CAPI_END();
  72. }
  73. int LITE_update_decryption_or_key(const char* decrypt_name,
  74. const LiteDecryptionFunc func,
  75. const uint8_t* key_data, size_t key_size) {
  76. LITE_CAPI_BEGIN();
  77. std::vector<uint8_t> key;
  78. for (size_t i = 0; i < key_size; i++) {
  79. key.push_back(key_data[i]);
  80. }
  81. lite::DecryptionFunc decrypt_func = nullptr;
  82. if (func) {
  83. decrypt_func = [func](const void* input_data, size_t input_size,
  84. const std::vector<uint8_t>& key) {
  85. auto size = func(input_data, input_size, key.data(), key.size(),
  86. nullptr);
  87. std::vector<uint8_t> output(size, 0);
  88. func(input_data, input_size, key.data(), key.size(), output.data());
  89. return output;
  90. };
  91. }
  92. lite::update_decryption_or_key(decrypt_name, decrypt_func, key);
  93. LITE_CAPI_END();
  94. }
  95. int LITE_register_parse_info_func(const char* info_type,
  96. const LiteParseInfoFunc parse_func) {
  97. LITE_CAPI_BEGIN();
  98. LITE_ASSERT(info_type && parse_func, "The ptr pass to LITE api is null");
  99. auto lite_func = [parse_func](
  100. const void* info_data, size_t info_size,
  101. const std::string model_name, lite::Config& config,
  102. lite::NetworkIO& network_io,
  103. std::unordered_map<std::string, lite::LiteAny>&
  104. separate_config_map,
  105. std::string& extra_info) {
  106. LITE_MARK_USED_VAR(extra_info);
  107. size_t nr_threads = 1;
  108. int device_id = 0, is_cpu_inplace_mode = false, use_tensorrt = false;
  109. LiteNetworkIO c_io;
  110. LiteConfig c_config;
  111. auto ret = parse_func(info_data, info_size, model_name.c_str(),
  112. &c_config, &c_io, &device_id, &nr_threads,
  113. &is_cpu_inplace_mode, &use_tensorrt);
  114. config = convert_to_lite_config(c_config);
  115. network_io = convert_to_lite_io(c_io);
  116. if (device_id != 0) {
  117. separate_config_map["device_id"] = device_id;
  118. }
  119. if (nr_threads != 1) {
  120. separate_config_map["nr_threads"] = nr_threads;
  121. }
  122. if (is_cpu_inplace_mode != false) {
  123. separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
  124. }
  125. if (use_tensorrt != false) {
  126. separate_config_map["use_tensorrt"] = use_tensorrt;
  127. }
  128. return ret;
  129. };
  130. lite::register_parse_info_func(info_type, lite_func);
  131. LITE_CAPI_END();
  132. }
  133. int LITE_set_loader_lib_path(const char* loader_path) {
  134. LITE_CAPI_BEGIN();
  135. LITE_ASSERT(loader_path, "The ptr pass to LITE api is null");
  136. lite::set_loader_lib_path(loader_path);
  137. LITE_CAPI_END();
  138. }
  139. int LITE_set_persistent_cache(const char* cache_path, int always_sync) {
  140. LITE_CAPI_BEGIN();
  141. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  142. lite::set_persistent_cache(cache_path, always_sync);
  143. LITE_CAPI_END();
  144. }
  145. int LITE_set_tensor_rt_cache(const char* cache_path) {
  146. LITE_CAPI_BEGIN();
  147. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  148. lite::set_tensor_rt_cache(cache_path);
  149. LITE_CAPI_END();
  150. }
  151. int LITE_set_log_level(LiteLogLevel level) {
  152. LITE_CAPI_BEGIN();
  153. lite::set_log_level(level);
  154. LITE_CAPI_END();
  155. }
  156. int LITE_get_log_level(LiteLogLevel* level) {
  157. LITE_CAPI_BEGIN();
  158. LITE_ASSERT(level, "The ptr pass to LITE api is null");
  159. *level = lite::get_log_level();
  160. LITE_CAPI_END();
  161. }
  162. int LITE_dump_persistent_cache(const char* cache_path) {
  163. LITE_CAPI_BEGIN();
  164. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  165. lite::dump_persistent_cache(cache_path);
  166. LITE_CAPI_END();
  167. }
  168. int LITE_dump_tensor_rt_cache() {
  169. LITE_CAPI_BEGIN();
  170. lite::dump_tensor_rt_cache();
  171. LITE_CAPI_END();
  172. }
  173. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台