You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #include "lite/tensor.h"
  2. #include "function_base.h"
  3. #include "tensor_impl_base.h"
  4. #if LITE_BUILD_WITH_MGE
  5. #include "megbrain/comp_node.h"
  6. #include "megbrain/tensor.h"
  7. #include "mge/function_dft.h"
  8. #include "mge/tensor_impl.h"
  9. #endif
  10. #include <memory>
  11. using namespace lite;
  12. size_t Layout::get_elem_size() const {
  13. size_t elesize = 1;
  14. switch (data_type) {
  15. case LiteDataType::LITE_INT64:
  16. elesize = 8;
  17. break;
  18. case LiteDataType::LITE_FLOAT:
  19. case LiteDataType::LITE_INT:
  20. case LiteDataType::LITE_UINT:
  21. elesize = 4;
  22. break;
  23. case LiteDataType::LITE_HALF:
  24. case LiteDataType::LITE_INT16:
  25. case LiteDataType::LITE_UINT16:
  26. elesize = 2;
  27. break;
  28. case LiteDataType::LITE_INT8:
  29. case LiteDataType::LITE_UINT8:
  30. elesize = 1;
  31. break;
  32. default:
  33. LITE_THROW("not support data type.");
  34. }
  35. return elesize;
  36. }
  37. bool Layout::operator==(const Layout& other) const {
  38. bool equal = true;
  39. equal &= (ndim == other.ndim);
  40. equal &= (data_type == other.data_type);
  41. for (size_t i = 0; i < ndim; i++) {
  42. equal &= (shapes[i] == other.shapes[i]);
  43. }
  44. return equal;
  45. }
  46. Tensor::~Tensor() = default;
  47. Tensor::Tensor() {
  48. LITE_ERROR_HANDLER_BEGIN
  49. m_tensor_impl =
  50. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  51. "create_tensor");
  52. LITE_ERROR_HANDLER_END
  53. }
  54. Tensor::Tensor(LiteDeviceType device_type, bool is_pinned_host)
  55. : m_is_pinned_host(is_pinned_host), m_device_type(device_type) {
  56. LITE_ERROR_HANDLER_BEGIN
  57. m_tensor_impl =
  58. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  59. "create_tensor", device_type, is_pinned_host);
  60. LITE_ERROR_HANDLER_END
  61. }
  62. Tensor::Tensor(LiteDeviceType device_type, const Layout& layout, bool is_pinned_host)
  63. : m_is_pinned_host(is_pinned_host),
  64. m_layout(layout),
  65. m_device_type(device_type) {
  66. LITE_ERROR_HANDLER_BEGIN
  67. m_tensor_impl =
  68. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  69. "create_tensor", device_type, layout, is_pinned_host);
  70. LITE_ERROR_HANDLER_END
  71. }
  72. Tensor::Tensor(
  73. int device_id, LiteDeviceType device_type, const Layout& layout,
  74. bool is_pinned_host)
  75. : m_is_pinned_host(is_pinned_host),
  76. m_device_id(device_id),
  77. m_layout(layout),
  78. m_device_type(device_type) {
  79. LITE_ERROR_HANDLER_BEGIN
  80. m_tensor_impl =
  81. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  82. "create_tensor", device_id, device_type, layout, is_pinned_host);
  83. LITE_ERROR_HANDLER_END
  84. }
  85. Tensor::Tensor(
  86. int device_id, int stream_id, LiteDeviceType device_type, bool is_pinned_host)
  87. : m_is_pinned_host(is_pinned_host),
  88. m_device_id(device_id),
  89. m_device_type(device_type) {
  90. LITE_ERROR_HANDLER_BEGIN
  91. m_tensor_impl =
  92. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  93. "create_tensor", device_id, stream_id, device_type, is_pinned_host);
  94. LITE_ERROR_HANDLER_END
  95. }
  96. Tensor::Tensor(
  97. LiteBackend backend, LiteDeviceType device_type, int device_id,
  98. const Layout& layout, bool is_pinned_host) {
  99. if (backend == LiteBackend::LITE_DEFAULT) {
  100. m_tensor_impl =
  101. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  102. "create_tensor", device_id, device_type, layout,
  103. is_pinned_host);
  104. } else {
  105. LITE_MARK_USED_VAR(device_type);
  106. LITE_MARK_USED_VAR(is_pinned_host);
  107. LITE_MARK_USED_VAR(layout);
  108. LITE_MARK_USED_VAR(device_id);
  109. LITE_THROW("unknow backend, enum id is : %d.");
  110. }
  111. }
  112. void Tensor::reshape(const std::vector<int>& shape) {
  113. LITE_ASSERT(m_layout.ndim > 0, "The tensor to be reshape is empty.");
  114. uint32_t length = shape.size();
  115. LITE_ASSERT(length < Layout::MAXDIM, "The ndim of reshape input is too large.");
  116. Layout new_layout = m_layout;
  117. new_layout.ndim = length;
  118. size_t total_length = get_tensor_total_size_in_byte() / m_layout.get_elem_size();
  119. uint32_t unfixed_number = 0;
  120. uint32_t unfixed_index = 0;
  121. for (uint32_t i = 0; i < length; i++) {
  122. if (shape[i] == -1) {
  123. unfixed_number += 1;
  124. unfixed_index = i;
  125. } else {
  126. LITE_ASSERT(shape[i] > 0, "The reshape inputs invalid.");
  127. new_layout.shapes[i] = shape[i];
  128. }
  129. }
  130. LITE_ASSERT(unfixed_number <= 1, "The reshape inputs invalid.");
  131. if (unfixed_number) {
  132. size_t left = total_length;
  133. for (uint32_t i = 0; i < length; i++) {
  134. if (i == unfixed_index) {
  135. continue;
  136. } else {
  137. LITE_ASSERT(
  138. left > 0 && (left % new_layout.shapes[i] == 0),
  139. "The reshape inputs invalid.");
  140. left = left / new_layout.shapes[i];
  141. }
  142. }
  143. LITE_ASSERT(left > 0, "The reshape inputs invalid.");
  144. new_layout.shapes[unfixed_index] = left;
  145. }
  146. size_t new_total = 1;
  147. for (uint32_t i = 0; i < length; i++) {
  148. new_total *= new_layout.shapes[i];
  149. }
  150. LITE_ASSERT(new_total == total_length, "The reshape inputs invalid.");
  151. m_layout = new_layout;
  152. m_tensor_impl->reshape(m_layout);
  153. }
  154. size_t Tensor::get_tensor_total_size_in_byte() const {
  155. LITE_ERROR_HANDLER_BEGIN
  156. size_t elemsize = m_layout.get_elem_size();
  157. size_t total = m_layout.ndim == 0 ? 0 : 1;
  158. for (size_t i = 0; i < m_layout.ndim; i++) {
  159. total *= m_layout.shapes[i];
  160. }
  161. return total * elemsize;
  162. LITE_ERROR_HANDLER_END
  163. }
  164. void* Tensor::get_memory_ptr() const {
  165. LITE_ERROR_HANDLER_BEGIN
  166. LITE_ASSERT(m_layout.ndim != 0, "Tensor layout is not valid when get memory ptr.");
  167. return m_tensor_impl->get_memory_ptr();
  168. LITE_ERROR_HANDLER_END
  169. }
  170. void* Tensor::get_memory_ptr(const std::vector<size_t>& idx) const {
  171. LITE_ERROR_HANDLER_BEGIN
  172. return m_tensor_impl->get_memory_ptr(idx);
  173. LITE_ERROR_HANDLER_END
  174. }
  175. std::shared_ptr<Tensor> Tensor::slice(
  176. const std::vector<size_t>& start, const std::vector<size_t>& end,
  177. const std::vector<size_t>& step) {
  178. LITE_ERROR_HANDLER_BEGIN
  179. auto ret = m_tensor_impl->slice(start, end, step);
  180. ret->update_from_implement();
  181. return ret;
  182. LITE_ERROR_HANDLER_END
  183. }
  184. void Tensor::fill_zero() {
  185. LITE_ERROR_HANDLER_BEGIN
  186. LITE_ASSERT(
  187. m_layout.ndim > 0, "fill_zero can't apply on a tensor with empty layout.");
  188. m_tensor_impl->fill_zero();
  189. LITE_ERROR_HANDLER_END
  190. }
  191. void Tensor::share_memory_with(const Tensor& src_tensor) {
  192. LITE_ERROR_HANDLER_BEGIN
  193. LITE_ASSERT(src_tensor.m_layout.ndim > 0, "To be shared tensor with empty layout.");
  194. m_tensor_impl->share_memory_with(src_tensor.m_tensor_impl.get());
  195. update_from_implement();
  196. LITE_ERROR_HANDLER_END
  197. }
  198. void Tensor::set_layout(const Layout& layout) {
  199. LITE_ERROR_HANDLER_BEGIN
  200. m_layout = layout;
  201. m_tensor_impl->set_layout(layout);
  202. LITE_ERROR_HANDLER_END
  203. }
  204. void Tensor::reset(void* prepared_data, size_t data_length_in_byte) {
  205. LITE_ERROR_HANDLER_BEGIN
  206. LITE_ASSERT(m_layout.ndim, "Tensor layout is empty, please reset with layout");
  207. LITE_ASSERT(
  208. data_length_in_byte >= get_tensor_total_size_in_byte(),
  209. "the memory reset to the tensor is too small.");
  210. m_tensor_impl->reset(prepared_data);
  211. LITE_ERROR_HANDLER_END
  212. }
  213. void Tensor::reset(void* prepared_data, const Layout& layout) {
  214. LITE_ERROR_HANDLER_BEGIN
  215. m_layout = layout;
  216. m_tensor_impl->reset(prepared_data, layout);
  217. LITE_ERROR_HANDLER_END
  218. }
  219. bool Tensor::is_continue_memory() const {
  220. LITE_ERROR_HANDLER_BEGIN
  221. return m_tensor_impl->is_continue_memory();
  222. LITE_ERROR_HANDLER_END
  223. }
  224. void Tensor::copy_from(const Tensor& src) {
  225. LITE_ERROR_HANDLER_BEGIN
  226. LITE_ASSERT(
  227. src.get_layout().ndim != 0,
  228. "when tensor copy, the src tensor layout is empty.");
  229. m_tensor_impl->copy_from(src.m_tensor_impl.get());
  230. update_from_implement();
  231. LITE_ERROR_HANDLER_END
  232. }
  233. void Tensor::update_from_implement() {
  234. LITE_ERROR_HANDLER_BEGIN
  235. m_layout = m_tensor_impl->get_layout();
  236. m_device_type = m_tensor_impl->get_device_type();
  237. m_device_id = m_tensor_impl->get_device_id();
  238. m_is_pinned_host = m_tensor_impl->is_pinned_host();
  239. LITE_ERROR_HANDLER_END
  240. }
  241. void LiteAny::type_missmatch(size_t expect, size_t get) const {
  242. LITE_THROW(ssprintf(
  243. "The type store in LiteAny is not match the visit type, type of "
  244. "storage enum is %zu, type of visit enum is %zu.",
  245. expect, get));
  246. }
  247. namespace lite {
  248. #define GET_TYPE(ctype, ENUM) \
  249. template <> \
  250. LiteAny::Type LiteAny::get_type<ctype>() const { \
  251. return ENUM; \
  252. }
  253. GET_TYPE(std::string, STRING)
  254. GET_TYPE(int32_t, INT32)
  255. GET_TYPE(uint32_t, UINT32)
  256. GET_TYPE(int8_t, INT8)
  257. GET_TYPE(uint8_t, UINT8)
  258. GET_TYPE(int64_t, INT64)
  259. GET_TYPE(uint64_t, UINT64)
  260. GET_TYPE(float, FLOAT)
  261. GET_TYPE(bool, BOOL)
  262. GET_TYPE(void*, VOID_PTR)
  263. } // namespace lite
  264. std::shared_ptr<Tensor> TensorUtils::concat(
  265. const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device,
  266. int dst_device_id) {
  267. if (tensors.size() <= 0) {
  268. return std::make_shared<Tensor>();
  269. }
  270. if (dst_device == LiteDeviceType::LITE_DEVICE_DEFAULT) {
  271. dst_device = tensors.front().get_device_type();
  272. }
  273. if (dst_device_id == -1) {
  274. dst_device_id = tensors.front().get_device_id();
  275. }
  276. bool is_pinned_host = tensors.front().is_pinned_host();
  277. auto layout = tensors.front().get_layout();
  278. LITE_ASSERT(static_cast<int>(layout.ndim) > dim, "the dim in concat is error.");
  279. size_t sum_in_dim = layout.shapes[dim];
  280. for (size_t i = 1; i < tensors.size(); ++i) {
  281. auto other_layout = tensors[i].get_layout();
  282. LITE_ASSERT(
  283. other_layout.ndim == layout.ndim,
  284. "the dim size of tensors is not same!");
  285. LITE_ASSERT(
  286. other_layout.data_type == layout.data_type,
  287. "the dtype of tensors is not same!");
  288. for (size_t j = 0; j < other_layout.ndim; ++j) {
  289. if (dim == static_cast<int>(j)) {
  290. sum_in_dim += other_layout.shapes[j];
  291. continue;
  292. }
  293. LITE_ASSERT(
  294. other_layout.shapes[j] == layout.shapes[j],
  295. "the shape of tensors is not same!");
  296. }
  297. }
  298. layout.shapes[dim] = sum_in_dim;
  299. auto result =
  300. std::make_shared<Tensor>(dst_device_id, dst_device, layout, is_pinned_host);
  301. size_t index = 0;
  302. std::vector<size_t> start(dim + 1, 0);
  303. std::vector<size_t> end(dim + 1, 0);
  304. for (int i = 0; i < dim; i++) {
  305. end[i] = layout.shapes[i];
  306. }
  307. for (size_t i = 0; i < tensors.size(); ++i) {
  308. auto&& tensor = tensors[i];
  309. auto layout = tensor.get_layout();
  310. if (layout.shapes[dim] == 0)
  311. continue;
  312. start[dim] = index;
  313. end[dim] = index + layout.shapes[dim];
  314. auto&& sub_dst = result->slice(start, end);
  315. sub_dst->copy_from(tensor);
  316. index += layout.shapes[dim];
  317. }
  318. return result;
  319. }
  320. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}