You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #include "lite/tensor.h"
  2. #include <set>
  3. #include <string>
  4. #include <unordered_map>
  5. #include "../../src/tensor_impl_base.h"
  6. #include "common.h"
  7. #include "lite-c/tensor_c.h"
  8. const LiteLayout default_layout = {
  9. .shapes = {0, 0, 0, 0, 0}, .ndim = 0, .data_type = LiteDataType::LITE_FLOAT};
  10. const LiteTensorDesc default_desc = {
  11. .is_pinned_host = false,
  12. .layout = default_layout,
  13. .device_type = LiteDeviceType::LITE_CPU,
  14. .device_id = 0};
  15. namespace {
  16. static LITE_MUTEX mtx_tensor;
  17. std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() {
  18. static std::unordered_map<void*, std::shared_ptr<lite::Tensor>> global_holder;
  19. return global_holder;
  20. }
  21. static LITE_MUTEX mtx_attr;
  22. std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() {
  23. static std::unordered_map<std::string, lite::LiteAny> global_holder;
  24. return global_holder;
  25. }
  26. } // namespace
  27. //! convert the lite::Layout to Layout
  28. LiteLayout convert_to_clayout(const lite::Layout& layout) {
  29. LiteLayout clayout;
  30. clayout.ndim = layout.ndim;
  31. LITE_ASSERT(layout.ndim < LAYOUT_MAX_DIM, "layout ndim is to large");
  32. for (size_t i = 0; i < layout.ndim; i++) {
  33. clayout.shapes[i] = layout.shapes[i];
  34. }
  35. clayout.data_type = layout.data_type;
  36. return clayout;
  37. }
  38. //! convert the C Layout to lite::Layout
  39. lite::Layout convert_to_layout(const LiteLayout& clayout) {
  40. lite::Layout layout;
  41. layout.ndim = clayout.ndim;
  42. LITE_ASSERT(layout.ndim < LAYOUT_MAX_DIM, "clayout ndim is to large");
  43. for (size_t i = 0; i < layout.ndim; i++) {
  44. layout.shapes[i] = clayout.shapes[i];
  45. }
  46. layout.data_type = clayout.data_type;
  47. return layout;
  48. }
  49. int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
  50. LITE_CAPI_BEGIN();
  51. LITE_ASSERT(tensor, "The tensor pass to LITE_make_tensor is null");
  52. lite::Layout layout = convert_to_layout(tensor_describe.layout);
  53. auto lite_tensor = std::make_shared<lite::Tensor>(
  54. tensor_describe.device_id, tensor_describe.device_type, layout,
  55. tensor_describe.is_pinned_host);
  56. {
  57. LITE_LOCK_GUARD(mtx_tensor);
  58. get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
  59. }
  60. *tensor = lite_tensor.get();
  61. LITE_CAPI_END();
  62. }
  63. int LITE_destroy_tensor(LiteTensor tensor) {
  64. LITE_CAPI_BEGIN();
  65. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  66. LITE_LOCK_GUARD(mtx_tensor);
  67. auto& global_holder = get_global_tensor_holder();
  68. if (global_holder.find(tensor) != global_holder.end()) {
  69. global_holder.erase(tensor);
  70. } else {
  71. //! return -1, means the tensor has been destroyed.
  72. return -1;
  73. }
  74. LITE_CAPI_END();
  75. }
  76. int LITE_set_tensor_layout(LiteTensor tensor, const LiteLayout layout) {
  77. LITE_CAPI_BEGIN();
  78. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  79. auto tensor_ptr = static_cast<lite::Tensor*>(tensor);
  80. tensor_ptr->set_layout(convert_to_layout(layout));
  81. LITE_CAPI_END();
  82. }
  83. int LITE_reset_tensor_memory(
  84. LiteTensor tensor, void* prepared_data, size_t data_length_in_byte) {
  85. LITE_CAPI_BEGIN();
  86. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  87. LITE_ASSERT(prepared_data, "The prepared_data pass to LITE c_api is null");
  88. static_cast<lite::Tensor*>(tensor)->reset(prepared_data, data_length_in_byte);
  89. LITE_CAPI_END();
  90. }
  91. int LITE_reset_tensor(LiteTensor tensor, const LiteLayout layout, void* prepared_data) {
  92. LITE_CAPI_BEGIN();
  93. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  94. LITE_ASSERT(prepared_data, "The prepared_data pass to LITE c_api is null");
  95. static_cast<lite::Tensor*>(tensor)->reset(prepared_data, convert_to_layout(layout));
  96. LITE_CAPI_END();
  97. }
  98. int LITE_tensor_reshape(LiteTensor tensor, const int* shape, int size) {
  99. LITE_CAPI_BEGIN();
  100. LITE_ASSERT(tensor && shape, "The tensor pass to LITE c_api is null");
  101. std::vector<int> shapes;
  102. for (int i = 0; i < size; i++) {
  103. shapes.push_back(shape[i]);
  104. }
  105. static_cast<lite::Tensor*>(tensor)->reshape(shapes);
  106. LITE_CAPI_END();
  107. }
  108. int LITE_tensor_slice(
  109. const LiteTensor tensor, const size_t* start, const size_t* end,
  110. const size_t* step, size_t size, LiteTensor* slice_tensor) {
  111. LITE_CAPI_BEGIN();
  112. LITE_ASSERT(
  113. tensor && start && end && slice_tensor,
  114. "The tensor pass to LITE c_api is null");
  115. std::vector<size_t> starts, ends, steps;
  116. for (size_t i = 0; i < size; i++) {
  117. starts.push_back(start[i]);
  118. ends.push_back(end[i]);
  119. if (step) {
  120. steps.push_back(step[i]);
  121. }
  122. }
  123. auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps);
  124. {
  125. LITE_LOCK_GUARD(mtx_tensor);
  126. get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
  127. }
  128. *slice_tensor = ret_tensor.get();
  129. LITE_CAPI_END();
  130. }
  131. int LITE_tensor_fill_zero(LiteTensor tensor) {
  132. LITE_CAPI_BEGIN();
  133. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  134. static_cast<lite::Tensor*>(tensor)->fill_zero();
  135. LITE_CAPI_END();
  136. }
  137. int LITE_tensor_copy(LiteTensor dst_tensor, const LiteTensor src_tensor) {
  138. LITE_CAPI_BEGIN();
  139. LITE_ASSERT(dst_tensor && src_tensor, "The tensor pass to LITE c_api is null");
  140. static_cast<lite::Tensor*>(dst_tensor)
  141. ->copy_from(*static_cast<lite::Tensor*>(src_tensor));
  142. LITE_CAPI_END();
  143. }
  144. int LITE_tensor_share_memory_with(LiteTensor dst_tensor, const LiteTensor src_tensor) {
  145. LITE_CAPI_BEGIN();
  146. LITE_ASSERT(dst_tensor && src_tensor, "The tensor pass to LITE c_api is null");
  147. static_cast<lite::Tensor*>(dst_tensor)
  148. ->share_memory_with(*static_cast<lite::Tensor*>(src_tensor));
  149. LITE_CAPI_END();
  150. }
  151. int LITE_get_tensor_memory(const LiteTensor tensor, void** data) {
  152. LITE_CAPI_BEGIN();
  153. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  154. LITE_ASSERT(data, "The data ptr pass to LITE c_api is null");
  155. *data = static_cast<lite::Tensor*>(tensor)->get_memory_ptr();
  156. LITE_CAPI_END();
  157. }
  158. int LITE_get_tensor_memory_with_index(
  159. const LiteTensor tensor, const size_t* index, size_t size, void** data) {
  160. LITE_CAPI_BEGIN();
  161. LITE_ASSERT(tensor && index && data, "The tensor pass to LITE c_api is null");
  162. std::vector<size_t> index_v;
  163. for (size_t i = 0; i < size; i++) {
  164. index_v.push_back(index[i]);
  165. }
  166. *data = static_cast<lite::Tensor*>(tensor)->get_memory_ptr(index_v);
  167. LITE_CAPI_END();
  168. }
  169. int LITE_get_tensor_total_size_in_byte(const LiteTensor tensor, size_t* size) {
  170. LITE_CAPI_BEGIN();
  171. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  172. LITE_ASSERT(size, "The size ptr pass to LITE c_api is null");
  173. *size = static_cast<lite::Tensor*>(tensor)->get_tensor_total_size_in_byte();
  174. LITE_CAPI_END();
  175. }
  176. int LITE_get_tensor_layout(const LiteTensor tensor, LiteLayout* layout) {
  177. LITE_CAPI_BEGIN();
  178. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  179. LITE_ASSERT(layout, "The layout ptr pass to LITE c_api is null");
  180. *layout = convert_to_clayout(static_cast<lite::Tensor*>(tensor)->get_layout());
  181. LITE_CAPI_END();
  182. }
  183. int LITE_get_tensor_device_type(const LiteTensor tensor, LiteDeviceType* device_type) {
  184. LITE_CAPI_BEGIN();
  185. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  186. LITE_ASSERT(device_type, "The device ptr pass to LITE c_api is null");
  187. *device_type = static_cast<lite::Tensor*>(tensor)->get_device_type();
  188. LITE_CAPI_END();
  189. }
  190. int LITE_get_tensor_device_id(const LiteTensor tensor, int* device_id) {
  191. LITE_CAPI_BEGIN();
  192. LITE_ASSERT(tensor && device_id, "The tensor pass to LITE c_api is null");
  193. *device_id = static_cast<lite::Tensor*>(tensor)->get_device_id();
  194. LITE_CAPI_END();
  195. }
  196. int LITE_is_pinned_host(const LiteTensor tensor, int* is_pinned_host) {
  197. LITE_CAPI_BEGIN();
  198. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  199. LITE_ASSERT(is_pinned_host, "The is_pinned_host ptr pass to LITE c_api is null");
  200. *is_pinned_host = static_cast<lite::Tensor*>(tensor)->is_pinned_host();
  201. LITE_CAPI_END();
  202. }
  203. int LITE_is_memory_continue(const LiteTensor tensor, int* is_continue) {
  204. LITE_CAPI_BEGIN();
  205. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  206. LITE_ASSERT(is_continue, "The is_continue ptr pass to LITE c_api is null");
  207. *is_continue = static_cast<lite::Tensor*>(tensor)->is_continue_memory();
  208. LITE_CAPI_END();
  209. }
  210. int LITE_tensor_concat(
  211. LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device,
  212. int device_id, LiteTensor* result_tensor) {
  213. LITE_CAPI_BEGIN();
  214. LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null");
  215. std::vector<lite::Tensor> v_tensors;
  216. for (int i = 0; i < nr_tensor; i++) {
  217. v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i]));
  218. }
  219. auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id);
  220. {
  221. LITE_LOCK_GUARD(mtx_tensor);
  222. get_global_tensor_holder()[tensor.get()] = tensor;
  223. }
  224. *result_tensor = tensor.get();
  225. LITE_CAPI_END()
  226. }
  227. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}