You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. /**
  2. * \file src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/tensor.h"
  12. #include "function_base.h"
  13. #include "tensor_impl_base.h"
  14. #if LITE_BUILD_WITH_MGE
  15. #include "megbrain/comp_node.h"
  16. #include "megbrain/tensor.h"
  17. #include "mge/function_dft.h"
  18. #include "mge/tensor_impl.h"
  19. #endif
  20. #include <memory>
  21. using namespace lite;
  22. size_t Layout::get_elem_size() const {
  23. size_t elesize = 1;
  24. switch (data_type) {
  25. case LiteDataType::LITE_INT64:
  26. elesize = 8;
  27. break;
  28. case LiteDataType::LITE_FLOAT:
  29. case LiteDataType::LITE_INT:
  30. case LiteDataType::LITE_UINT:
  31. elesize = 4;
  32. break;
  33. case LiteDataType::LITE_HALF:
  34. case LiteDataType::LITE_INT16:
  35. case LiteDataType::LITE_UINT16:
  36. elesize = 2;
  37. break;
  38. case LiteDataType::LITE_INT8:
  39. case LiteDataType::LITE_UINT8:
  40. elesize = 1;
  41. break;
  42. default:
  43. LITE_THROW("not support data type.");
  44. }
  45. return elesize;
  46. }
  47. bool Layout::operator==(const Layout& other) const {
  48. bool equal = true;
  49. equal &= (ndim == other.ndim);
  50. equal &= (data_type == other.data_type);
  51. for (size_t i = 0; i < ndim; i++) {
  52. equal &= (shapes[i] == other.shapes[i]);
  53. }
  54. return equal;
  55. }
  56. Tensor::~Tensor() = default;
  57. Tensor::Tensor() {
  58. LITE_ERROR_HANDLER_BEGIN
  59. m_tensor_impl = call_func<TensorImplDft,
  60. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  61. "create_tensor");
  62. LITE_ERROR_HANDLER_END
  63. }
  64. Tensor::Tensor(LiteDeviceType device_type, bool is_pinned_host)
  65. : m_is_pinned_host(is_pinned_host), m_device_type(device_type) {
  66. LITE_ERROR_HANDLER_BEGIN
  67. m_tensor_impl = call_func<TensorImplDft,
  68. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  69. "create_tensor", device_type, is_pinned_host);
  70. LITE_ERROR_HANDLER_END
  71. }
  72. Tensor::Tensor(LiteDeviceType device_type, const Layout& layout,
  73. bool is_pinned_host)
  74. : m_is_pinned_host(is_pinned_host),
  75. m_layout(layout),
  76. m_device_type(device_type) {
  77. LITE_ERROR_HANDLER_BEGIN
  78. m_tensor_impl = call_func<TensorImplDft,
  79. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  80. "create_tensor", device_type, layout, is_pinned_host);
  81. LITE_ERROR_HANDLER_END
  82. }
  83. Tensor::Tensor(int device_id, LiteDeviceType device_type, const Layout& layout,
  84. bool is_pinned_host)
  85. : m_is_pinned_host(is_pinned_host),
  86. m_device_id(device_id),
  87. m_layout(layout),
  88. m_device_type(device_type) {
  89. LITE_ERROR_HANDLER_BEGIN
  90. m_tensor_impl = call_func<TensorImplDft,
  91. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  92. "create_tensor", device_id, device_type, layout, is_pinned_host);
  93. LITE_ERROR_HANDLER_END
  94. }
  95. Tensor::Tensor(int device_id, int stream_id, LiteDeviceType device_type,
  96. bool is_pinned_host)
  97. : m_is_pinned_host(is_pinned_host),
  98. m_device_id(device_id),
  99. m_device_type(device_type) {
  100. LITE_ERROR_HANDLER_BEGIN
  101. m_tensor_impl = call_func<TensorImplDft,
  102. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  103. "create_tensor", device_id, stream_id, device_type, is_pinned_host);
  104. LITE_ERROR_HANDLER_END
  105. }
  106. Tensor::Tensor(LiteBackend backend, LiteDeviceType device_type, int device_id,
  107. const Layout& layout, bool is_pinned_host) {
  108. if (backend == LiteBackend::LITE_DEFAULT) {
  109. m_tensor_impl =
  110. call_func<TensorImplDft,
  111. std::shared_ptr<lite::Tensor::TensorImplBase>>(
  112. "create_tensor", device_id, device_type, layout,
  113. is_pinned_host);
  114. } else {
  115. LITE_MARK_USED_VAR(device_type);
  116. LITE_MARK_USED_VAR(is_pinned_host);
  117. LITE_MARK_USED_VAR(layout);
  118. LITE_MARK_USED_VAR(device_id);
  119. LITE_THROW("unknow backend, enum id is : %d.");
  120. }
  121. }
  122. void Tensor::reshape(const std::vector<int>& shape) {
  123. LITE_ASSERT(m_layout.ndim > 0, "The tensor to be reshape is empty.");
  124. uint32_t length = shape.size();
  125. LITE_ASSERT(length < Layout::MAXDIM,
  126. "The ndim of reshape input is too large.");
  127. Layout new_layout = m_layout;
  128. new_layout.ndim = length;
  129. size_t total_length =
  130. get_tensor_total_size_in_byte() / m_layout.get_elem_size();
  131. uint32_t unfixed_number = 0;
  132. uint32_t unfixed_index = 0;
  133. for (uint32_t i = 0; i < length; i++) {
  134. if (shape[i] == -1) {
  135. unfixed_number += 1;
  136. unfixed_index = i;
  137. } else {
  138. LITE_ASSERT(shape[i] > 0, "The reshape inputs invalid.");
  139. new_layout.shapes[i] = shape[i];
  140. }
  141. }
  142. LITE_ASSERT(unfixed_number <= 1, "The reshape inputs invalid.");
  143. if (unfixed_number) {
  144. size_t left = total_length;
  145. for (uint32_t i = 0; i < length; i++) {
  146. if (i == unfixed_index) {
  147. continue;
  148. } else {
  149. LITE_ASSERT(left > 0 && (left % new_layout.shapes[i] == 0),
  150. "The reshape inputs invalid.");
  151. left = left / new_layout.shapes[i];
  152. }
  153. }
  154. LITE_ASSERT(left > 0, "The reshape inputs invalid.");
  155. new_layout.shapes[unfixed_index] = left;
  156. }
  157. size_t new_total = 1;
  158. for (uint32_t i = 0; i < length; i++) {
  159. new_total *= new_layout.shapes[i];
  160. }
  161. LITE_ASSERT(new_total == total_length, "The reshape inputs invalid.");
  162. m_layout = new_layout;
  163. m_tensor_impl->reshape(m_layout);
  164. }
  165. size_t Tensor::get_tensor_total_size_in_byte() const {
  166. LITE_ERROR_HANDLER_BEGIN
  167. size_t elemsize = m_layout.get_elem_size();
  168. size_t total = m_layout.ndim == 0 ? 0 : 1;
  169. for (size_t i = 0; i < m_layout.ndim; i++) {
  170. total *= m_layout.shapes[i];
  171. }
  172. return total * elemsize;
  173. LITE_ERROR_HANDLER_END
  174. }
  175. void* Tensor::get_memory_ptr() const {
  176. LITE_ERROR_HANDLER_BEGIN
  177. LITE_ASSERT(m_layout.ndim != 0,
  178. "Tensor layout is not valid when get memory ptr.");
  179. return m_tensor_impl->get_memory_ptr();
  180. LITE_ERROR_HANDLER_END
  181. }
  182. void* Tensor::get_memory_ptr(const std::vector<size_t>& idx) const {
  183. LITE_ERROR_HANDLER_BEGIN
  184. return m_tensor_impl->get_memory_ptr(idx);
  185. LITE_ERROR_HANDLER_END
  186. }
  187. std::shared_ptr<Tensor> Tensor::slice(const std::vector<size_t>& start,
  188. const std::vector<size_t>& end,
  189. const std::vector<size_t>& step) {
  190. LITE_ERROR_HANDLER_BEGIN
  191. auto ret = m_tensor_impl->slice(start, end, step);
  192. ret->update_from_implement();
  193. return ret;
  194. LITE_ERROR_HANDLER_END
  195. }
  196. void Tensor::fill_zero() {
  197. LITE_ERROR_HANDLER_BEGIN
  198. LITE_ASSERT(m_layout.ndim > 0,
  199. "fill_zero can't apply on a tensor with empty layout.");
  200. m_tensor_impl->fill_zero();
  201. LITE_ERROR_HANDLER_END
  202. }
  203. void Tensor::share_memory_with(const Tensor& src_tensor) {
  204. LITE_ERROR_HANDLER_BEGIN
  205. LITE_ASSERT(src_tensor.m_layout.ndim > 0,
  206. "To be shared tensor with empty layout.");
  207. m_tensor_impl->share_memory_with(src_tensor.m_tensor_impl.get());
  208. update_from_implement();
  209. LITE_ERROR_HANDLER_END
  210. }
  211. void Tensor::set_layout(const Layout& layout) {
  212. LITE_ERROR_HANDLER_BEGIN
  213. m_layout = layout;
  214. m_tensor_impl->set_layout(layout);
  215. LITE_ERROR_HANDLER_END
  216. }
  217. void Tensor::reset(void* prepared_data, size_t data_length_in_byte) {
  218. LITE_ERROR_HANDLER_BEGIN
  219. LITE_ASSERT(m_layout.ndim,
  220. "Tensor layout is empty, please reset with layout");
  221. LITE_ASSERT(data_length_in_byte >= get_tensor_total_size_in_byte(),
  222. "the memory reset to the tensor is too small.");
  223. m_tensor_impl->reset(prepared_data);
  224. LITE_ERROR_HANDLER_END
  225. }
  226. void Tensor::reset(void* prepared_data, const Layout& layout) {
  227. LITE_ERROR_HANDLER_BEGIN
  228. m_layout = layout;
  229. m_tensor_impl->reset(prepared_data, layout);
  230. LITE_ERROR_HANDLER_END
  231. }
  232. bool Tensor::is_continue_memory() const {
  233. LITE_ERROR_HANDLER_BEGIN
  234. return m_tensor_impl->is_continue_memory();
  235. LITE_ERROR_HANDLER_END
  236. }
  237. void Tensor::copy_from(const Tensor& src) {
  238. LITE_ERROR_HANDLER_BEGIN
  239. LITE_ASSERT(src.get_layout().ndim != 0,
  240. "when tensor copy, the src tensor layout is empty.");
  241. m_tensor_impl->copy_from(src.m_tensor_impl.get());
  242. update_from_implement();
  243. LITE_ERROR_HANDLER_END
  244. }
  245. void Tensor::update_from_implement() {
  246. LITE_ERROR_HANDLER_BEGIN
  247. m_layout = m_tensor_impl->get_layout();
  248. m_device_type = m_tensor_impl->get_device_type();
  249. m_device_id = m_tensor_impl->get_device_id();
  250. m_is_pinned_host = m_tensor_impl->is_pinned_host();
  251. LITE_ERROR_HANDLER_END
  252. }
  253. void LiteAny::type_missmatch(size_t expect, size_t get) const {
  254. LITE_THROW(ssprintf(
  255. "The type store in LiteAny is not match the visit type, type of "
  256. "storage length is %zu, type of visit length is %zu.",
  257. expect, get));
  258. }
  259. std::shared_ptr<Tensor> TensorUtils::concat(const std::vector<Tensor>& tensors,
  260. int dim, LiteDeviceType dst_device,
  261. int dst_device_id) {
  262. if (tensors.size() <= 0) {
  263. return std::make_shared<Tensor>();
  264. }
  265. if (dst_device == LiteDeviceType::LITE_DEVICE_DEFAULT) {
  266. dst_device = tensors.front().get_device_type();
  267. }
  268. if (dst_device_id == -1) {
  269. dst_device_id = tensors.front().get_device_id();
  270. }
  271. bool is_pinned_host = tensors.front().is_pinned_host();
  272. auto layout = tensors.front().get_layout();
  273. LITE_ASSERT(static_cast<int>(layout.ndim) > dim,
  274. "the dim in concat is error.");
  275. size_t sum_in_dim = layout.shapes[dim];
  276. for (size_t i = 1; i < tensors.size(); ++i) {
  277. auto other_layout = tensors[i].get_layout();
  278. LITE_ASSERT(other_layout.ndim == layout.ndim,
  279. "the dim size of tensors is not same!");
  280. LITE_ASSERT(other_layout.data_type == layout.data_type,
  281. "the dtype of tensors is not same!");
  282. for (size_t j = 0; j < other_layout.ndim; ++j) {
  283. if (dim == static_cast<int>(j)) {
  284. sum_in_dim += other_layout.shapes[j];
  285. continue;
  286. }
  287. LITE_ASSERT(other_layout.shapes[j] == layout.shapes[j],
  288. "the shape of tensors is not same!");
  289. }
  290. }
  291. layout.shapes[dim] = sum_in_dim;
  292. auto result = std::make_shared<Tensor>(dst_device_id, dst_device, layout,
  293. is_pinned_host);
  294. size_t index = 0;
  295. std::vector<size_t> start(dim + 1, 0);
  296. std::vector<size_t> end(dim + 1, 0);
  297. for (int i = 0; i < dim; i++) {
  298. end[i] = layout.shapes[i];
  299. }
  300. for (size_t i = 0; i < tensors.size(); ++i) {
  301. auto&& tensor = tensors[i];
  302. auto layout = tensor.get_layout();
  303. if (layout.shapes[dim] == 0)
  304. continue;
  305. start[dim] = index;
  306. end[dim] = index + layout.shapes[dim];
  307. auto&& sub_dst = result->slice(start, end);
  308. sub_dst->copy_from(tensor);
  309. index += layout.shapes[dim];
  310. }
  311. return result;
  312. }
  313. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台