You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_format.cpp 23 kB

feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. /**
  2. * \file dnn/src/common/tensor_format.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/tensor_format.h"
  12. #include "megdnn/basic_types.h"
  13. #include "src/common/utils.h"
  14. #include <unordered_map>
  15. using namespace megdnn;
  16. using namespace megdnn::detail;
  17. namespace {
  18. DefaultTensorFormat* default_tensor_format_obj;
  19. }
  20. /* ===================== TensorFormat ===================== */
  21. TensorFormat TensorFormat::deserialize(const std::string& bin, const Handle* handle) {
  22. using Type = TensorFormat::Type;
  23. auto type = reinterpret_cast<const Type*>(bin.data());
  24. switch (*type) {
  25. case Type::DEFAULT:
  26. return DefaultTensorFormat::deserialize(
  27. handle, type + 1, bin.size() - sizeof(Type));
  28. case Type::IMAGE2D_PACK4:
  29. return Image2DPack4TensorFormat::deserialize(
  30. handle, type + 1, bin.size() - sizeof(Type));
  31. case Type::LOWBITS_ALIGNED_TO_BYTE:
  32. return LowbitsAlignedToBytesTensorFormat::deserialize(
  33. handle, type + 1, bin.size() - sizeof(Type));
  34. default:
  35. megdnn_throw("invalid tensor format type in deserialize");
  36. }
  37. }
  38. TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {}
  39. TensorFormat::Format(DType dtype) {
  40. if (dtype.valid() && dtype.is_quantized_lowbit()) { // quantized lowbit, by default
  41. // aligned to bytes
  42. size_t size_nbits = dtype.low_bit();
  43. megdnn_assert(
  44. size_nbits == 1 || size_nbits == 2 || size_nbits == 4,
  45. "unsupported lowbits data type(%s, size in bits: %zu)", dtype.name(),
  46. size_nbits);
  47. m_impl = LowbitsAlignedToBytesTensorFormat::make(size_nbits).m_impl;
  48. } else { // non parameterized lowbit, default format
  49. m_impl = DefaultTensorFormat::make().m_impl;
  50. }
  51. }
  52. std::string TensorFormat::to_string() const {
  53. return m_impl->to_string();
  54. }
  55. std::string TensorFormat::serialize() const {
  56. std::string ret;
  57. ret.reserve(32);
  58. ret.assign(sizeof(Type), '\0');
  59. *reinterpret_cast<Type*>(&ret[0]) = type();
  60. m_impl->serialize_append(ret);
  61. return ret;
  62. }
  63. void TensorFormat::on_bad_cvt(Type dst_type) const {
  64. MEGDNN_MARK_USED_VAR(dst_type);
  65. megdnn_throw(ssprintf(
  66. "can not convert tensor format %s to %d", impl()->to_string().c_str(),
  67. static_cast<int>(dst_type)));
  68. }
  69. bool TensorFormat::is_default() const {
  70. return m_impl == default_tensor_format_obj;
  71. }
  72. bool TensorFormat::is_lowbit_aligned() const {
  73. return type() == TensorFormat::Type::LOWBITS_ALIGNED_TO_BYTE;
  74. }
  75. /* ===================== DefaultFormat ===================== */
  76. void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const {
  77. megdnn_assert(
  78. !layout.dtype.valid() || !layout.dtype.is_quantized_lowbit(),
  79. "DefaultTensorFormat does not support quantized lowbit tensor(dtype:%s)",
  80. layout.dtype.name());
  81. }
  82. size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
  83. assert_valid(layout);
  84. if (!layout.ndim)
  85. return 0;
  86. megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
  87. size_t accum = 1;
  88. SafeMultiplies<size_t> mul;
  89. for (size_t i = layout.ndim; i; --i) {
  90. layout.stride[i - 1] = accum;
  91. accum = mul(accum, layout.shape[i - 1]);
  92. }
  93. return accum;
  94. }
  95. bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const {
  96. assert_valid(layout);
  97. return layout.is_physical_contiguous();
  98. }
  99. TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
  100. const TensorLayout& layout) const {
  101. assert_valid(layout);
  102. megdnn_assert(layout.ndim);
  103. TensorLayout res{layout};
  104. // remove all dims with shape 1
  105. for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 2; --i) {
  106. if (!res.shape[i]) {
  107. // empty tensor
  108. res.ndim = 1;
  109. res.shape[0] = 0;
  110. res.stride[0] = 1;
  111. return res;
  112. }
  113. if (res.shape[i] == 1)
  114. res.remove_axis_inplace(i);
  115. }
  116. if (res.ndim == 1) {
  117. if (res.shape[0] <= 1) {
  118. // make it the "most canonical" contiguous layout for scalars or
  119. // empty tensors
  120. res.stride[0] = 1;
  121. }
  122. return res;
  123. }
  124. megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
  125. for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
  126. megdnn_assert(res.shape[i]);
  127. if (res.stride[i] ==
  128. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
  129. res.shape[i] *= res.shape[i + 1];
  130. res.stride[i] = res.stride[i + 1];
  131. res.remove_axis_inplace(i + 1);
  132. }
  133. }
  134. return res;
  135. }
  136. TensorLayout::Span DefaultTensorFormat::span_spec(const TensorLayout& layout) const {
  137. assert_valid(layout);
  138. if (layout.ndim == 0)
  139. return {0, 0, 0, 0};
  140. ptrdiff_t low_elem = 0;
  141. size_t high_elem = 0;
  142. for (size_t i = 0; i < layout.ndim; ++i) {
  143. auto shape_val = layout.shape[i];
  144. if (!shape_val) {
  145. return {0, 0, 0, 0};
  146. }
  147. auto stride_val = layout.stride[i];
  148. if (stride_val > 0) {
  149. high_elem += (shape_val - 1) * stride_val;
  150. } else {
  151. low_elem += (shape_val - 1) * stride_val;
  152. }
  153. }
  154. ++high_elem;
  155. ptrdiff_t low_byte;
  156. if (low_elem < 0) {
  157. low_byte = low_elem * layout.dtype.size();
  158. } else {
  159. low_byte = 0;
  160. }
  161. size_t high_byte = layout.dtype.size(high_elem);
  162. return TensorLayout::Span(low_elem, low_byte, high_elem, high_byte);
  163. }
  164. std::string DefaultTensorFormat::to_string() const {
  165. return "default{}";
  166. }
  167. void DefaultTensorFormat::serialize_append(std::string&) const {}
  168. TensorFormat DefaultTensorFormat::deserialize(
  169. const Handle* handle, const void* buf, size_t size) {
  170. MEGDNN_MARK_USED_VAR(handle);
  171. MEGDNN_MARK_USED_VAR(buf);
  172. megdnn_assert(!size);
  173. return make();
  174. }
  175. TensorFormat DefaultTensorFormat::make() {
  176. // use static storage so the object is accessible in global destructing
  177. // phase
  178. static std::aligned_storage_t<
  179. sizeof(DefaultTensorFormat), alignof(DefaultTensorFormat)>
  180. storage;
  181. static DefaultTensorFormat* obj = default_tensor_format_obj =
  182. new (&storage) DefaultTensorFormat{};
  183. return impl_to_tensor_format(obj);
  184. }
  185. /* ===================== Image2DTensorFormatBase ===================== */
  186. Image2DTensorFormatBase::Image2DTensorFormatBase(
  187. Type type, size_t align_axis, size_t align_size_in_elements)
  188. : ImplBase(type), m_align_axis(align_axis) {
  189. megdnn_assert(align_size_in_elements && align_axis);
  190. m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements);
  191. megdnn_assert(
  192. (1u << m_align_size_in_elements_log2) == align_size_in_elements,
  193. "align size not power of 2: %zu", align_size_in_elements);
  194. }
  195. void Image2DTensorFormatBase::serialize_append(std::string& result) const {
  196. SerializePack pack;
  197. pack.align_axis = m_align_axis;
  198. megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
  199. result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
  200. }
  201. size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
  202. size_t accum = 1;
  203. for (int i = m_align_axis - 1; i >= 0; --i) {
  204. if (layout.stride[i] == 0) {
  205. // this dimension is broadcasted
  206. } else {
  207. accum *= layout.shape[i];
  208. }
  209. }
  210. return accum;
  211. }
  212. size_t Image2DTensorFormatBase::image_width_elems(const TensorLayout& layout) const {
  213. size_t high_elem = 0;
  214. for (size_t i = m_align_axis; i < layout.ndim; ++i) {
  215. high_elem += (layout.shape[i] - 1) * layout.stride[i];
  216. }
  217. return high_elem + 1;
  218. }
  219. std::string Image2DTensorFormatBase::to_string() const {
  220. return ssprintf("I2D{%zu,%d}", m_align_axis, 1 << m_align_size_in_elements_log2);
  221. }
  222. /* ===================== Image2DPackedTensorFormatBase ===================== */
  223. template <size_t PIXEL_SIZE>
  224. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
  225. const TensorLayout& layout) const {
  226. auto ret = image_width_elems(layout);
  227. megdnn_assert(ret % PIXEL_SIZE == 0);
  228. return ret / PIXEL_SIZE;
  229. }
  230. template <size_t PIXEL_SIZE>
  231. void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
  232. const TensorLayout& layout) const {
  233. auto m_align_axis = align_axis();
  234. megdnn_assert(
  235. !(layout.shape[layout.ndim - 1] % PIXEL_SIZE), "bad shape: %zu",
  236. layout.shape[layout.ndim - 1]);
  237. megdnn_assert(
  238. layout.dtype.valid() && !layout.dtype.is_quantized_lowbit() &&
  239. layout.ndim > m_align_axis);
  240. ptrdiff_t first_non_zero_stride = 0;
  241. for (int i = layout.ndim - 1; i >= 0; --i) {
  242. megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
  243. if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
  244. first_non_zero_stride = layout.stride[i];
  245. }
  246. }
  247. size_t mask = image_pitch_alignment_in_bytes(
  248. align_size_in_elements(layout.dtype.size_log()), layout) -
  249. 1;
  250. megdnn_assert(
  251. !(first_non_zero_stride & mask), "first stride is %d, but alignment is %zu",
  252. static_cast<int>(first_non_zero_stride), mask + 1);
  253. }
  254. template <size_t PIXEL_SIZE>
  255. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_row_pitch(
  256. const TensorLayout& layout) const {
  257. for (int i = align_axis() - 1; i >= 0; --i) {
  258. // find a non-broadcast axis
  259. if (auto s = layout.stride[i]) {
  260. return layout.dtype.size(s);
  261. }
  262. }
  263. // use width for all broadcasted case
  264. size_t alignment_in_bytes_log2 = align_size_in_elements_log2();
  265. if (m_vendor_type == Handle::HandleVendorType::MALI) {
  266. alignment_in_bytes_log2 += __builtin_ctz(layout.dtype.size() * PIXEL_SIZE);
  267. }
  268. return get_aligned_power2<size_t>(
  269. layout.dtype.size(image_width_elems(layout)), 1 << alignment_in_bytes_log2);
  270. }
  271. template <size_t PIXEL_SIZE>
  272. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_pitch_alignment_in_bytes(
  273. size_t align_size_in_elements, const TensorLayout& layout) const {
  274. return m_vendor_type == Handle::HandleVendorType::MALI
  275. ? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE)
  276. : align_size_in_elements;
  277. }
  278. template <size_t PIXEL_SIZE>
  279. TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec(
  280. const TensorLayout& layout) const {
  281. assert_valid(layout);
  282. size_t size = image_height(layout) * image_row_pitch(layout);
  283. auto mask = (1 << layout.dtype.size_log()) - 1;
  284. megdnn_assert(!(size & mask), "unaligned size: %zu", size);
  285. return {0, 0, size >> layout.dtype.size_log(), size};
  286. }
  287. template <size_t PIXEL_SIZE>
  288. size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::init_contiguous_stride(
  289. TensorLayout& layout) const {
  290. auto m_align_axis = align_axis();
  291. if (!layout.ndim)
  292. return 0;
  293. megdnn_assert(
  294. layout.dtype.valid() && layout.ndim > m_align_axis,
  295. "dtype=%s ndim=%zu align=%zu", layout.dtype.name(), layout.ndim,
  296. m_align_axis);
  297. size_t align_size = image_pitch_alignment_in_bytes(
  298. align_size_in_elements(layout.dtype.size_log()), layout);
  299. size_t accum = 1;
  300. SafeMultiplies<size_t> mul;
  301. for (size_t i = layout.ndim; i; --i) {
  302. if (i == m_align_axis) {
  303. accum = get_aligned_power2<size_t>(accum, align_size);
  304. }
  305. layout.stride[i - 1] = accum;
  306. accum = mul(accum, layout.shape[i - 1]);
  307. }
  308. assert_valid(layout);
  309. return accum;
  310. };
  311. template <size_t PIXEL_SIZE>
  312. bool Image2DPackedTensorFormatBase<PIXEL_SIZE>::is_contiguous_spec(
  313. const TensorLayout& layout) const {
  314. megdnn_assert(layout.dtype.valid());
  315. size_t align_size = image_pitch_alignment_in_bytes(
  316. align_size_in_elements(layout.dtype.size_log()), layout);
  317. ptrdiff_t expected = 1;
  318. int height_axis = static_cast<int>(align_axis() - 1);
  319. for (int i = layout.ndim - 1; i >= 0; --i) {
  320. if (i == height_axis) {
  321. expected = megdnn::get_aligned_power2<size_t>(expected, align_size);
  322. }
  323. if (layout.shape[i] != 1 && layout.stride[i] != expected) {
  324. if (i == height_axis) {
  325. // allow row pitch to be larger than minimal required
  326. auto s = layout.stride[i];
  327. if (!s) {
  328. // broadcast is not contiguous
  329. return false;
  330. }
  331. size_t mask = image_pitch_alignment_in_bytes(
  332. align_size_in_elements(layout.dtype.size_log()),
  333. layout) -
  334. 1;
  335. megdnn_assert(
  336. s > expected && !(s & mask),
  337. "invalid row pitch: %d; layout: %s", static_cast<int>(s),
  338. layout.to_string().c_str());
  339. expected = s;
  340. } else {
  341. return false;
  342. }
  343. }
  344. expected *= layout.shape[i];
  345. }
  346. // empty tensors are not contiguous
  347. return expected != 0;
  348. }
  349. template <size_t PIXEL_SIZE>
  350. TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec(
  351. const TensorLayout& layout) const {
  352. assert_valid(layout);
  353. TensorLayout res{layout};
  354. int new_axis = align_axis();
  355. // remove all dims with shape 1
  356. for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) {
  357. if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) {
  358. // i is the only width dim
  359. continue;
  360. }
  361. if (i == new_axis - 1 && !i) {
  362. // new_xis == 1 && i == 0, i is the only height dim
  363. continue;
  364. }
  365. if (res.shape[i] == 1) {
  366. res.remove_axis_inplace(i);
  367. if (i < new_axis)
  368. new_axis -= 1;
  369. }
  370. }
  371. megdnn_assert(res.ndim >= 2);
  372. auto contig_with_next = [&](size_t i) {
  373. return res.stride[i] ==
  374. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1]);
  375. };
  376. for (int i = static_cast<int>(res.ndim) - 2; i >= new_axis; --i) {
  377. megdnn_assert(res.shape[i]);
  378. if (contig_with_next(i)) {
  379. // remove next axis
  380. res.shape[i] *= res.shape[i + 1];
  381. res.stride[i] = res.stride[i + 1];
  382. res.remove_axis_inplace(i + 1);
  383. }
  384. }
  385. for (int i = new_axis - 2; i >= 0; --i) {
  386. megdnn_assert(res.shape[i]);
  387. if (contig_with_next(i)) {
  388. res.shape[i] *= res.shape[i + 1];
  389. res.stride[i] = res.stride[i + 1];
  390. res.remove_axis_inplace(i + 1);
  391. if (i <= new_axis - 2)
  392. new_axis -= 1;
  393. }
  394. }
  395. res.format = change_axis(new_axis);
  396. return res;
  397. }
  398. namespace megdnn {
  399. namespace detail {
  400. template class Image2DPackedTensorFormatBase<4>;
  401. } // namespace detail
  402. } // namespace megdnn
  403. /* =============== LowbitsAlignedTensorFormatBase ============== */
  404. LowbitsAlignedTensorFormatBase::LowbitsAlignedTensorFormatBase(
  405. Type type, size_t size_nbits, size_t align_size_in_bits)
  406. : ImplBase(type),
  407. m_size_nbits(size_nbits),
  408. m_align_size_in_bits(align_size_in_bits) {
  409. megdnn_assert(
  410. !(m_align_size_in_bits % m_size_nbits),
  411. "align size(%zu) must be a multiple of element size(%zu)",
  412. m_align_size_in_bits, m_size_nbits);
  413. m_align_size_in_elements = m_align_size_in_bits / m_size_nbits;
  414. }
  415. std::string LowbitsAlignedTensorFormatBase::to_string() const {
  416. return ssprintf("LOWBITS{%zu,%zu}", m_size_nbits, m_align_size_in_bits);
  417. }
  418. void LowbitsAlignedTensorFormatBase::assert_valid(const TensorLayout& layout) const {
  419. megdnn_assert(
  420. layout.dtype.valid() && layout.dtype.is_low_bit() &&
  421. layout.dtype.low_bit() == m_size_nbits);
  422. bool has_dim_unity_stride = false;
  423. bool has_dim_aligned_stride = false;
  424. for (int i = layout.ndim - 1; i >= 0; --i) {
  425. if (!has_dim_unity_stride && layout.stride[i] == 1)
  426. has_dim_unity_stride = true;
  427. megdnn_assert(
  428. layout.stride[i] >= 0 &&
  429. (layout.stride[i] % m_align_size_in_elements == 0 ||
  430. layout.stride[i] == 1),
  431. "bad stride:%s, %ld", layout.to_string().c_str(),
  432. static_cast<long>(layout.stride[i]));
  433. if (!has_dim_aligned_stride &&
  434. static_cast<size_t>(layout.stride[i]) == m_align_size_in_elements)
  435. has_dim_aligned_stride = true;
  436. }
  437. megdnn_assert(
  438. layout.ndim == 0 || has_dim_unity_stride || has_dim_aligned_stride,
  439. "innermost dim not contiguous");
  440. }
  441. void LowbitsAlignedTensorFormatBase::serialize_append(std::string& result) const {
  442. SerializePack pack;
  443. pack.size_nbits = m_size_nbits;
  444. pack.align_size_in_bits = m_align_size_in_bits;
  445. megdnn_assert(pack.align_size_in_bits == m_align_size_in_bits); // detect overflow;
  446. result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
  447. }
  448. TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec(
  449. const TensorLayout& layout) const {
  450. assert_valid(layout);
  451. if (layout.ndim == 0)
  452. return {0, 0, 0, 0};
  453. size_t high_elem = 0;
  454. for (size_t i = 0; i < layout.ndim; ++i) {
  455. auto shape_val = layout.shape[i];
  456. if (!shape_val) {
  457. return {0, 0, 0, 0};
  458. }
  459. auto stride_val = layout.stride[i];
  460. megdnn_assert(
  461. stride_val >= 0, "lowbit tensors shouldn't have negative strides");
  462. high_elem += (shape_val - 1) * stride_val;
  463. }
  464. ++high_elem;
  465. size_t high_byte = layout.dtype.size(high_elem);
  466. return TensorLayout::Span(0, 0, high_elem, high_byte);
  467. }
  468. size_t LowbitsAlignedTensorFormatBase::init_contiguous_stride(
  469. TensorLayout& layout) const {
  470. if (!layout.ndim)
  471. return 0;
  472. megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
  473. size_t accum = 1;
  474. SafeMultiplies<size_t> mul;
  475. for (size_t i = layout.ndim; i; --i) {
  476. layout.stride[i - 1] = accum;
  477. auto multiplier = layout.shape[i - 1];
  478. if (i == layout.ndim)
  479. multiplier = round_up(multiplier, m_align_size_in_elements);
  480. accum = mul(accum, multiplier);
  481. }
  482. assert_valid(layout);
  483. return accum;
  484. }
  485. bool LowbitsAlignedTensorFormatBase::is_contiguous_spec(
  486. const TensorLayout& layout) const {
  487. assert_valid(layout);
  488. ptrdiff_t expected = 1;
  489. for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
  490. bool is_valid_stride =
  491. (layout.stride[i] == expected) ||
  492. (expected == 1 &&
  493. (int)layout.stride[i] == round_up(1, (int)m_align_size_in_elements));
  494. if (layout.shape[i] != 1 && !is_valid_stride)
  495. return false;
  496. auto multiplier = layout.shape[i];
  497. if (i == static_cast<int>(layout.ndim) - 1)
  498. multiplier = round_up(multiplier, m_align_size_in_elements);
  499. expected *= multiplier;
  500. }
  501. return expected != 0;
  502. }
  503. TensorLayout LowbitsAlignedTensorFormatBase::collapse_contiguous_spec(
  504. const TensorLayout& layout) const {
  505. assert_valid(layout);
  506. TensorLayout res{layout};
  507. for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) {
  508. if (!res.shape[i]) {
  509. // empty tensor
  510. res.ndim = 1;
  511. res.shape[0] = 0;
  512. res.stride[0] = 1;
  513. return res;
  514. }
  515. if (res.shape[i] == 1) {
  516. res.remove_axis_inplace(i);
  517. }
  518. }
  519. megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
  520. for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
  521. megdnn_assert(res.shape[i]);
  522. if (res.stride[i] ==
  523. res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
  524. res.shape[i] *= res.shape[i + 1];
  525. res.stride[i] = res.stride[i + 1];
  526. res.remove_axis_inplace(i + 1);
  527. }
  528. }
  529. return res;
  530. }
  531. /* ===================== Image2DPack4TensorFormat ===================== */
  532. TensorFormat Image2DPack4TensorFormat::make_raw(
  533. size_t align_axis, size_t align_size_in_elements,
  534. Handle::HandleVendorType vendor_type) {
  535. static DNN_MUTEX mtx;
  536. static std::unordered_map<uint64_t, std::unique_ptr<Image2DPack4TensorFormat>>
  537. cache;
  538. megdnn_assert(
  539. std::max(align_axis, align_size_in_elements) <=
  540. std::numeric_limits<uint32_t>::max());
  541. MEGDNN_LOCK_GUARD(mtx);
  542. auto&& ptr =
  543. cache[(static_cast<uint64_t>(align_axis) << 32) | align_size_in_elements];
  544. if (!ptr) {
  545. ptr.reset(new Image2DPack4TensorFormat{
  546. align_axis, align_size_in_elements, vendor_type});
  547. }
  548. return impl_to_tensor_format(ptr.get());
  549. }
  550. TensorFormat Image2DPack4TensorFormat::make(size_t align_axis, const Handle* handle) {
  551. return make_raw(
  552. align_axis, handle->image2d_pitch_alignment(), handle->vendor_type());
  553. }
  554. TensorFormat Image2DPack4TensorFormat::deserialize(
  555. const Handle* handle, const void* buf, size_t size) {
  556. megdnn_assert(size == sizeof(SerializePack));
  557. auto pack = *static_cast<const SerializePack*>(buf);
  558. return make(pack.align_axis, handle);
  559. }
  560. TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
  561. return make_raw(axis, align_size_in_elements(), vendor());
  562. }
  563. /* ===================== LowbitsitsAlignedToBytesTensorFormat
  564. * ===================== */
  565. TensorFormat LowbitsAlignedToBytesTensorFormat::make(size_t size_nbits) {
  566. static DNN_MUTEX mtx;
  567. static std::unordered_map<
  568. uint64_t, std::unique_ptr<LowbitsAlignedToBytesTensorFormat>>
  569. cache;
  570. megdnn_assert(!(8 % size_nbits));
  571. MEGDNN_LOCK_GUARD(mtx);
  572. auto&& ptr = cache[static_cast<uint32_t>(size_nbits)];
  573. if (!ptr) {
  574. ptr.reset(new LowbitsAlignedToBytesTensorFormat{size_nbits});
  575. }
  576. return impl_to_tensor_format(ptr.get());
  577. }
  578. TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize(
  579. const Handle*, const void* buf, size_t size) {
  580. megdnn_assert(size == sizeof(SerializePack));
  581. auto pack = *static_cast<const SerializePack*>(buf);
  582. return make(pack.size_nbits);
  583. }
  584. // vim: syntax=cpp.doxygen