You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

psroipooling.cpp 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "psroipooling.h"
  15. #include <math.h>
  16. #include <algorithm>
  17. namespace ncnn {
  18. DEFINE_LAYER_CREATOR(PSROIPooling)
  19. PSROIPooling::PSROIPooling()
  20. {
  21. one_blob_only = false;
  22. support_inplace = false;
  23. }
  24. int PSROIPooling::load_param(const ParamDict& pd)
  25. {
  26. pooled_width = pd.get(0, 7);
  27. pooled_height = pd.get(1, 7);
  28. spatial_scale = pd.get(2, 0.0625f);
  29. output_dim = pd.get(3, 0);
  30. return 0;
  31. }
  32. int PSROIPooling::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
  33. {
  34. const Mat& bottom_blob = bottom_blobs[0];
  35. int w = bottom_blob.w;
  36. int h = bottom_blob.h;
  37. size_t elemsize = bottom_blob.elemsize;
  38. int channels = bottom_blob.c;
  39. const Mat& roi_blob = bottom_blobs[1];
  40. if (channels != output_dim * pooled_width * pooled_height)
  41. {
  42. // input channel number does not match layer parameters
  43. return -1;
  44. }
  45. Mat& top_blob = top_blobs[0];
  46. top_blob.create(pooled_width, pooled_height, output_dim, elemsize, opt.blob_allocator);
  47. if (top_blob.empty())
  48. return -100;
  49. // For each ROI R = [x y w h]: avg pool over R
  50. const float* roi_ptr = roi_blob;
  51. float roi_x1 = round(roi_ptr[0]) * spatial_scale;
  52. float roi_y1 = round(roi_ptr[1]) * spatial_scale;
  53. float roi_x2 = round(roi_ptr[2] + 1.f) * spatial_scale;
  54. float roi_y2 = round(roi_ptr[3] + 1.f) * spatial_scale;
  55. float roi_w = std::max(roi_x2 - roi_x1, 0.1f);
  56. float roi_h = std::max(roi_y2 - roi_y1, 0.1f);
  57. float bin_size_w = roi_w / (float)pooled_width;
  58. float bin_size_h = roi_h / (float)pooled_height;
  59. #pragma omp parallel for num_threads(opt.num_threads)
  60. for (int q=0; q<output_dim; q++)
  61. {
  62. float* outptr = top_blob.channel(q);
  63. for (int ph = 0; ph < pooled_height; ph++)
  64. {
  65. for (int pw = 0; pw < pooled_width; pw++)
  66. {
  67. const float* ptr = bottom_blob.channel((q * pooled_height + ph) * pooled_width + pw);
  68. int hstart = floor(roi_y1 + (float)(ph) * bin_size_h);
  69. int wstart = floor(roi_x1 + (float)(pw) * bin_size_w);
  70. int hend = ceil(roi_y1 + (float)(ph + 1) * bin_size_h);
  71. int wend = ceil(roi_x1 + (float)(pw + 1) * bin_size_w);
  72. hstart = std::min(std::max(hstart, 0), h);
  73. wstart = std::min(std::max(wstart, 0), w);
  74. hend = std::min(std::max(hend, 0), h);
  75. wend = std::min(std::max(wend, 0), w);
  76. bool is_empty = (hend <= hstart) || (wend <= wstart);
  77. int area = (hend - hstart) * (wend - wstart);
  78. float sum = 0.f;
  79. for (int y = hstart; y < hend; y++)
  80. {
  81. for (int x = wstart; x < wend; x++)
  82. {
  83. int index = y * w + x;
  84. sum += ptr[index];
  85. }
  86. }
  87. outptr[pw] = is_empty ? 0.f : (sum / (float)area);
  88. }
  89. outptr += pooled_width;
  90. }
  91. }
  92. return 0;
  93. }
  94. } // namespace ncnn