You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interp.cpp 3.0 kB

8 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "interp.h"
  15. #include <algorithm>
  16. namespace ncnn {
  17. DEFINE_LAYER_CREATOR(Interp);
  18. Interp::Interp()
  19. {
  20. one_blob_only = true;
  21. }
  22. int Interp::load_param(const ParamDict& pd)
  23. {
  24. resize_type = pd.get(0, 0);
  25. height_scale = pd.get(1, 1.f);
  26. width_scale = pd.get(2, 1.f);
  27. output_height = pd.get(3, 0);
  28. output_width = pd.get(4, 0);
  29. return 0;
  30. }
  31. int Interp::forward(const Mat &bottom_blob, Mat &top_blob, const Option& opt) const
  32. {
  33. int h = bottom_blob.h;
  34. int w = bottom_blob.w;
  35. int c = bottom_blob.c;
  36. size_t elemsize = bottom_blob.elemsize;
  37. int oh = output_height;
  38. int ow = output_width;
  39. if (bottom_blob.dims == 1)
  40. {
  41. h = 1;
  42. w = 1;
  43. c = bottom_blob.w;
  44. }
  45. if (oh == 0 || ow == 0)
  46. {
  47. oh = h * height_scale;
  48. ow = w * width_scale;
  49. }
  50. if (oh == h && ow == w)
  51. {
  52. top_blob = bottom_blob;
  53. return 0;
  54. }
  55. top_blob.create(ow, oh, c, elemsize, opt.blob_allocator);
  56. if (top_blob.empty())
  57. return -100;
  58. if (bottom_blob.dims == 1)
  59. {
  60. #pragma omp parallel for num_threads(opt.num_threads)
  61. for (int q = 0; q < c; ++q)
  62. {
  63. Mat top_blob_c = top_blob.channel(q);
  64. const float *ptr = ((const float*)bottom_blob.data + q);
  65. top_blob_c.fill(*ptr);
  66. }
  67. return 0;
  68. }
  69. if (resize_type == 1)//nearest
  70. {
  71. #pragma omp parallel for num_threads(opt.num_threads)
  72. for (int q = 0; q < c; ++q)
  73. {
  74. const float *ptr = bottom_blob.channel(q);
  75. float *output_ptr = top_blob.channel(q);
  76. for (int y = 0; y < oh; ++y)
  77. {
  78. const int in_y = std::min((int) (y / height_scale), (h - 1));
  79. for (int x = 0; x < ow; ++x)
  80. {
  81. const int in_x = std::min((int) (x / width_scale), (w - 1));
  82. output_ptr[ow * y + x] = ptr[in_y * w + in_x];
  83. }
  84. }
  85. }
  86. return 0;
  87. }
  88. else if (resize_type == 2)// bilinear
  89. {
  90. resize_bilinear(bottom_blob, top_blob, ow, oh);
  91. return 0;
  92. }
  93. else
  94. {
  95. fprintf(stderr, "unsupported resize type %d %d %d\n", resize_type, oh, ow);
  96. return -233;
  97. }
  98. }
  99. } // namespace ncnn