You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack_ext.c 7.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string.h>
  17. #include "nnacl/fp32_grad/pack_ext.h"
  18. static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
  19. void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
  20. const int pad_left = conv_param->pad_l_;
  21. const int pad_up = conv_param->pad_u_;
  22. const int stride_h = conv_param->stride_h_;
  23. const int stride_w = conv_param->stride_w_;
  24. const int dilation_h = conv_param->dilation_h_;
  25. const int dilation_w = conv_param->dilation_w_;
  26. const int kernel_h = conv_param->kernel_h_;
  27. const int kernel_w = conv_param->kernel_w_;
  28. const int in_height = conv_param->input_h_;
  29. const int in_width = conv_param->input_w_;
  30. const int output_h = conv_param->output_h_;
  31. const int output_w = conv_param->output_w_;
  32. const int channels = conv_param->input_channel_ / conv_param->group_;
  33. const int tot_channels = conv_param->input_channel_;
  34. int kernel_row, kernel_col, output_rows, output_col;
  35. int row_stride_offset = 0;
  36. for (output_rows = output_h; output_rows; output_rows--) {
  37. int col_stride_offset = 0;
  38. for (output_col = output_w; output_col; output_col--) {
  39. for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
  40. int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
  41. for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
  42. int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
  43. if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
  44. const int offset = (input_row * in_width + input_col) * tot_channels;
  45. memcpy(data_col, in_data + offset, sizeof(float) * channels);
  46. data_col += channels;
  47. } else {
  48. memset(data_col, 0, sizeof(float) * channels);
  49. data_col += channels;
  50. }
  51. }
  52. }
  53. col_stride_offset += stride_w;
  54. }
  55. row_stride_offset += stride_h;
  56. }
  57. }
  58. // output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
  59. void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose) {
  60. const int pad_left = conv_param->pad_l_;
  61. const int pad_up = conv_param->pad_u_;
  62. const int stride_h = conv_param->stride_h_;
  63. const int stride_w = conv_param->stride_w_;
  64. const int dilation_h = conv_param->dilation_h_;
  65. const int dilation_w = conv_param->dilation_w_;
  66. const int kernel_h = conv_param->kernel_h_;
  67. const int kernel_w = conv_param->kernel_w_;
  68. const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_;
  69. const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_;
  70. const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_;
  71. const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_;
  72. const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_;
  73. const int channels = tot_channels / conv_param->group_;
  74. int channel, kernel_row, kernel_col, output_rows, output_col;
  75. if (transpose) {
  76. for (channel = 0; channel < channels; channel++) {
  77. for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
  78. for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
  79. int input_row = -pad_up + kernel_row * dilation_h;
  80. for (output_rows = output_h; output_rows; output_rows--) {
  81. if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
  82. for (output_col = output_w; output_col; output_col--) {
  83. *(data_row++) = 0;
  84. }
  85. } else {
  86. int input_col = -pad_left + kernel_col * dilation_w;
  87. for (output_col = output_w; output_col; output_col--) {
  88. if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
  89. const int offset = (input_row * in_width + input_col) * tot_channels + channel;
  90. *(data_row++) = in_data[offset];
  91. } else {
  92. *(data_row++) = 0;
  93. }
  94. input_col += stride_w;
  95. }
  96. }
  97. input_row += stride_h;
  98. }
  99. }
  100. }
  101. }
  102. } else {
  103. for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
  104. for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
  105. for (channel = 0; channel < channels; channel++) {
  106. int input_row = -pad_up + kernel_row * dilation_h;
  107. for (output_rows = output_h; output_rows; output_rows--) {
  108. if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
  109. for (output_col = output_w; output_col; output_col--) {
  110. *(data_row++) = 0;
  111. }
  112. } else {
  113. int input_col = -pad_left + kernel_col * dilation_w;
  114. for (output_col = output_w; output_col; output_col--) {
  115. if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
  116. const int offset = (input_row * in_width + input_col) * tot_channels + channel;
  117. *(data_row++) = in_data[offset];
  118. } else {
  119. *(data_row++) = 0;
  120. }
  121. input_col += stride_w;
  122. }
  123. }
  124. input_row += stride_h;
  125. }
  126. }
  127. }
  128. }
  129. }
  130. }
  131. void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) {
  132. const int pad_left = conv_param->pad_l_;
  133. const int pad_up = conv_param->pad_u_;
  134. const int stride_h = conv_param->stride_h_;
  135. const int stride_w = conv_param->stride_w_;
  136. const int dilation_h = conv_param->dilation_h_;
  137. const int dilation_w = conv_param->dilation_w_;
  138. const int kernel_h = conv_param->kernel_h_;
  139. const int kernel_w = conv_param->kernel_w_;
  140. const int in_height = conv_param->input_h_;
  141. const int in_width = conv_param->input_w_;
  142. const int output_h = conv_param->output_h_;
  143. const int output_w = conv_param->output_w_;
  144. const int channels = conv_param->input_channel_ / conv_param->group_;
  145. const int tot_channels = conv_param->input_channel_;
  146. int kernel_row, kernel_col, output_rows, output_col;
  147. int row_stride_offset = 0;
  148. for (output_rows = output_h; output_rows; output_rows--) {
  149. int col_stride_offset = 0;
  150. for (output_col = output_w; output_col; output_col--) {
  151. for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
  152. int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
  153. for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
  154. int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
  155. if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
  156. int offset = (input_row * in_width + input_col) * tot_channels;
  157. float *data_im_ptr = &data_im[offset];
  158. for (int i = 0; i < channels; i++) {
  159. data_im_ptr[i] += data_col[i];
  160. }
  161. }
  162. data_col += channels;
  163. }
  164. }
  165. col_stride_offset += stride_w;
  166. }
  167. row_stride_offset += stride_h;
  168. }
  169. }