You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ROIAlignRotated_cpu.cpp 16 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. #include <ATen/TensorUtils.h>
  3. #include "ROIAlignRotated.h"
  4. // Note: this implementation originates from the Caffe2 ROIAlignRotated Op
  5. // and PyTorch ROIAlign (non-rotated) Op implementations.
  6. // The key difference between this implementation and those ones is
  7. // we don't do "legacy offset" in this version, as there aren't many previous
  8. // works, if any, using the "legacy" ROIAlignRotated Op.
  9. // This would make the interface a bit cleaner.
  10. namespace detectron2 {
  11. namespace {
  12. template <typename T>
  13. struct PreCalc {
  14. int pos1;
  15. int pos2;
  16. int pos3;
  17. int pos4;
  18. T w1;
  19. T w2;
  20. T w3;
  21. T w4;
  22. };
  23. template <typename T>
  24. void pre_calc_for_bilinear_interpolate(
  25. const int height,
  26. const int width,
  27. const int pooled_height,
  28. const int pooled_width,
  29. const int iy_upper,
  30. const int ix_upper,
  31. T roi_start_h,
  32. T roi_start_w,
  33. T bin_size_h,
  34. T bin_size_w,
  35. int roi_bin_grid_h,
  36. int roi_bin_grid_w,
  37. T roi_center_h,
  38. T roi_center_w,
  39. T cos_theta,
  40. T sin_theta,
  41. std::vector<PreCalc<T>>& pre_calc) {
  42. int pre_calc_index = 0;
  43. for (int ph = 0; ph < pooled_height; ph++) {
  44. for (int pw = 0; pw < pooled_width; pw++) {
  45. for (int iy = 0; iy < iy_upper; iy++) {
  46. const T yy = roi_start_h + ph * bin_size_h +
  47. static_cast<T>(iy + .5f) * bin_size_h /
  48. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  49. for (int ix = 0; ix < ix_upper; ix++) {
  50. const T xx = roi_start_w + pw * bin_size_w +
  51. static_cast<T>(ix + .5f) * bin_size_w /
  52. static_cast<T>(roi_bin_grid_w);
  53. // Rotate by theta around the center and translate
  54. // In image space, (y, x) is the order for Right Handed System,
  55. // and this is essentially multiplying the point by a rotation matrix
  56. // to rotate it counterclockwise through angle theta.
  57. T y = yy * cos_theta - xx * sin_theta + roi_center_h;
  58. T x = yy * sin_theta + xx * cos_theta + roi_center_w;
  59. // deal with: inverse elements are out of feature map boundary
  60. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  61. // empty
  62. PreCalc<T> pc;
  63. pc.pos1 = 0;
  64. pc.pos2 = 0;
  65. pc.pos3 = 0;
  66. pc.pos4 = 0;
  67. pc.w1 = 0;
  68. pc.w2 = 0;
  69. pc.w3 = 0;
  70. pc.w4 = 0;
  71. pre_calc[pre_calc_index] = pc;
  72. pre_calc_index += 1;
  73. continue;
  74. }
  75. if (y < 0) {
  76. y = 0;
  77. }
  78. if (x < 0) {
  79. x = 0;
  80. }
  81. int y_low = (int)y;
  82. int x_low = (int)x;
  83. int y_high;
  84. int x_high;
  85. if (y_low >= height - 1) {
  86. y_high = y_low = height - 1;
  87. y = (T)y_low;
  88. } else {
  89. y_high = y_low + 1;
  90. }
  91. if (x_low >= width - 1) {
  92. x_high = x_low = width - 1;
  93. x = (T)x_low;
  94. } else {
  95. x_high = x_low + 1;
  96. }
  97. T ly = y - y_low;
  98. T lx = x - x_low;
  99. T hy = 1. - ly, hx = 1. - lx;
  100. T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  101. // save weights and indices
  102. PreCalc<T> pc;
  103. pc.pos1 = y_low * width + x_low;
  104. pc.pos2 = y_low * width + x_high;
  105. pc.pos3 = y_high * width + x_low;
  106. pc.pos4 = y_high * width + x_high;
  107. pc.w1 = w1;
  108. pc.w2 = w2;
  109. pc.w3 = w3;
  110. pc.w4 = w4;
  111. pre_calc[pre_calc_index] = pc;
  112. pre_calc_index += 1;
  113. }
  114. }
  115. }
  116. }
  117. }
  118. template <typename T>
  119. void bilinear_interpolate_gradient(
  120. const int height,
  121. const int width,
  122. T y,
  123. T x,
  124. T& w1,
  125. T& w2,
  126. T& w3,
  127. T& w4,
  128. int& x_low,
  129. int& x_high,
  130. int& y_low,
  131. int& y_high) {
  132. // deal with cases that inverse elements are out of feature map boundary
  133. if (y < -1.0 || y > height || x < -1.0 || x > width) {
  134. // empty
  135. w1 = w2 = w3 = w4 = 0.;
  136. x_low = x_high = y_low = y_high = -1;
  137. return;
  138. }
  139. if (y < 0) {
  140. y = 0;
  141. }
  142. if (x < 0) {
  143. x = 0;
  144. }
  145. y_low = (int)y;
  146. x_low = (int)x;
  147. if (y_low >= height - 1) {
  148. y_high = y_low = height - 1;
  149. y = (T)y_low;
  150. } else {
  151. y_high = y_low + 1;
  152. }
  153. if (x_low >= width - 1) {
  154. x_high = x_low = width - 1;
  155. x = (T)x_low;
  156. } else {
  157. x_high = x_low + 1;
  158. }
  159. T ly = y - y_low;
  160. T lx = x - x_low;
  161. T hy = 1. - ly, hx = 1. - lx;
  162. // reference in forward
  163. // T v1 = input[y_low * width + x_low];
  164. // T v2 = input[y_low * width + x_high];
  165. // T v3 = input[y_high * width + x_low];
  166. // T v4 = input[y_high * width + x_high];
  167. // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  168. w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  169. return;
  170. }
  171. template <class T>
  172. inline void add(T* address, const T& val) {
  173. *address += val;
  174. }
  175. } // namespace
  176. template <typename T>
  177. void ROIAlignRotatedForward(
  178. const int nthreads,
  179. const T* input,
  180. const T& spatial_scale,
  181. const int channels,
  182. const int height,
  183. const int width,
  184. const int pooled_height,
  185. const int pooled_width,
  186. const int sampling_ratio,
  187. const T* rois,
  188. T* output) {
  189. int n_rois = nthreads / channels / pooled_width / pooled_height;
  190. // (n, c, ph, pw) is an element in the pooled output
  191. // can be parallelized using omp
  192. // #pragma omp parallel for num_threads(32)
  193. for (int n = 0; n < n_rois; n++) {
  194. int index_n = n * channels * pooled_width * pooled_height;
  195. const T* current_roi = rois + n * 6;
  196. int roi_batch_ind = current_roi[0];
  197. // Do not use rounding; this implementation detail is critical
  198. // ROIAlignRotated supports align == true, i.e., continuous coordinate
  199. // by default, thus the 0.5 offset
  200. T offset = (T)0.5;
  201. T roi_center_w = current_roi[1] * spatial_scale - offset;
  202. T roi_center_h = current_roi[2] * spatial_scale - offset;
  203. T roi_width = current_roi[3] * spatial_scale;
  204. T roi_height = current_roi[4] * spatial_scale;
  205. T theta = current_roi[5] * M_PI / 180.0;
  206. T cos_theta = cos(theta);
  207. T sin_theta = sin(theta);
  208. AT_ASSERTM(
  209. roi_width >= 0 && roi_height >= 0,
  210. "ROIs in ROIAlignRotated do not have non-negative size!");
  211. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  212. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  213. // We use roi_bin_grid to sample the grid and mimic integral
  214. int roi_bin_grid_h = (sampling_ratio > 0)
  215. ? sampling_ratio
  216. : ceil(roi_height / pooled_height); // e.g., = 2
  217. int roi_bin_grid_w =
  218. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  219. // We do average (integral) pooling inside a bin
  220. const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
  221. // we want to precalculate indices and weights shared by all channels,
  222. // this is the key point of optimization
  223. std::vector<PreCalc<T>> pre_calc(
  224. roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
  225. // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
  226. // Appropriate translation needs to be applied after.
  227. T roi_start_h = -roi_height / 2.0;
  228. T roi_start_w = -roi_width / 2.0;
  229. pre_calc_for_bilinear_interpolate(
  230. height,
  231. width,
  232. pooled_height,
  233. pooled_width,
  234. roi_bin_grid_h,
  235. roi_bin_grid_w,
  236. roi_start_h,
  237. roi_start_w,
  238. bin_size_h,
  239. bin_size_w,
  240. roi_bin_grid_h,
  241. roi_bin_grid_w,
  242. roi_center_h,
  243. roi_center_w,
  244. cos_theta,
  245. sin_theta,
  246. pre_calc);
  247. for (int c = 0; c < channels; c++) {
  248. int index_n_c = index_n + c * pooled_width * pooled_height;
  249. const T* offset_input =
  250. input + (roi_batch_ind * channels + c) * height * width;
  251. int pre_calc_index = 0;
  252. for (int ph = 0; ph < pooled_height; ph++) {
  253. for (int pw = 0; pw < pooled_width; pw++) {
  254. int index = index_n_c + ph * pooled_width + pw;
  255. T output_val = 0.;
  256. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  257. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  258. PreCalc<T> pc = pre_calc[pre_calc_index];
  259. output_val += pc.w1 * offset_input[pc.pos1] +
  260. pc.w2 * offset_input[pc.pos2] +
  261. pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
  262. pre_calc_index += 1;
  263. }
  264. }
  265. output_val /= count;
  266. output[index] = output_val;
  267. } // for pw
  268. } // for ph
  269. } // for c
  270. } // for n
  271. }
  272. template <typename T>
  273. void ROIAlignRotatedBackward(
  274. const int nthreads,
  275. const T* grad_output,
  276. const T& spatial_scale,
  277. const int channels,
  278. const int height,
  279. const int width,
  280. const int pooled_height,
  281. const int pooled_width,
  282. const int sampling_ratio,
  283. T* grad_input,
  284. const T* rois,
  285. const int n_stride,
  286. const int c_stride,
  287. const int h_stride,
  288. const int w_stride) {
  289. for (int index = 0; index < nthreads; index++) {
  290. // (n, c, ph, pw) is an element in the pooled output
  291. int pw = index % pooled_width;
  292. int ph = (index / pooled_width) % pooled_height;
  293. int c = (index / pooled_width / pooled_height) % channels;
  294. int n = index / pooled_width / pooled_height / channels;
  295. const T* current_roi = rois + n * 6;
  296. int roi_batch_ind = current_roi[0];
  297. // Do not use rounding; this implementation detail is critical
  298. // ROIAlignRotated supports align == true, i.e., continuous coordinate
  299. // by default, thus the 0.5 offset
  300. T offset = (T)0.5;
  301. T roi_center_w = current_roi[1] * spatial_scale - offset;
  302. T roi_center_h = current_roi[2] * spatial_scale - offset;
  303. T roi_width = current_roi[3] * spatial_scale;
  304. T roi_height = current_roi[4] * spatial_scale;
  305. T theta = current_roi[5] * M_PI / 180.0;
  306. T cos_theta = cos(theta);
  307. T sin_theta = sin(theta);
  308. AT_ASSERTM(
  309. roi_width >= 0 && roi_height >= 0,
  310. "ROIs in ROIAlignRotated do not have non-negative size!");
  311. T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
  312. T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
  313. T* offset_grad_input =
  314. grad_input + ((roi_batch_ind * channels + c) * height * width);
  315. int output_offset = n * n_stride + c * c_stride;
  316. const T* offset_grad_output = grad_output + output_offset;
  317. const T grad_output_this_bin =
  318. offset_grad_output[ph * h_stride + pw * w_stride];
  319. // We use roi_bin_grid to sample the grid and mimic integral
  320. int roi_bin_grid_h = (sampling_ratio > 0)
  321. ? sampling_ratio
  322. : ceil(roi_height / pooled_height); // e.g., = 2
  323. int roi_bin_grid_w =
  324. (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
  325. // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
  326. // Appropriate translation needs to be applied after.
  327. T roi_start_h = -roi_height / 2.0;
  328. T roi_start_w = -roi_width / 2.0;
  329. // We do average (integral) pooling inside a bin
  330. const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
  331. for (int iy = 0; iy < roi_bin_grid_h; iy++) {
  332. const T yy = roi_start_h + ph * bin_size_h +
  333. static_cast<T>(iy + .5f) * bin_size_h /
  334. static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
  335. for (int ix = 0; ix < roi_bin_grid_w; ix++) {
  336. const T xx = roi_start_w + pw * bin_size_w +
  337. static_cast<T>(ix + .5f) * bin_size_w /
  338. static_cast<T>(roi_bin_grid_w);
  339. // Rotate by theta around the center and translate
  340. T y = yy * cos_theta - xx * sin_theta + roi_center_h;
  341. T x = yy * sin_theta + xx * cos_theta + roi_center_w;
  342. T w1, w2, w3, w4;
  343. int x_low, x_high, y_low, y_high;
  344. bilinear_interpolate_gradient(
  345. height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
  346. T g1 = grad_output_this_bin * w1 / count;
  347. T g2 = grad_output_this_bin * w2 / count;
  348. T g3 = grad_output_this_bin * w3 / count;
  349. T g4 = grad_output_this_bin * w4 / count;
  350. if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
  351. // atomic add is not needed for now since it is single threaded
  352. add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
  353. add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
  354. add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
  355. add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
  356. } // if
  357. } // ix
  358. } // iy
  359. } // for
  360. } // ROIAlignRotatedBackward
  361. at::Tensor ROIAlignRotated_forward_cpu(
  362. const at::Tensor& input,
  363. const at::Tensor& rois,
  364. const float spatial_scale,
  365. const int pooled_height,
  366. const int pooled_width,
  367. const int sampling_ratio) {
  368. AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
  369. AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
  370. at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
  371. at::CheckedFrom c = "ROIAlign_forward_cpu";
  372. at::checkAllSameType(c, {input_t, rois_t});
  373. auto num_rois = rois.size(0);
  374. auto channels = input.size(1);
  375. auto height = input.size(2);
  376. auto width = input.size(3);
  377. at::Tensor output = at::zeros(
  378. {num_rois, channels, pooled_height, pooled_width}, input.options());
  379. auto output_size = num_rois * pooled_height * pooled_width * channels;
  380. if (output.numel() == 0) {
  381. return output;
  382. }
  383. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  384. input.type(), "ROIAlignRotated_forward", [&] {
  385. ROIAlignRotatedForward<scalar_t>(
  386. output_size,
  387. input.contiguous().data_ptr<scalar_t>(),
  388. spatial_scale,
  389. channels,
  390. height,
  391. width,
  392. pooled_height,
  393. pooled_width,
  394. sampling_ratio,
  395. rois.contiguous().data_ptr<scalar_t>(),
  396. output.data_ptr<scalar_t>());
  397. });
  398. return output;
  399. }
  400. at::Tensor ROIAlignRotated_backward_cpu(
  401. const at::Tensor& grad,
  402. const at::Tensor& rois,
  403. const float spatial_scale,
  404. const int pooled_height,
  405. const int pooled_width,
  406. const int batch_size,
  407. const int channels,
  408. const int height,
  409. const int width,
  410. const int sampling_ratio) {
  411. AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
  412. AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
  413. at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
  414. at::CheckedFrom c = "ROIAlignRotated_backward_cpu";
  415. at::checkAllSameType(c, {grad_t, rois_t});
  416. at::Tensor grad_input =
  417. at::zeros({batch_size, channels, height, width}, grad.options());
  418. // handle possibly empty gradients
  419. if (grad.numel() == 0) {
  420. return grad_input;
  421. }
  422. // get stride values to ensure indexing into gradients is correct.
  423. int n_stride = grad.stride(0);
  424. int c_stride = grad.stride(1);
  425. int h_stride = grad.stride(2);
  426. int w_stride = grad.stride(3);
  427. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  428. grad.type(), "ROIAlignRotated_forward", [&] {
  429. ROIAlignRotatedBackward<scalar_t>(
  430. grad.numel(),
  431. grad.contiguous().data_ptr<scalar_t>(),
  432. spatial_scale,
  433. channels,
  434. height,
  435. width,
  436. pooled_height,
  437. pooled_width,
  438. sampling_ratio,
  439. grad_input.data_ptr<scalar_t>(),
  440. rois.contiguous().data_ptr<scalar_t>(),
  441. n_stride,
  442. c_stride,
  443. h_stride,
  444. w_stride);
  445. });
  446. return grad_input;
  447. }
  448. } // namespace detectron2

No Description