You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /**
  2. * \file dnn/src/common/utils.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/utils.h"
  12. #include "megdnn/handle.h"
  13. #include "megdnn/oprs/utils.h"
  14. #include <cstdarg>
  15. #include <cstring>
  16. #include <mutex>
  17. #include <numeric>
  18. using namespace megdnn;
  19. namespace {
  20. std::string svsprintf(const char* fmt, va_list ap_orig) {
  21. int size = 100; /* Guess we need no more than 100 bytes */
  22. char* p;
  23. if ((p = (char*)malloc(size)) == nullptr)
  24. return "svsprintf: malloc failed";
  25. for (;;) {
  26. va_list ap;
  27. va_copy(ap, ap_orig);
  28. int n = vsnprintf(p, size, fmt, ap);
  29. va_end(ap);
  30. if (n < 0)
  31. return "svsprintf: vsnprintf failed";
  32. if (n < size) {
  33. std::string rst(p);
  34. free(p);
  35. return rst;
  36. }
  37. size = n + 1;
  38. char* np = (char*)realloc(p, size);
  39. if (!np) {
  40. free(p);
  41. return "svsprintf: realloc failed";
  42. } else
  43. p = np;
  44. }
  45. }
  46. } // anonymous namespace
  47. std::string megdnn::ssprintf(const char* fmt, ...) {
  48. va_list ap;
  49. va_start(ap, fmt);
  50. auto rst = svsprintf(fmt, ap);
  51. va_end(ap);
  52. return rst;
  53. }
  54. void megdnn::__assert_fail__(
  55. const char* file, int line, const char* func, const char* expr,
  56. const char* msg_fmt, ...) {
  57. std::string msg;
  58. if (msg_fmt) {
  59. va_list ap;
  60. va_start(ap, msg_fmt);
  61. msg = "\nextra message: ";
  62. msg.append(svsprintf(msg_fmt, ap));
  63. va_end(ap);
  64. }
  65. msg = ssprintf(
  66. "assertion `%s' failed at %s:%d: %s%s", expr, file, line, func,
  67. msg.c_str());
  68. megdnn_throw(msg.c_str());
  69. }
  70. bool megdnn::get_next_addr(size_t* idx, const size_t* shp, size_t n, size_t stride) {
  71. auto errmsg = [&]() {
  72. std::string res;
  73. res.append("idx={");
  74. for (size_t i = 0; i < n; ++i) {
  75. res.append(std::to_string(idx[i]));
  76. if (i + 1 < n)
  77. res.append(",");
  78. }
  79. res.append("}, shp={");
  80. for (size_t i = 0; i < n; ++i) {
  81. res.append(std::to_string(shp[i]));
  82. if (i + 1 < n)
  83. res.append(",");
  84. }
  85. res.append("}, n=");
  86. res.append(std::to_string(n));
  87. res.append(", stride=");
  88. res.append(std::to_string(stride));
  89. return res;
  90. };
  91. MEGDNN_MARK_USED_VAR(errmsg);
  92. for (size_t i = 0; i < n; ++i) {
  93. megdnn_assert(idx[i] < shp[i], "%s", errmsg().c_str());
  94. }
  95. idx[n - 1] += stride;
  96. megdnn_assert(idx[n - 1] <= shp[n - 1], "%s", errmsg().c_str());
  97. size_t i;
  98. for (i = n; i > 1; --i)
  99. if (idx[i - 1] == shp[i - 1]) {
  100. idx[i - 1] = 0;
  101. ++idx[i - 2];
  102. } else {
  103. break;
  104. }
  105. if (i == 1 && idx[0] == shp[0]) {
  106. idx[0] = 0;
  107. return false;
  108. }
  109. return true;
  110. }
  111. int megdnn::get_linear_addr_noncont(size_t* index, const TensorLayout& layout) {
  112. int ans = 0;
  113. rep(i, layout.ndim) { ans += index[i] * layout.stride[i]; }
  114. return ans;
  115. }
  116. size_t megdnn::get_linear_addr(size_t* index, const size_t* shape, size_t n) {
  117. size_t base = 1;
  118. size_t ans = 0;
  119. for (size_t i = n; i > 0; --i) {
  120. ans += index[i - 1] * base;
  121. base *= shape[i - 1];
  122. }
  123. return ans;
  124. }
  125. size_t megdnn::infer_conv_shape(
  126. size_t inp, size_t flt, size_t stride, size_t pad, bool is_floor) {
  127. megdnn_assert(
  128. inp + 2 * pad >= flt, "input=%zu padding=%zu filter=%zu", inp, pad, flt);
  129. if (is_floor) {
  130. return (inp + 2 * pad - flt) / stride + 1;
  131. }
  132. return (inp + 2 * pad - flt + stride - 1) / stride + 1;
  133. }
  134. void megdnn::infer_conv_shape2d(
  135. size_t ih, size_t iw, size_t fh, size_t fw, size_t sh, size_t sw, size_t ph,
  136. size_t pw, size_t& oh, size_t& ow, bool is_floor) {
  137. oh = infer_conv_shape(ih, fh, sh, ph, is_floor);
  138. ow = infer_conv_shape(iw, fw, sw, pw, is_floor);
  139. }
  140. WorkspaceBundle::WorkspaceBundle(
  141. void* ptr, SmallVector<size_t> sizes_in_bytes, size_t align_in_bytes)
  142. : m_ptr(ptr),
  143. m_sizes(std::move(sizes_in_bytes)),
  144. m_align_in_bytes(align_in_bytes) {
  145. m_aligned_sizes.reserve(m_sizes.size());
  146. for (auto size : m_sizes) {
  147. auto aligned_size = size;
  148. if (size % m_align_in_bytes != 0) {
  149. aligned_size += m_align_in_bytes - size % m_align_in_bytes;
  150. }
  151. m_aligned_sizes.push_back(aligned_size);
  152. }
  153. }
  154. void* WorkspaceBundle::ptr() const {
  155. return m_ptr;
  156. }
  157. void* WorkspaceBundle::get(size_t i) const {
  158. auto addr = reinterpret_cast<uintptr_t>(m_ptr);
  159. if (addr % m_align_in_bytes != 0)
  160. addr += m_align_in_bytes - addr % m_align_in_bytes;
  161. for (size_t j = 0; j < i; ++j) {
  162. addr += m_aligned_sizes[j];
  163. }
  164. return reinterpret_cast<void*>(addr);
  165. }
  166. size_t WorkspaceBundle::nr_workspace() const {
  167. return m_sizes.size();
  168. }
  169. size_t WorkspaceBundle::get_size(size_t i) const {
  170. return m_sizes[i];
  171. }
  172. void WorkspaceBundle::set(void* ptr) {
  173. m_ptr = ptr;
  174. }
  175. size_t WorkspaceBundle::total_size_in_bytes() const {
  176. //! return 0 if the WorkspaceBundle is empty
  177. size_t size = std::accumulate(
  178. m_aligned_sizes.begin(), m_aligned_sizes.end(), static_cast<size_t>(0));
  179. return size ? size + m_align_in_bytes : size;
  180. }
  181. size_t megdnn::count_not_ones_in_shape(const TensorShape& shape) {
  182. size_t res = 0u;
  183. for (size_t i = 0; i < shape.ndim; ++i)
  184. res += (shape[i] != 1u);
  185. return res;
  186. }
  187. bool megdnn::is_nhwc_contig_wc(const TensorLayout& layout) {
  188. return layout.ndim == 4 && (layout.stride[3] == 1 || layout.shape[3] == 1) &&
  189. (layout.stride[2] == static_cast<ptrdiff_t>(layout.shape[3]) ||
  190. layout.shape[2] == 1);
  191. }
  192. megcoreDeviceHandle_t megdnn::get_device_handle(Handle* handle) {
  193. megcoreStatus_t status;
  194. megcoreDeviceHandle_t dev_handle;
  195. megcoreComputingHandle_t comp_handle = handle->megcore_computing_handle();
  196. status = megcoreGetDeviceHandle(comp_handle, &dev_handle);
  197. megdnn_throw_if(status != megcoreSuccess, megdnn_error, "get device handle error!");
  198. return dev_handle;
  199. }
  200. // clang-format off
  201. float megdnn::mul_scale(DType lhs, DType rhs) {
  202. #define cb_binary(dt1, dt2) \
  203. if ((lhs.enumv() == DTypeTrait<dt1>::enumv) && \
  204. (rhs.enumv() == DTypeTrait<dt2>::enumv)) \
  205. return lhs.param<dt1>().scale * rhs.param<dt2>().scale;
  206. cb_binary(::megdnn::dtype::QuantizedS8, ::megdnn::dtype::QuantizedS16)
  207. cb_binary(::megdnn::dtype::Quantized4Asymm, ::megdnn::dtype::QuantizedS4)
  208. #undef cb_binary
  209. megdnn_assert(lhs.enumv() == rhs.enumv());
  210. #define cb(dt) \
  211. if (lhs.enumv() == DTypeTrait<dt>::enumv) \
  212. return lhs.param<dt>().scale * rhs.param<dt>().scale;
  213. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  214. MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
  215. cb(::megdnn::dtype::QuantizedS1)
  216. #undef cb
  217. megdnn_assert_internal(0);
  218. }
  219. // clang-format on
  220. float megdnn::get_scale(DType dt) {
  221. megdnn_assert(dt.category() == DTypeCategory::QUANTIZED);
  222. #define cb(_dt) \
  223. if (dt.enumv() == DTypeTrait<_dt>::enumv) \
  224. return dt.param<_dt>().scale;
  225. MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
  226. MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
  227. cb(::megdnn::dtype::QuantizedS1)
  228. #undef cb
  229. megdnn_assert_internal(0);
  230. }
  231. bool megdnn::dtype_almost_equal(DType lhs, DType rhs) {
  232. if (lhs.enumv() != rhs.enumv())
  233. return false;
  234. if (lhs.category() != DTypeCategory::QUANTIZED)
  235. return true;
  236. #define cb(dt) \
  237. if (lhs.enumv() == DTypeTrait<dt>::enumv) \
  238. return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale);
  239. MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb)
  240. #undef cb
  241. #define cb(dt) \
  242. if (lhs.enumv() == DTypeTrait<dt>::enumv) \
  243. return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale) && \
  244. lhs.param<dt>().zero_point == rhs.param<dt>().zero_point;
  245. MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb)
  246. #undef cb
  247. megdnn_assert_internal(false);
  248. }
  249. template <>
  250. uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, size_t offset) {
  251. uint8_t _src = std::min(src.as_uint8(), DTypeTrait<dtype::Quantized4Asymm>::max());
  252. if (offset == 0) {
  253. _src &= 0xF;
  254. dst &= 0xF0;
  255. dst |= _src;
  256. } else {
  257. _src <<= 4;
  258. dst &= 0xF;
  259. dst |= _src;
  260. }
  261. return dst;
  262. }
  263. template <>
  264. dt_quint4 megdnn::convert<uint8_t, dt_quint4>(
  265. uint8_t src, dt_quint4 dst, size_t offset) {
  266. src >>= (offset << 2);
  267. src &= 0xF;
  268. dst = dt_quint4(src);
  269. return dst;
  270. }
  271. template <>
  272. int8_t megdnn::convert<dt_qint4, int8_t>(dt_qint4 src, int8_t dst, size_t offset) {
  273. int8_t _src = std::max(
  274. std::min(src.as_int8(), DTypeTrait<dtype::QuantizedS4>::max()),
  275. DTypeTrait<dtype::QuantizedS4>::min());
  276. if (offset == 0) {
  277. _src &= 0xF;
  278. dst &= 0xF0;
  279. dst |= _src;
  280. } else {
  281. _src <<= 4;
  282. dst &= 0xF;
  283. dst |= _src;
  284. }
  285. return dst;
  286. }
  287. template <>
  288. dt_qint4 megdnn::convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset) {
  289. src <<= (4 - (offset << 2));
  290. src >>= 4;
  291. dst = dt_qint4(src);
  292. return dst;
  293. }
  294. /* ======================== CpuNDRange ======================== */
  295. std::string CpuNDRange::to_string() const {
  296. std::string ret;
  297. for (size_t i = 0; i < m_dimension; i++) {
  298. ret += megdnn::ssprintf(" %zu", m_dim[i]);
  299. }
  300. return ret;
  301. }
  302. size_t& CpuNDRange::operator[](size_t idx) {
  303. megdnn_assert(
  304. idx < m_dimension, "invalid index: %zu expected < %zu", idx, m_dimension);
  305. return m_dim[idx];
  306. }
  307. bool megdnn::check_bias_share_in_channel(
  308. const TensorLayout& bias, const param::ConvBias::Format format) {
  309. bool share_in_channel = false;
  310. if (format == param::ConvBias::Format::NCHW ||
  311. format == param::ConvBias::Format::NCHW4_NCHW) {
  312. share_in_channel =
  313. (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && bias[3] == 1);
  314. } else if (
  315. format == param::ConvBias::Format::NHWC ||
  316. format == param::ConvBias::Format::NCHW4_NHWC) {
  317. share_in_channel =
  318. (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && bias[2] == 1);
  319. } else if (
  320. format == param::ConvBias::Format::NCHW4 ||
  321. format == param::ConvBias::Format::NCHW8 ||
  322. format == param::ConvBias::Format::NCHW32 ||
  323. format == param::ConvBias::Format::NCHW64 ||
  324. format == param::ConvBias::Format::NCHW4_NCHW32 ||
  325. format == param::ConvBias::Format::NCHW32_NCHW4) {
  326. share_in_channel =
  327. (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && bias[3] == 1);
  328. } else if (format == param::ConvBias::Format::NHWCD4) {
  329. share_in_channel =
  330. (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && bias[3] == 1);
  331. } else {
  332. megdnn_assert(format == param::ConvBias::Format::CHWN4);
  333. share_in_channel =
  334. (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && bias[3] == 1);
  335. }
  336. return share_in_channel;
  337. }
  338. // vim: syntax=cpp.doxygen