You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batch_to_space.c 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "nnacl/batch_to_space.h"
  17. #include "nnacl/arithmetic_common.h"
  18. void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block,
  19. int data_size) {
  20. int block_h = block[0];
  21. int block_w = block[1];
  22. int in_h = in_shape[1];
  23. int in_w = in_shape[2];
  24. int in_c = in_shape[3];
  25. size_t stride_h = block_w * out_n;
  26. size_t output_offset = 0;
  27. size_t copy_size = in_c * data_size;
  28. size_t in_stride_h = in_w * in_c;
  29. size_t in_stride_n = in_stride_h * in_h;
  30. for (int n = 0; n < out_n; ++n) {
  31. for (int h = 0; h < in_h; ++h) {
  32. size_t h_offset = h * in_stride_h;
  33. for (int bh = 0; bh < block_h; ++bh) {
  34. for (int w = 0; w < in_w; ++w) {
  35. size_t w_offset = w * in_c;
  36. for (int bw = 0; bw < block_w; ++bw) {
  37. size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset;
  38. memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size);
  39. output_offset += copy_size;
  40. }
  41. }
  42. }
  43. }
  44. }
  45. }
  46. void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block,
  47. const int *crops, int data_size) {
  48. int block_h = block[0];
  49. int block_w = block[1];
  50. int in_h = in_shape[1];
  51. int in_w = in_shape[2];
  52. int in_c = in_shape[3];
  53. int h_start = crops[0] / block_h;
  54. int h_valid_begin = crops[0];
  55. int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h);
  56. int h_valid_end = in_h * block_h - crops[1] - 1;
  57. int w_start = crops[2] / block_w;
  58. int w_valid_begin = crops[2];
  59. int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w);
  60. int w_valid_end = in_w * block_w - crops[3] - 1;
  61. size_t stride_h = block_w * out_n;
  62. size_t output_offset = 0;
  63. size_t copy_size = in_c * data_size;
  64. size_t in_stride_h = in_w * in_c;
  65. size_t in_stride_n = in_stride_h * in_h;
  66. for (int n = 0; n < out_n; ++n) {
  67. for (int h = h_start; h < h_end; ++h) {
  68. size_t h_offset = h * in_stride_h;
  69. for (int bh = 0; bh < block_h; ++bh) {
  70. size_t h_index = h * block_h + bh;
  71. if (h_index < h_valid_begin || h_index > h_valid_end) {
  72. continue;
  73. }
  74. for (int w = w_start; w < w_end; ++w) {
  75. size_t w_offset = w * in_c;
  76. for (int bw = 0; bw < block_w; ++bw) {
  77. size_t w_index = w * block_w + bw;
  78. if (w_index < w_valid_begin || w_index > w_valid_end) {
  79. continue;
  80. }
  81. size_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset;
  82. memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size);
  83. output_offset += copy_size;
  84. }
  85. }
  86. }
  87. }
  88. }
  89. }