You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #include "lite/tensor.h"
  2. #include <set>
  3. #include <string>
  4. #include <unordered_map>
  5. #include "../../src/tensor_impl_base.h"
  6. #include "common.h"
  7. #include "lite-c/tensor_c.h"
  8. const LiteLayout default_layout = {
  9. .shapes = {0, 0, 0, 0, 0}, .ndim = 0, .data_type = LiteDataType::LITE_FLOAT};
  10. const LiteTensorDesc default_desc = {
  11. .is_pinned_host = false,
  12. .layout = default_layout,
  13. .device_type = LiteDeviceType::LITE_CPU,
  14. .device_id = 0};
  15. namespace {
  16. static LITE_MUTEX mtx_tensor;
  17. std::unordered_map<void*, std::shared_ptr<lite::Tensor>>& get_global_tensor_holder() {
  18. static std::unordered_map<void*, std::shared_ptr<lite::Tensor>> global_holder;
  19. return global_holder;
  20. }
  21. static LITE_MUTEX mtx_attr;
  22. std::unordered_map<std::string, lite::LiteAny>& get_global_tensor_attr_holder() {
  23. static std::unordered_map<std::string, lite::LiteAny> global_holder;
  24. return global_holder;
  25. }
  26. } // namespace
  27. //! convert the lite::Layout to Layout
  28. LiteLayout convert_to_clayout(const lite::Layout& layout) {
  29. LiteLayout clayout;
  30. clayout.ndim = layout.ndim;
  31. LITE_ASSERT(layout.ndim < LAYOUT_MAX_DIM, "layout ndim is to large");
  32. for (size_t i = 0; i < layout.ndim; i++) {
  33. clayout.shapes[i] = layout.shapes[i];
  34. }
  35. clayout.data_type = layout.data_type;
  36. return clayout;
  37. }
  38. //! convert the C Layout to lite::Layout
  39. lite::Layout convert_to_layout(const LiteLayout& clayout) {
  40. lite::Layout layout;
  41. layout.ndim = clayout.ndim;
  42. LITE_ASSERT(layout.ndim < LAYOUT_MAX_DIM, "clayout ndim is to large");
  43. for (size_t i = 0; i < layout.ndim; i++) {
  44. layout.shapes[i] = clayout.shapes[i];
  45. }
  46. layout.data_type = clayout.data_type;
  47. return layout;
  48. }
  49. int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) {
  50. LITE_CAPI_BEGIN();
  51. LITE_ASSERT(tensor, "The tensor pass to LITE_make_tensor is null");
  52. lite::Layout layout = convert_to_layout(tensor_describe.layout);
  53. auto lite_tensor = std::make_shared<lite::Tensor>(
  54. tensor_describe.device_id, tensor_describe.device_type, layout,
  55. tensor_describe.is_pinned_host);
  56. LITE_LOCK_GUARD(mtx_tensor);
  57. get_global_tensor_holder()[lite_tensor.get()] = lite_tensor;
  58. *tensor = lite_tensor.get();
  59. LITE_CAPI_END();
  60. }
  61. int LITE_destroy_tensor(LiteTensor tensor) {
  62. LITE_CAPI_BEGIN();
  63. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  64. LITE_LOCK_GUARD(mtx_tensor);
  65. get_global_tensor_holder().erase(tensor);
  66. LITE_CAPI_END();
  67. }
  68. int LITE_set_tensor_layout(LiteTensor tensor, const LiteLayout layout) {
  69. LITE_CAPI_BEGIN();
  70. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  71. auto tensor_ptr = static_cast<lite::Tensor*>(tensor);
  72. tensor_ptr->set_layout(convert_to_layout(layout));
  73. LITE_CAPI_END();
  74. }
  75. int LITE_reset_tensor_memory(
  76. LiteTensor tensor, void* prepared_data, size_t data_length_in_byte) {
  77. LITE_CAPI_BEGIN();
  78. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  79. LITE_ASSERT(prepared_data, "The prepared_data pass to LITE c_api is null");
  80. static_cast<lite::Tensor*>(tensor)->reset(prepared_data, data_length_in_byte);
  81. LITE_CAPI_END();
  82. }
  83. int LITE_reset_tensor(LiteTensor tensor, const LiteLayout layout, void* prepared_data) {
  84. LITE_CAPI_BEGIN();
  85. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  86. LITE_ASSERT(prepared_data, "The prepared_data pass to LITE c_api is null");
  87. static_cast<lite::Tensor*>(tensor)->reset(prepared_data, convert_to_layout(layout));
  88. LITE_CAPI_END();
  89. }
  90. int LITE_tensor_reshape(LiteTensor tensor, const int* shape, int size) {
  91. LITE_CAPI_BEGIN();
  92. LITE_ASSERT(tensor && shape, "The tensor pass to LITE c_api is null");
  93. std::vector<int> shapes;
  94. for (int i = 0; i < size; i++) {
  95. shapes.push_back(shape[i]);
  96. }
  97. static_cast<lite::Tensor*>(tensor)->reshape(shapes);
  98. LITE_CAPI_END();
  99. }
  100. int LITE_tensor_slice(
  101. const LiteTensor tensor, const size_t* start, const size_t* end,
  102. const size_t* step, size_t size, LiteTensor* slice_tensor) {
  103. LITE_CAPI_BEGIN();
  104. LITE_ASSERT(
  105. tensor && start && end && slice_tensor,
  106. "The tensor pass to LITE c_api is null");
  107. std::vector<size_t> starts, ends, steps;
  108. for (size_t i = 0; i < size; i++) {
  109. starts.push_back(start[i]);
  110. ends.push_back(end[i]);
  111. if (step) {
  112. steps.push_back(step[i]);
  113. }
  114. }
  115. auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps);
  116. LITE_LOCK_GUARD(mtx_tensor);
  117. get_global_tensor_holder()[ret_tensor.get()] = ret_tensor;
  118. *slice_tensor = ret_tensor.get();
  119. LITE_CAPI_END();
  120. }
  121. int LITE_tensor_fill_zero(LiteTensor tensor) {
  122. LITE_CAPI_BEGIN();
  123. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  124. static_cast<lite::Tensor*>(tensor)->fill_zero();
  125. LITE_CAPI_END();
  126. }
  127. int LITE_tensor_copy(LiteTensor dst_tensor, const LiteTensor src_tensor) {
  128. LITE_CAPI_BEGIN();
  129. LITE_ASSERT(dst_tensor && src_tensor, "The tensor pass to LITE c_api is null");
  130. static_cast<lite::Tensor*>(dst_tensor)
  131. ->copy_from(*static_cast<lite::Tensor*>(src_tensor));
  132. LITE_CAPI_END();
  133. }
  134. int LITE_tensor_share_memory_with(LiteTensor dst_tensor, const LiteTensor src_tensor) {
  135. LITE_CAPI_BEGIN();
  136. LITE_ASSERT(dst_tensor && src_tensor, "The tensor pass to LITE c_api is null");
  137. static_cast<lite::Tensor*>(dst_tensor)
  138. ->share_memory_with(*static_cast<lite::Tensor*>(src_tensor));
  139. LITE_CAPI_END();
  140. }
  141. int LITE_get_tensor_memory(const LiteTensor tensor, void** data) {
  142. LITE_CAPI_BEGIN();
  143. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  144. LITE_ASSERT(data, "The data ptr pass to LITE c_api is null");
  145. *data = static_cast<lite::Tensor*>(tensor)->get_memory_ptr();
  146. LITE_CAPI_END();
  147. }
  148. int LITE_get_tensor_memory_with_index(
  149. const LiteTensor tensor, const size_t* index, size_t size, void** data) {
  150. LITE_CAPI_BEGIN();
  151. LITE_ASSERT(tensor && index && data, "The tensor pass to LITE c_api is null");
  152. std::vector<size_t> index_v;
  153. for (size_t i = 0; i < size; i++) {
  154. index_v.push_back(index[i]);
  155. }
  156. *data = static_cast<lite::Tensor*>(tensor)->get_memory_ptr(index_v);
  157. LITE_CAPI_END();
  158. }
  159. int LITE_get_tensor_total_size_in_byte(const LiteTensor tensor, size_t* size) {
  160. LITE_CAPI_BEGIN();
  161. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  162. LITE_ASSERT(size, "The size ptr pass to LITE c_api is null");
  163. *size = static_cast<lite::Tensor*>(tensor)->get_tensor_total_size_in_byte();
  164. LITE_CAPI_END();
  165. }
  166. int LITE_get_tensor_layout(const LiteTensor tensor, LiteLayout* layout) {
  167. LITE_CAPI_BEGIN();
  168. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  169. LITE_ASSERT(layout, "The layout ptr pass to LITE c_api is null");
  170. *layout = convert_to_clayout(static_cast<lite::Tensor*>(tensor)->get_layout());
  171. LITE_CAPI_END();
  172. }
  173. int LITE_get_tensor_device_type(const LiteTensor tensor, LiteDeviceType* device_type) {
  174. LITE_CAPI_BEGIN();
  175. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  176. LITE_ASSERT(device_type, "The device ptr pass to LITE c_api is null");
  177. *device_type = static_cast<lite::Tensor*>(tensor)->get_device_type();
  178. LITE_CAPI_END();
  179. }
  180. int LITE_get_tensor_device_id(const LiteTensor tensor, int* device_id) {
  181. LITE_CAPI_BEGIN();
  182. LITE_ASSERT(tensor && device_id, "The tensor pass to LITE c_api is null");
  183. *device_id = static_cast<lite::Tensor*>(tensor)->get_device_id();
  184. LITE_CAPI_END();
  185. }
  186. int LITE_is_pinned_host(const LiteTensor tensor, int* is_pinned_host) {
  187. LITE_CAPI_BEGIN();
  188. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  189. LITE_ASSERT(is_pinned_host, "The is_pinned_host ptr pass to LITE c_api is null");
  190. *is_pinned_host = static_cast<lite::Tensor*>(tensor)->is_pinned_host();
  191. LITE_CAPI_END();
  192. }
  193. int LITE_is_memory_continue(const LiteTensor tensor, int* is_continue) {
  194. LITE_CAPI_BEGIN();
  195. LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null");
  196. LITE_ASSERT(is_continue, "The is_continue ptr pass to LITE c_api is null");
  197. *is_continue = static_cast<lite::Tensor*>(tensor)->is_continue_memory();
  198. LITE_CAPI_END();
  199. }
  200. int LITE_tensor_concat(
  201. LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device,
  202. int device_id, LiteTensor* result_tensor) {
  203. LITE_CAPI_BEGIN();
  204. std::vector<lite::Tensor> v_tensors;
  205. for (int i = 0; i < nr_tensor; i++) {
  206. v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i]));
  207. }
  208. auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id);
  209. get_global_tensor_holder()[tensor.get()] = tensor;
  210. *result_tensor = tensor.get();
  211. LITE_CAPI_END()
  212. }
  213. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}