You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tile_repeat.cpp 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * \file dnn/src/common/tile_repeat.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. #include <numeric>
  14. namespace megdnn {
  15. void TileRepeatBase::check_layout_fwd(const TensorLayout &src,
  16. const TensorLayout &dst)
  17. {
  18. auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst)
  19. + ", " + "times=" + param().times.to_string();
  20. auto errmsg_c = errmsg.c_str();
  21. MEGDNN_MARK_USED_VAR(errmsg_c);
  22. megdnn_assert_contiguous(src);
  23. megdnn_assert_contiguous(dst);
  24. auto expected_ndim = param().times.ndim;
  25. megdnn_assert(expected_ndim == src.ndim, "%s", errmsg_c);
  26. megdnn_assert(expected_ndim == dst.ndim, "%s", errmsg_c);
  27. rep(i, expected_ndim) {
  28. megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i],
  29. "%s", errmsg_c);
  30. }
  31. megdnn_assert(src.dtype == dst.dtype);
  32. }
  33. void TileRepeatBase::deduce_layout_fwd(const TensorLayout &src,
  34. TensorLayout &dst)
  35. {
  36. dst.ndim = src.ndim;
  37. rep(i, src.ndim) {
  38. dst.shape[i] = src.shape[i] * param().times[i];
  39. }
  40. dst.dtype = src.dtype;
  41. dst.init_contiguous_stride();
  42. check_layout_fwd(src, dst);
  43. }
  44. size_t TileRepeatBase::get_workspace_in_bytes_fwd(const TensorShape & /* src */,
  45. const TensorShape &dst,
  46. const TensorShape &times,
  47. DType dtype)
  48. {
  49. size_t nr_workspace = 0;
  50. auto nr_reduces = count_not_ones_in_shape(times);
  51. if (nr_reduces == 0) {
  52. // case 1: no tile/repeat is needed, let alone workspace.
  53. nr_workspace = 0;
  54. } else if (nr_reduces == 1) {
  55. // case 2: only one tile/repeat is needed, so we don't need workspace.
  56. nr_workspace = 0;
  57. } else if (nr_reduces == 2) {
  58. // case 3: two tile/repeats are needed, so we need a single workspace.
  59. nr_workspace = 1;
  60. } else {
  61. // case 4: multiple tile/repeats are needed, so we need two workspace in
  62. // an alternate fashion.
  63. nr_workspace = 2;
  64. }
  65. if (nr_workspace == 0) {
  66. return 0;
  67. } else {
  68. WorkspaceBundle workspaces{
  69. nullptr, {nr_workspace, dst.total_nr_elems() * dtype.size()}};
  70. return workspaces.total_size_in_bytes();
  71. }
  72. }
  73. void TileBase::simplify_shape(const TensorShape &src,
  74. const TensorShape &dst,
  75. const TensorShape &times,
  76. TensorShape &src2,
  77. TensorShape &dst2,
  78. TensorShape &times2)
  79. {
  80. size_t n = 0;
  81. for (size_t i = 0; i < src.ndim; ++i) {
  82. if (times.shape[i] == 1 && n > 0) {
  83. src2.shape[n-1] *= src.shape[i];
  84. dst2.shape[n-1] *= dst.shape[i];
  85. } else {
  86. src2.shape[n] = src.shape[i];
  87. dst2.shape[n] = dst.shape[i];
  88. times2.shape[n] = times.shape[i];
  89. ++n;
  90. }
  91. }
  92. src2.ndim = dst2.ndim = times2.ndim = n;
  93. }
  94. size_t TileBase::get_workspace_in_bytes_fwd(const TensorLayout &src_,
  95. const TensorLayout &dst_)
  96. {
  97. TensorShape src, dst, times;
  98. simplify_shape(src_, dst_, param().times, src, dst, times);
  99. return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times,
  100. src_.dtype);
  101. }
  102. void TileForward::deduce_layout(const TensorLayout &src,
  103. TensorLayout &dst)
  104. {
  105. deduce_layout_fwd(src, dst);
  106. }
  107. void TileForward::check_exec(const TensorLayout &src, const TensorLayout &dst,
  108. size_t workspace_in_bytes)
  109. {
  110. check_layout_fwd(src, dst);
  111. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  112. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  113. }
  114. void TileBackward::check_exec(const TensorLayout &diff, const TensorLayout &grad,
  115. size_t workspace_in_bytes)
  116. {
  117. check_layout_fwd(grad, diff);
  118. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
  119. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  120. }
  121. void RepeatBase::simplify_shape(const TensorShape &src,
  122. const TensorShape & /* dst */,
  123. const TensorShape &times,
  124. TensorShape &src2,
  125. TensorShape &dst2,
  126. TensorShape &times2)
  127. {
  128. auto n = 0u;
  129. size_t i = 0;
  130. while (i < times.ndim) {
  131. size_t j = i;
  132. while (j < times.ndim && times.shape[j] == 1) ++j;
  133. // Here: j is times.ndim, or times.shape[j] != 1
  134. if (j < times.ndim) ++j;
  135. src2.shape[n] = std::accumulate(src.shape + i, src.shape + j,
  136. 1_z, SafeMultiplies<size_t>());
  137. times2.shape[n] = times.shape[j-1];
  138. dst2.shape[n] = src2.shape[n] * times2.shape[n];
  139. ++n;
  140. i = j;
  141. }
  142. src2.ndim = dst2.ndim = times2.ndim = n;
  143. }
  144. size_t RepeatBase::get_workspace_in_bytes_fwd(const TensorLayout &src_,
  145. const TensorLayout &dst_)
  146. {
  147. TensorShape src, dst, times;
  148. simplify_shape(src_, dst_, param().times, src, dst, times);
  149. return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times,
  150. src_.dtype);
  151. }
  152. void RepeatForward::deduce_layout(const TensorLayout &src,
  153. TensorLayout &dst)
  154. {
  155. deduce_layout_fwd(src, dst);
  156. }
  157. void RepeatForward::check_exec(const TensorLayout &src, const TensorLayout &dst,
  158. size_t workspace_in_bytes)
  159. {
  160. check_layout_fwd(src, dst);
  161. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  162. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  163. }
  164. void RepeatBackward::check_exec(const TensorLayout &diff,
  165. const TensorLayout &grad, size_t workspace_in_bytes)
  166. {
  167. check_layout_fwd(grad, diff);
  168. auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
  169. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  170. }
  171. } // namespace megdnn
  172. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台