You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

squeeze.cpp 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "squeeze.h"
  15. namespace ncnn {
  16. Squeeze::Squeeze()
  17. {
  18. one_blob_only = true;
  19. support_inplace = false;
  20. }
  21. int Squeeze::load_param(const ParamDict& pd)
  22. {
  23. squeeze_w = pd.get(0, 0);
  24. squeeze_h = pd.get(1, 0);
  25. squeeze_d = pd.get(11, 0);
  26. squeeze_c = pd.get(2, 0);
  27. axes = pd.get(3, Mat());
  28. return 0;
  29. }
  30. int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
  31. {
  32. int w = bottom_blob.w;
  33. int h = bottom_blob.h;
  34. int d = bottom_blob.d;
  35. int channels = bottom_blob.c;
  36. int dims = bottom_blob.dims;
  37. bool _squeeze_w = false;
  38. bool _squeeze_h = false;
  39. bool _squeeze_d = false;
  40. bool _squeeze_c = false;
  41. if (axes.empty())
  42. {
  43. _squeeze_w = w == 1 && squeeze_w;
  44. _squeeze_h = h == 1 && squeeze_h;
  45. _squeeze_d = d == 1 && squeeze_d;
  46. _squeeze_c = channels == 1 && squeeze_c;
  47. }
  48. else
  49. {
  50. const int* axes_ptr = axes;
  51. for (int i = 0; i < axes.w; i++)
  52. {
  53. int axis = axes_ptr[i];
  54. if (axis < 0)
  55. axis = dims + axis;
  56. if (dims == 1 && axis == 0)
  57. {
  58. _squeeze_w = w == 1;
  59. }
  60. if (dims == 2 && axis == 0)
  61. {
  62. _squeeze_h = h == 1;
  63. }
  64. if (dims == 2 && axis == 1)
  65. {
  66. _squeeze_w = w == 1;
  67. }
  68. if (dims == 3 && axis == 0)
  69. {
  70. _squeeze_c = channels == 1;
  71. }
  72. if (dims == 3 && axis == 1)
  73. {
  74. _squeeze_h = h == 1;
  75. }
  76. if (dims == 3 && axis == 2)
  77. {
  78. _squeeze_w = w == 1;
  79. }
  80. if (dims == 4 && axis == 0)
  81. {
  82. _squeeze_c = channels == 1;
  83. }
  84. if (dims == 4 && axis == 1)
  85. {
  86. _squeeze_d = d == 1;
  87. }
  88. if (dims == 4 && axis == 2)
  89. {
  90. _squeeze_h = h == 1;
  91. }
  92. if (dims == 4 && axis == 3)
  93. {
  94. _squeeze_w = w == 1;
  95. }
  96. }
  97. }
  98. top_blob = bottom_blob;
  99. if (dims == 1)
  100. {
  101. if (_squeeze_w)
  102. {
  103. top_blob = bottom_blob.reshape(1, opt.blob_allocator);
  104. }
  105. }
  106. if (dims == 2)
  107. {
  108. if (_squeeze_w && _squeeze_h)
  109. {
  110. top_blob = bottom_blob.reshape(1, opt.blob_allocator);
  111. }
  112. else if (_squeeze_w)
  113. {
  114. top_blob = bottom_blob.reshape(h, opt.blob_allocator);
  115. }
  116. else if (_squeeze_h)
  117. {
  118. top_blob = bottom_blob.reshape(w, opt.blob_allocator);
  119. }
  120. }
  121. if (dims == 3)
  122. {
  123. if (_squeeze_w && _squeeze_h && _squeeze_c)
  124. {
  125. top_blob = bottom_blob.reshape(1, opt.blob_allocator);
  126. }
  127. else if (_squeeze_w && _squeeze_h)
  128. {
  129. top_blob = bottom_blob.reshape(channels, opt.blob_allocator);
  130. }
  131. else if (_squeeze_h && _squeeze_c)
  132. {
  133. top_blob = bottom_blob.reshape(w, opt.blob_allocator);
  134. }
  135. else if (_squeeze_w && _squeeze_c)
  136. {
  137. top_blob = bottom_blob.reshape(h, opt.blob_allocator);
  138. }
  139. else if (_squeeze_w)
  140. {
  141. top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator);
  142. }
  143. else if (_squeeze_h)
  144. {
  145. top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator);
  146. }
  147. else if (_squeeze_c)
  148. {
  149. top_blob = bottom_blob.reshape(w, h, opt.blob_allocator);
  150. }
  151. }
  152. if (dims == 4)
  153. {
  154. if (_squeeze_w && _squeeze_h && _squeeze_d && _squeeze_c)
  155. {
  156. top_blob = bottom_blob.reshape(1, opt.blob_allocator);
  157. }
  158. else if (_squeeze_w && _squeeze_h && _squeeze_d)
  159. {
  160. top_blob = bottom_blob.reshape(channels, opt.blob_allocator);
  161. }
  162. else if (_squeeze_h && _squeeze_d && _squeeze_c)
  163. {
  164. top_blob = bottom_blob.reshape(w, opt.blob_allocator);
  165. }
  166. else if (_squeeze_w && _squeeze_d && _squeeze_c)
  167. {
  168. top_blob = bottom_blob.reshape(h, opt.blob_allocator);
  169. }
  170. else if (_squeeze_w && _squeeze_h && _squeeze_c)
  171. {
  172. top_blob = bottom_blob.reshape(d, opt.blob_allocator);
  173. }
  174. else if (_squeeze_w && _squeeze_h)
  175. {
  176. top_blob = bottom_blob.reshape(d, channels, opt.blob_allocator);
  177. }
  178. else if (_squeeze_w && _squeeze_d)
  179. {
  180. top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator);
  181. }
  182. else if (_squeeze_h && _squeeze_d)
  183. {
  184. top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator);
  185. }
  186. else if (_squeeze_h && _squeeze_c)
  187. {
  188. top_blob = bottom_blob.reshape(w, d, opt.blob_allocator);
  189. }
  190. else if (_squeeze_w && _squeeze_c)
  191. {
  192. top_blob = bottom_blob.reshape(h, d, opt.blob_allocator);
  193. }
  194. else if (_squeeze_d && _squeeze_c)
  195. {
  196. top_blob = bottom_blob.reshape(w, h, opt.blob_allocator);
  197. }
  198. else if (_squeeze_w)
  199. {
  200. top_blob = bottom_blob.reshape(h, d, channels, opt.blob_allocator);
  201. }
  202. else if (_squeeze_h)
  203. {
  204. top_blob = bottom_blob.reshape(w, d, channels, opt.blob_allocator);
  205. }
  206. else if (_squeeze_d)
  207. {
  208. top_blob = bottom_blob.reshape(w, h, channels, opt.blob_allocator);
  209. }
  210. else if (_squeeze_c)
  211. {
  212. top_blob = bottom_blob.reshape(w, h, d, opt.blob_allocator);
  213. }
  214. }
  215. if (top_blob.empty())
  216. return -100;
  217. return 0;
  218. }
  219. } // namespace ncnn