You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

named_tensor.cpp 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. /**
  2. * \file dnn/src/common/named_tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/named_tensor.h"
  13. #include "src/common/utils.h"
  14. using namespace megdnn;
  15. /* ===================== Dimension ============================ */
  16. const Dimension::Name Dimension::NAME_ALL[] = {
  17. Dimension::Name::N, Dimension::Name::C, Dimension::Name::H, Dimension::Name::W,
  18. Dimension::Name::G, Dimension::Name::K, Dimension::Name::R, Dimension::Name::S,
  19. Dimension::Name::P, Dimension::Name::Q,
  20. };
  21. const int Dimension::NR_NAMES = sizeof(Dimension::NAME_ALL);
  22. Dimension::Dimension(const std::string& expr) {
  23. auto errmsg = [&]() { return ssprintf("Invalid dimension(%s)", expr.c_str()); };
  24. const char* data = expr.data();
  25. bool has_stride = false;
  26. bool has_extent = false;
  27. bool init_name = false;
  28. while (*data) {
  29. if (data[0] >= 'A' && data[0] <= 'Z') {
  30. megdnn_throw_if(init_name, megdnn_error, errmsg().c_str());
  31. for (auto e : NAME_ALL) {
  32. if (data[0] == static_cast<char>(e)) {
  33. init_name = true;
  34. m_name = e;
  35. break;
  36. }
  37. }
  38. megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
  39. ++data;
  40. } else if (data[0] == '/' && data[1] == '/') {
  41. megdnn_throw_if(
  42. !init_name || has_stride || has_extent, megdnn_error,
  43. errmsg().c_str());
  44. has_stride = true;
  45. data += 2;
  46. } else if (data[0] == '%') {
  47. megdnn_throw_if(!init_name || has_extent, megdnn_error, errmsg().c_str());
  48. has_extent = true;
  49. ++data;
  50. } else if (data[0] >= '0' && data[0] <= '9') {
  51. megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
  52. uint32_t num = 0;
  53. while (data[0] >= '0' && data[0] <= '9') {
  54. num = num * 10 + (data[0] - '0');
  55. ++data;
  56. }
  57. if (has_extent)
  58. m_extent = num;
  59. else if (has_stride)
  60. m_stride = num;
  61. } else {
  62. megdnn_throw(errmsg().c_str());
  63. }
  64. }
  65. megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
  66. if (!has_extent) {
  67. m_extent = UNDETERMINED_EXTENT;
  68. }
  69. if (!has_stride) {
  70. m_stride = 1;
  71. }
  72. }
  73. Dimension& Dimension::operator=(const Dimension& rhs) {
  74. m_name = rhs.m_name;
  75. m_stride = rhs.m_stride;
  76. m_extent = rhs.m_extent;
  77. return *this;
  78. }
  79. bool Dimension::operator==(const Dimension& rhs) const {
  80. return m_name == rhs.m_name && m_stride == rhs.m_stride && m_extent == rhs.m_extent;
  81. }
  82. bool Dimension::operator<(const Dimension& rhs) const {
  83. if (m_name != rhs.m_name) {
  84. return static_cast<char>(m_name) < static_cast<char>(rhs.m_name);
  85. }
  86. if (m_stride == rhs.m_stride) {
  87. return m_extent > rhs.m_extent;
  88. }
  89. return m_stride > rhs.m_stride;
  90. }
  91. Dimension Dimension::operator*(const Dimension& rhs) const {
  92. megdnn_assert(
  93. m_name == rhs.m_name,
  94. "Multiply operation cannot be applied on dimensions with "
  95. "different name(lhs:%c, rhs:%c)",
  96. static_cast<char>(m_name), static_cast<char>(rhs.m_name));
  97. megdnn_assert(
  98. m_stride == rhs.m_stride * rhs.m_extent,
  99. "Multiply operation cannot be applied on operands(lhs:%s, rhs:%s)",
  100. to_string().c_str(), rhs.to_string().c_str());
  101. if (m_extent == UNDETERMINED_EXTENT)
  102. return Dimension(m_name, rhs.m_stride);
  103. return Dimension(m_name, rhs.m_stride, m_extent * rhs.m_extent);
  104. }
  105. Dimension Dimension::operator/(const Dimension& rhs) const {
  106. megdnn_assert(
  107. m_name == rhs.m_name,
  108. "Divide operation cannot be applied on dimensions with "
  109. "different name(lhs:%c, rhs:%c)",
  110. static_cast<char>(m_name), static_cast<char>(rhs.m_name));
  111. if (operator==(rhs))
  112. return Dimension(m_name, 1, 1);
  113. if (m_stride == rhs.m_stride) {
  114. if (m_extent == UNDETERMINED_EXTENT) {
  115. megdnn_assert(
  116. rhs.m_extent != UNDETERMINED_EXTENT,
  117. "Divide operation cannot be applied on "
  118. "operands(dividend:%s, divisor:%s)",
  119. to_string().c_str(), rhs.to_string().c_str());
  120. return Dimension(m_name, rhs.m_extent * m_stride);
  121. } else {
  122. megdnn_assert(
  123. m_extent % rhs.m_extent == 0,
  124. "Divide operation cannot be applied on "
  125. "operands(dividend:%s, divisor:%s)",
  126. to_string().c_str(), rhs.to_string().c_str());
  127. return Dimension(m_name, rhs.m_extent * m_stride, m_extent / rhs.m_extent);
  128. }
  129. } else {
  130. if (m_extent == UNDETERMINED_EXTENT) {
  131. megdnn_assert(
  132. rhs.m_extent == UNDETERMINED_EXTENT && rhs.m_stride % m_stride == 0,
  133. "Divide operation cannot be applied on "
  134. "operands(dividend:%s, divisor:%s)",
  135. to_string().c_str(), rhs.to_string().c_str());
  136. return Dimension(m_name, m_stride, rhs.m_stride / m_stride);
  137. } else {
  138. megdnn_assert(
  139. m_extent * m_stride == rhs.m_extent * rhs.m_stride &&
  140. rhs.m_stride % m_stride == 0,
  141. "Divide operation cannot be applied on "
  142. "operands(dividend:%s, divisor:%s)",
  143. to_string().c_str(), rhs.to_string().c_str());
  144. return Dimension(m_name, m_stride, m_extent / rhs.m_extent);
  145. }
  146. }
  147. }
  148. std::string Dimension::to_string() const {
  149. if (m_extent == UNDETERMINED_EXTENT) {
  150. if (m_stride == 1)
  151. return ssprintf("%c", static_cast<char>(m_name));
  152. else
  153. return ssprintf("%c//%u", static_cast<char>(m_name), m_stride);
  154. } else {
  155. if (m_stride == 1)
  156. return ssprintf("%c%%%u", static_cast<char>(m_name), m_extent);
  157. else
  158. return ssprintf(
  159. "%c//%u%%%u", static_cast<char>(m_name), m_stride, m_extent);
  160. }
  161. }
  162. /* ===================== NamedTensorShape ===================== */
  163. NamedTensorShape::NamedTensorShape(const SmallVector<Dimension>& init_shape) {
  164. megdnn_assert(
  165. init_shape.size() <= MAX_NDIM,
  166. "Illegal to construct a NamedTensorShape with "
  167. "more than MAX_NDIM(%zu) axes; got(%zu)",
  168. MAX_NDIM, init_shape.size());
  169. ndim = init_shape.size();
  170. memcpy(this->dims.data(), init_shape.data(), sizeof(Dimension) * ndim);
  171. }
  172. NamedTensorShape::NamedTensorShape(std::initializer_list<Dimension> init_shape)
  173. : NamedTensorShape(SmallVector<Dimension>{init_shape}) {}
  174. bool NamedTensorShape::eq_shape(const NamedTensorShape& rhs) const {
  175. MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code");
  176. if (ndim == rhs.ndim) {
  177. size_t eq = 0;
  178. switch (ndim) {
  179. case 7:
  180. eq += dims[6] == rhs.dims[6];
  181. MEGDNN_FALLTHRU
  182. case 6:
  183. eq += dims[5] == rhs.dims[5];
  184. MEGDNN_FALLTHRU
  185. case 5:
  186. eq += dims[4] == rhs.dims[4];
  187. MEGDNN_FALLTHRU
  188. case 4:
  189. eq += dims[3] == rhs.dims[3];
  190. MEGDNN_FALLTHRU
  191. case 3:
  192. eq += dims[2] == rhs.dims[2];
  193. MEGDNN_FALLTHRU
  194. case 2:
  195. eq += dims[1] == rhs.dims[1];
  196. MEGDNN_FALLTHRU
  197. case 1:
  198. eq += dims[0] == rhs.dims[0];
  199. }
  200. return eq == ndim;
  201. }
  202. return false;
  203. }
  204. std::string NamedTensorShape::to_string() const {
  205. std::string rst("{");
  206. for (size_t i = 0; i < ndim; i++) {
  207. if (i)
  208. rst.append(",");
  209. rst.append(dims[i].to_string());
  210. }
  211. rst.append("}");
  212. return rst;
  213. }
  214. NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) {
  215. switch (format) {
  216. case Format::NCHW:
  217. return {{"N"}, {"C"}, {"H"}, {"W"}};
  218. case Format::NHWC:
  219. return {{"N"}, {"H"}, {"W"}, {"C"}};
  220. case Format::NCHW4:
  221. return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}};
  222. case Format::NCHW8:
  223. return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}};
  224. case Format::NCHW32:
  225. return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}};
  226. case Format::NCHW64:
  227. return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}};
  228. case Format::NCHW44:
  229. return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}, {"N%4"}};
  230. case Format::NCHW88:
  231. return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}};
  232. case Format::NCHW44_DOT:
  233. return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}};
  234. case Format::NHWCD4:
  235. return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}};
  236. default:
  237. megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format))
  238. .c_str());
  239. }
  240. }
  241. // vim: syntax=cpp.doxygen