You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ROIAlign_cpu.cpp 15 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. #include <ATen/TensorUtils.h>
  3. #include "ROIAlign.h"
  4. namespace {
  5. // implementation taken from Caffe2
  6. template <typename T>
  7. struct PreCalc {
  8. int pos1;
  9. int pos2;
  10. int pos3;
  11. int pos4;
  12. T w1;
  13. T w2;
  14. T w3;
  15. T w4;
  16. };
  17. template <typename T>
  18. void pre_calc_for_bilinear_interpolate(
  19. const int height,
  20. const int width,
  21. const int pooled_height,
  22. const int pooled_width,
  23. const int iy_upper,
  24. const int ix_upper,
  25. T roi_start_h,
  26. T roi_start_w,
  27. T bin_size_h,
  28. T bin_size_w,
  29. int roi_bin_grid_h,
  30. int roi_bin_grid_w,
  31. std::vector<PreCalc<T>>& pre_calc) {
  32. int pre_calc_index = 0;
  33. for (int ph = 0; ph < pooled_height; ph++) {
  34. for (int pw = 0; pw < pooled_width; pw++) {
  35. for (int iy = 0; iy < iy_upper; iy++) {
  36. const T yy = roi_start_h + ph * bin_size_h +
  37. static_cast<T>(iy + .5f) * bin_size_h /
  38. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  39. for (int ix = 0; ix < ix_upper; ix++) {
  40. const T xx = roi_start_w + pw * bin_size_w +
  41. static_cast<T>(ix + .5f) * bin_size_w /
  42. static_cast<T>(roi_bin_grid_w);
  43. T x = xx;
  44. T y = yy;
  45. // deal with: inverse elements are out of feature map boundary
  46. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  47. // empty
  48. PreCalc<T> pc;
  49. pc.pos1 = 0;
  50. pc.pos2 = 0;
  51. pc.pos3 = 0;
  52. pc.pos4 = 0;
  53. pc.w1 = 0;
  54. pc.w2 = 0;
  55. pc.w3 = 0;
  56. pc.w4 = 0;
  57. pre_calc[pre_calc_index] = pc;
  58. pre_calc_index += 1;
  59. continue;
  60. }
  61. if (y <= 0) {
  62. y = 0;
  63. }
  64. if (x <= 0) {
  65. x = 0;
  66. }
  67. int y_low = (int)y;
  68. int x_low = (int)x;
  69. int y_high;
  70. int x_high;
  71. if (y_low >= height - 1) {
  72. y_high = y_low = height - 1;
  73. y = (T)y_low;
  74. } else {
  75. y_high = y_low + 1;
  76. }
  77. if (x_low >= width - 1) {
  78. x_high = x_low = width - 1;
  79. x = (T)x_low;
  80. } else {
  81. x_high = x_low + 1;
  82. }
  83. T ly = y - y_low;
  84. T lx = x - x_low;
  85. T hy = 1. - ly, hx = 1. - lx;
  86. T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  87. // save weights and indices
  88. PreCalc<T> pc;
  89. pc.pos1 = y_low * width + x_low;
  90. pc.pos2 = y_low * width + x_high;
  91. pc.pos3 = y_high * width + x_low;
  92. pc.pos4 = y_high * width + x_high;
  93. pc.w1 = w1;
  94. pc.w2 = w2;
  95. pc.w3 = w3;
  96. pc.w4 = w4;
  97. pre_calc[pre_calc_index] = pc;
  98. pre_calc_index += 1;
  99. }
  100. }
  101. }
  102. }
  103. }
  104. template <typename T>
  105. void ROIAlignForward(
  106. const int nthreads,
  107. const T* input,
  108. const T& spatial_scale,
  109. const int channels,
  110. const int height,
  111. const int width,
  112. const int pooled_height,
  113. const int pooled_width,
  114. const int sampling_ratio,
  115. const T* rois,
  116. T* output,
  117. bool aligned) {
  118. int n_rois = nthreads / channels / pooled_width / pooled_height;
  119. // (n, c, ph, pw) is an element in the pooled output
  120. // can be parallelized using omp
  121. // #pragma omp parallel for num_threads(32)
  122. for (int n = 0; n < n_rois; n++) {
  123. int index_n = n * channels * pooled_width * pooled_height;
  124. const T* offset_rois = rois + n * 5;
  125. int roi_batch_ind = offset_rois[0];
  126. // Do not use rounding; this implementation detail is critical
  127. T offset = aligned ? (T)0.5 : (T)0.0;
  128. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  129. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  130. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  131. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  132. T roi_width = roi_end_w - roi_start_w;
  133. T roi_height = roi_end_h - roi_start_h;
  134. if (aligned) {
  135. AT_ASSERTM(
  136. roi_width >= 0 && roi_height >= 0,
  137. "ROIs in ROIAlign cannot have non-negative size!");
  138. } else { // for backward-compatibility only
  139. roi_width = std::max(roi_width, (T)1.);
  140. roi_height = std::max(roi_height, (T)1.);
  141. }
  142. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  143. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  144. // We use roi_bin_grid to sample the grid and mimic integral
  145. int roi_bin_grid_h = (sampling_ratio > 0)
  146. ? sampling_ratio
  147. : ceil(roi_height / pooled_height); // e.g., = 2
  148. int roi_bin_grid_w =
  149. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  150. // We do average (integral) pooling inside a bin
  151. // When the grid is empty, output zeros.
  152. const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
  153. // we want to precalculate indices and weights shared by all channels,
  154. // this is the key point of optimization
  155. std::vector<PreCalc<T>> pre_calc(
  156. roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
  157. pre_calc_for_bilinear_interpolate(
  158. height,
  159. width,
  160. pooled_height,
  161. pooled_width,
  162. roi_bin_grid_h,
  163. roi_bin_grid_w,
  164. roi_start_h,
  165. roi_start_w,
  166. bin_size_h,
  167. bin_size_w,
  168. roi_bin_grid_h,
  169. roi_bin_grid_w,
  170. pre_calc);
  171. for (int c = 0; c < channels; c++) {
  172. int index_n_c = index_n + c * pooled_width * pooled_height;
  173. const T* offset_input =
  174. input + (roi_batch_ind * channels + c) * height * width;
  175. int pre_calc_index = 0;
  176. for (int ph = 0; ph < pooled_height; ph++) {
  177. for (int pw = 0; pw < pooled_width; pw++) {
  178. int index = index_n_c + ph * pooled_width + pw;
  179. T output_val = 0.;
  180. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  181. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  182. PreCalc<T> pc = pre_calc[pre_calc_index];
  183. output_val += pc.w1 * offset_input[pc.pos1] +
  184. pc.w2 * offset_input[pc.pos2] +
  185. pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
  186. pre_calc_index += 1;
  187. }
  188. }
  189. output_val /= count;
  190. output[index] = output_val;
  191. } // for pw
  192. } // for ph
  193. } // for c
  194. } // for n
  195. }
  196. template <typename T>
  197. void bilinear_interpolate_gradient(
  198. const int height,
  199. const int width,
  200. T y,
  201. T x,
  202. T& w1,
  203. T& w2,
  204. T& w3,
  205. T& w4,
  206. int& x_low,
  207. int& x_high,
  208. int& y_low,
  209. int& y_high,
  210. const int index /* index for debug only*/) {
  211. // deal with cases that inverse elements are out of feature map boundary
  212. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  213. // empty
  214. w1 = w2 = w3 = w4 = 0.;
  215. x_low = x_high = y_low = y_high = -1;
  216. return;
  217. }
  218. if (y <= 0)
  219. y = 0;
  220. if (x <= 0)
  221. x = 0;
  222. y_low = (int)y;
  223. x_low = (int)x;
  224. if (y_low >= height - 1) {
  225. y_high = y_low = height - 1;
  226. y = (T)y_low;
  227. } else {
  228. y_high = y_low + 1;
  229. }
  230. if (x_low >= width - 1) {
  231. x_high = x_low = width - 1;
  232. x = (T)x_low;
  233. } else {
  234. x_high = x_low + 1;
  235. }
  236. T ly = y - y_low;
  237. T lx = x - x_low;
  238. T hy = 1. - ly, hx = 1. - lx;
  239. // reference in forward
  240. // T v1 = input[y_low * width + x_low];
  241. // T v2 = input[y_low * width + x_high];
  242. // T v3 = input[y_high * width + x_low];
  243. // T v4 = input[y_high * width + x_high];
  244. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  245. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  246. return;
  247. }
  248. template <class T>
  249. inline void add(T* address, const T& val) {
  250. *address += val;
  251. }
  252. template <typename T>
  253. void ROIAlignBackward(
  254. const int nthreads,
  255. const T* grad_output,
  256. const T& spatial_scale,
  257. const int channels,
  258. const int height,
  259. const int width,
  260. const int pooled_height,
  261. const int pooled_width,
  262. const int sampling_ratio,
  263. T* grad_input,
  264. const T* rois,
  265. const int n_stride,
  266. const int c_stride,
  267. const int h_stride,
  268. const int w_stride,
  269. bool aligned) {
  270. for (int index = 0; index < nthreads; index++) {
  271. // (n, c, ph, pw) is an element in the pooled output
  272. int pw = index % pooled_width;
  273. int ph = (index / pooled_width) % pooled_height;
  274. int c = (index / pooled_width / pooled_height) % channels;
  275. int n = index / pooled_width / pooled_height / channels;
  276. const T* offset_rois = rois + n * 5;
  277. int roi_batch_ind = offset_rois[0];
  278. // Do not use rounding; this implementation detail is critical
  279. T offset = aligned ? (T)0.5 : (T)0.0;
  280. T roi_start_w = offset_rois[1] * spatial_scale - offset;
  281. T roi_start_h = offset_rois[2] * spatial_scale - offset;
  282. T roi_end_w = offset_rois[3] * spatial_scale - offset;
  283. T roi_end_h = offset_rois[4] * spatial_scale - offset;
  284. T roi_width = roi_end_w - roi_start_w;
  285. T roi_height = roi_end_h - roi_start_h;
  286. if (aligned) {
  287. AT_ASSERTM(
  288. roi_width >= 0 && roi_height >= 0,
  289. "ROIs in ROIAlign do not have non-negative size!");
  290. } else { // for backward-compatibility only
  291. roi_width = std::max(roi_width, (T)1.);
  292. roi_height = std::max(roi_height, (T)1.);
  293. }
  294. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  295. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  296. T* offset_grad_input =
  297. grad_input + ((roi_batch_ind * channels + c) * height * width);
  298. int output_offset = n * n_stride + c * c_stride;
  299. const T* offset_grad_output = grad_output + output_offset;
  300. const T grad_output_this_bin =
  301. offset_grad_output[ph * h_stride + pw * w_stride];
  302. // We use roi_bin_grid to sample the grid and mimic integral
  303. int roi_bin_grid_h = (sampling_ratio > 0)
  304. ? sampling_ratio
  305. : ceil(roi_height / pooled_height); // e.g., = 2
  306. int roi_bin_grid_w =
  307. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  308. // We do average (integral) pooling inside a bin
  309. const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  310. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  311. const T y = roi_start_h + ph * bin_size_h +
  312. static_cast<T>(iy + .5f) * bin_size_h /
  313. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  314. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  315. const T x = roi_start_w + pw * bin_size_w +
  316. static_cast<T>(ix + .5f) * bin_size_w /
  317. static_cast<T>(roi_bin_grid_w);
  318. T w1, w2, w3, w4;
  319. int x_low, x_high, y_low, y_high;
  320. bilinear_interpolate_gradient(
  321. height,
  322. width,
  323. y,
  324. x,
  325. w1,
  326. w2,
  327. w3,
  328. w4,
  329. x_low,
  330. x_high,
  331. y_low,
  332. y_high,
  333. index);
  334. T g1 = grad_output_this_bin * w1 / count;
  335. T g2 = grad_output_this_bin * w2 / count;
  336. T g3 = grad_output_this_bin * w3 / count;
  337. T g4 = grad_output_this_bin * w4 / count;
  338. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  339. // atomic add is not needed for now since it is single threaded
  340. add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
  341. add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
  342. add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
  343. add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
  344. } // if
  345. } // ix
  346. } // iy
  347. } // for
  348. } // ROIAlignBackward
  349. } // namespace
  350. namespace detectron2 {
  351. at::Tensor ROIAlign_forward_cpu(
  352. const at::Tensor& input,
  353. const at::Tensor& rois,
  354. const float spatial_scale,
  355. const int pooled_height,
  356. const int pooled_width,
  357. const int sampling_ratio,
  358. bool aligned) {
  359. AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
  360. AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
  361. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  362. at::CheckedFrom c = "ROIAlign_forward_cpu";
  363. at::checkAllSameType(c, {input_t, rois_t});
  364. auto num_rois = rois.size(0);
  365. auto channels = input.size(1);
  366. auto height = input.size(2);
  367. auto width = input.size(3);
  368. at::Tensor output = at::zeros(
  369. {num_rois, channels, pooled_height, pooled_width}, input.options());
  370. auto output_size = num_rois * pooled_height * pooled_width * channels;
  371. if (output.numel() == 0)
  372. return output;
  373. AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
  374. ROIAlignForward<scalar_t>(
  375. output_size,
  376. input.contiguous().data_ptr<scalar_t>(),
  377. spatial_scale,
  378. channels,
  379. height,
  380. width,
  381. pooled_height,
  382. pooled_width,
  383. sampling_ratio,
  384. rois.contiguous().data_ptr<scalar_t>(),
  385. output.data_ptr<scalar_t>(),
  386. aligned);
  387. });
  388. return output;
  389. }
  390. at::Tensor ROIAlign_backward_cpu(
  391. const at::Tensor& grad,
  392. const at::Tensor& rois,
  393. const float spatial_scale,
  394. const int pooled_height,
  395. const int pooled_width,
  396. const int batch_size,
  397. const int channels,
  398. const int height,
  399. const int width,
  400. const int sampling_ratio,
  401. bool aligned) {
  402. AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
  403. AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
  404. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
  405. at::CheckedFrom c = "ROIAlign_backward_cpu";
  406. at::checkAllSameType(c, {grad_t, rois_t});
  407. at::Tensor grad_input =
  408. at::zeros({batch_size, channels, height, width}, grad.options());
  409. // handle possibly empty gradients
  410. if (grad.numel() == 0) {
  411. return grad_input;
  412. }
  413. // get stride values to ensure indexing into gradients is correct.
  414. int n_stride = grad.stride(0);
  415. int c_stride = grad.stride(1);
  416. int h_stride = grad.stride(2);
  417. int w_stride = grad.stride(3);
  418. AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
  419. ROIAlignBackward<scalar_t>(
  420. grad.numel(),
  421. grad.contiguous().data_ptr<scalar_t>(),
  422. spatial_scale,
  423. channels,
  424. height,
  425. width,
  426. pooled_height,
  427. pooled_width,
  428. sampling_ratio,
  429. grad_input.data_ptr<scalar_t>(),
  430. rois.contiguous().data_ptr<scalar_t>(),
  431. n_stride,
  432. c_stride,
  433. h_stride,
  434. w_stride,
  435. aligned);
  436. });
  437. return grad_input;
  438. }
  439. } // namespace detectron2

No Description