You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack.c 41 kB

5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <stdlib.h>
  17. #include <string.h>
  18. #include "nnacl/int8/conv_int8.h"
  19. #include "nnacl/pack.h"
  20. void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
  21. return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
  22. }
  23. void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
  24. for (int i = 0; i < height; ++i) {
  25. for (int j = 0; j < width; ++j) {
  26. memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
  27. }
  28. }
  29. }
  30. void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) {
  31. // original weight format : ohwi
  32. int kernel_h = conv_param->kernel_h_;
  33. int kernel_w = conv_param->kernel_w_;
  34. int in_channel = conv_param->input_channel_;
  35. int out_channel = conv_param->output_channel_;
  36. int oc4 = UP_DIV(out_channel, C4NUM);
  37. int ic4 = UP_DIV(in_channel, C4NUM);
  38. int kernel_plane = kernel_h * kernel_w;
  39. int plane_c4 = UP_DIV(kernel_plane, C4NUM);
  40. int pack_weight_size = oc4 * C4NUM * ic4 * C4NUM * plane_c4 * C4NUM;
  41. int block_size = pack_weight_size / oc4;
  42. QuantArg *filter_args = conv_param->conv_quant_arg_.filter_quant_args_;
  43. for (int m = 0; m < kernel_plane; m++) {
  44. int kernel_plane_stride = m * in_channel;
  45. int plane_block = m / C4NUM;
  46. int plane_res = m % C4NUM;
  47. int packed_kernel_plane_stride = plane_block * C4NUM * C4NUM * ic4 * C4NUM + plane_res * C4NUM;
  48. for (int i = 0; i < ic4; i++) {
  49. int channel_block_stride = kernel_plane_stride + i * C4NUM;
  50. int packed_channel_block_size = packed_kernel_plane_stride + i * C4NUM * C4NUM * C4NUM;
  51. int ic_remainder = in_channel - i * C4NUM;
  52. int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM;
  53. for (int h = 0; h < real_ic_num; h++) {
  54. int block_stride = channel_block_stride + h;
  55. int packed_block_stride = packed_channel_block_size + h;
  56. for (int j = 0; j < oc4; j++) {
  57. int kernel_block_stride = block_stride + j * C4NUM * kernel_plane * in_channel;
  58. int packed_kernel_block_size = packed_block_stride + j * block_size;
  59. int oc_remainder = out_channel - j * C4NUM;
  60. int real_oc_num = oc_remainder < C4NUM ? oc_remainder : C4NUM;
  61. for (int k = 0; k < real_oc_num; k++) {
  62. int8_t *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
  63. int8_t *packed_data_ptr = packed_weight + packed_kernel_block_size + k * C4NUM * C4NUM;
  64. *packed_data_ptr = origin_data_ptr[0];
  65. int32_t f_zp;
  66. if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
  67. f_zp = filter_args[j * C4NUM + k].zp_;
  68. } else {
  69. f_zp = filter_args[0].zp_;
  70. }
  71. weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - f_zp);
  72. }
  73. } // kernel block loop
  74. } // inchannel block loop
  75. } // channel block loop
  76. } // kernel plane loop
  77. }
  78. void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
  79. /* support nhwc */
  80. char *src = (char *)src_ptr;
  81. char *dst = (char *)dst_ptr;
  82. for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
  83. int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
  84. if (src_h < 0 || src_h >= conv_param->input_h_) {
  85. continue;
  86. }
  87. const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
  88. char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
  89. for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
  90. int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
  91. if (src_w < 0 || src_w >= conv_param->input_w_) {
  92. continue;
  93. }
  94. memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size,
  95. src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size);
  96. }
  97. }
  98. return;
  99. }
  100. void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param) {
  101. int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
  102. for (int ic = 0; ic < conv_param->input_channel_; ic++) {
  103. for (int oc = 0; oc < conv_param->output_channel_; oc++) {
  104. int oc4mod = oc % 4;
  105. int oc4div = oc / 4;
  106. int dst_index = oc4div * c4 * C4NUM + ic * C4NUM + oc4mod;
  107. int src_index = oc * conv_param->input_channel_ + ic;
  108. packed_weight[dst_index] = weight_data[src_index];
  109. }
  110. }
  111. return;
  112. }
  113. void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
  114. /* normal matmul : 4x16 * 16x4 -> 4x4 */
  115. #ifdef ENABLE_ARM
  116. PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp);
  117. #else
  118. for (int r = 0; r < row4; r++) {
  119. int32_t tmp_value = 0;
  120. for (int c = 0; c < col16; c++) {
  121. int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
  122. int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
  123. tmp_value += src[src_index];
  124. }
  125. dst[r] = tmp_value * filter_zp;
  126. }
  127. #endif
  128. return;
  129. }
  130. void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
  131. size_t plane_size, size_t input_channel, size_t output_channel) {
  132. size_t hw4 = UP_ROUND(plane_size, C4NUM);
  133. size_t ic16 = UP_ROUND(input_channel, C16NUM);
  134. #ifdef ENABLE_ARM64
  135. size_t oc_div4 = output_channel / C4NUM * C4NUM;
  136. size_t oc_res4 = output_channel - oc_div4;
  137. size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
  138. PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
  139. #else
  140. for (int ri = 0; ri < plane_size; ri++) {
  141. int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
  142. for (int ci = 0; ci < output_channel; ci++) {
  143. int32_t tmp_sum_value = 0;
  144. int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
  145. int32_t filter_zp = filter_zp_ptr[ci];
  146. for (int di = 0; di < input_channel; di++) {
  147. size_t di16div = di / C16NUM, di16mod = di % C16NUM;
  148. int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
  149. tmp_sum_value += input_value[src_index];
  150. }
  151. int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
  152. input_sum[dst_index] = tmp_sum_value * filter_zp;
  153. }
  154. }
  155. #endif
  156. return;
  157. }
  158. void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
  159. size_t plane_size, size_t input_channel, size_t output_channel) {
  160. size_t hw4 = UP_ROUND(plane_size, C4NUM);
  161. size_t ic16 = UP_ROUND(input_channel, C16NUM);
  162. #ifdef ENABLE_ARM32
  163. size_t oc_div2 = output_channel / C2NUM * C2NUM;
  164. size_t oc_res2 = output_channel - oc_div2;
  165. size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
  166. PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
  167. #else
  168. for (int ri = 0; ri < plane_size; ri++) {
  169. int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
  170. for (int ci = 0; ci < output_channel; ci++) {
  171. int32_t tmp_sum_value = 0;
  172. int ci2div = ci / C2NUM, ci2mod = ci % C2NUM;
  173. int32_t filter_zp = filter_zp_ptr[ci];
  174. for (int di = 0; di < input_channel; di++) {
  175. size_t di16div = di / C16NUM, di16mod = di % C16NUM;
  176. int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
  177. tmp_sum_value += input_value[src_index];
  178. }
  179. int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod;
  180. input_sum[dst_index] = tmp_sum_value * filter_zp;
  181. }
  182. }
  183. #endif
  184. return;
  185. }
  186. void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) {
  187. size_t hw4 = UP_ROUND(conv_param->input_h_ * conv_param->input_w_, C4NUM);
  188. size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
  189. if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
  190. PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
  191. } else {
  192. #ifdef ENABLE_ARM32
  193. PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_,
  194. conv_param->input_channel_, conv_param->output_channel_);
  195. #else
  196. PackInputSum16x4PerChannel(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_,
  197. conv_param->input_channel_, conv_param->output_channel_);
  198. #endif
  199. }
  200. return;
  201. }
  202. void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num,
  203. int block_index) {
  204. // input format : nhwc
  205. int kernel_h = conv_param->kernel_h_;
  206. int kernel_w = conv_param->kernel_w_;
  207. int kernel_plane = kernel_h * kernel_w;
  208. int stride_h = conv_param->stride_h_;
  209. int stride_w = conv_param->stride_w_;
  210. int pad_h = conv_param->pad_u_;
  211. int pad_w = conv_param->pad_l_;
  212. int dilation_h = conv_param->dilation_h_;
  213. int dilation_w = conv_param->dilation_w_;
  214. int in_channel = conv_param->input_channel_;
  215. int in_h = conv_param->input_h_;
  216. int in_w = conv_param->input_w_;
  217. int out_w = conv_param->output_w_;
  218. for (int i = 0; i < real_cal_num; i++) {
  219. int block_start = block_index + i;
  220. int input_h = block_start / out_w * stride_h - pad_h;
  221. int input_w = block_start % out_w * stride_w - pad_w;
  222. int input_stride = (input_h * in_w + input_w) * in_channel;
  223. int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
  224. int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
  225. int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
  226. int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
  227. if (dilation_w == 1 && dilation_h == 1) {
  228. for (int j = kh_s; j < kh_e; j++) {
  229. int input_y_stride = j * in_w * in_channel + input_stride;
  230. int input_x_stride = input_y_stride + kw_s * in_channel;
  231. int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
  232. memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
  233. (kw_e - kw_s) * in_channel * sizeof(float));
  234. } // kernel_h loop
  235. } else {
  236. for (int j = kh_s; j < kh_e; j++) {
  237. int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
  238. for (int k = kw_s; k < kw_e; ++k) {
  239. int input_x_stride = input_y_stride + k * dilation_w * in_channel;
  240. int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
  241. memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
  242. }
  243. } // kernel_h loop
  244. }
  245. } // tile num loop
  246. }
  247. void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num,
  248. int block_index, int32_t *filter_zp, int32_t *input_sum, ConvParameter *conv_param,
  249. bool per_channel, bool is_optimize) {
  250. // input format : nhwc
  251. int kernel_h = conv_param->kernel_h_;
  252. int kernel_w = conv_param->kernel_w_;
  253. int stride_h = conv_param->stride_h_;
  254. int stride_w = conv_param->stride_w_;
  255. int pad_h = conv_param->pad_u_;
  256. int pad_w = conv_param->pad_l_;
  257. int dilation_h = conv_param->dilation_h_;
  258. int dilation_w = conv_param->dilation_w_;
  259. int in_channel = conv_param->input_channel_;
  260. int in_h = conv_param->input_h_;
  261. int in_w = conv_param->input_w_;
  262. int out_w = conv_param->output_w_;
  263. int kernel_plane = kernel_h * kernel_w;
  264. for (int i = 0; i < real_cal_num; i++) {
  265. int block_start = block_index + i;
  266. int input_h = block_start / out_w * stride_h - pad_h;
  267. int input_w = block_start % out_w * stride_w - pad_w;
  268. int input_stride = input_h * in_w * in_channel + input_w * in_channel;
  269. int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
  270. int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
  271. int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
  272. int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
  273. if (dilation_w == 1 && dilation_h == 1) {
  274. for (int j = kh_s; j < kh_e; j++) {
  275. int input_y_stride = j * in_w * in_channel + input_stride;
  276. int input_x_stride = input_y_stride + kw_s * in_channel;
  277. int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
  278. memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel);
  279. } // kernel_h loop
  280. } else {
  281. for (int j = kh_s; j < kh_e; j++) {
  282. int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
  283. for (int k = kw_s; k < kw_e; ++k) {
  284. int input_x_stride = input_y_stride + k * dilation_w * in_channel;
  285. int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
  286. memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel);
  287. }
  288. } // kernel_h loop
  289. }
  290. } // tile num loop
  291. int deep = kernel_plane * in_channel;
  292. if (is_optimize) {
  293. if (per_channel) {
  294. Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num,
  295. filter_zp, C8NUM * C8NUM);
  296. } else {
  297. Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param);
  298. }
  299. } else {
  300. RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep);
  301. if (per_channel) {
  302. #ifdef ENABLE_ARM32
  303. PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep,
  304. conv_param->output_channel_);
  305. #else
  306. PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_);
  307. #endif
  308. } else {
  309. size_t hw4 = UP_ROUND(real_cal_num, C4NUM);
  310. size_t ic16 = UP_ROUND(deep, C16NUM);
  311. PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4,
  312. ic16);
  313. }
  314. }
  315. }
  316. void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param) {
  317. int in_batch = conv_param->input_batch_;
  318. int in_channel = conv_param->input_channel_;
  319. int in_h = conv_param->input_h_;
  320. int in_w = conv_param->input_w_;
  321. int ic8 = UP_DIV(in_channel, C8NUM);
  322. int ic8_minus = ic8 - 1;
  323. for (int b = 0; b < in_batch; b++) {
  324. int src_batch_offset = b * in_channel * in_h * in_w;
  325. int dst_batch_offset = b * ic8 * C8NUM * in_h * in_w;
  326. for (int k = 0; k < in_w * in_h; k++) {
  327. int src_plane_offset = src_batch_offset + k * in_channel;
  328. int dst_plane_offset = dst_batch_offset + k * C8NUM;
  329. for (int i = 0; i < ic8_minus; ++i) {
  330. int src_c_offset = src_plane_offset + i * C8NUM;
  331. int dst_c_offset = dst_plane_offset + i * C8NUM * in_h * in_w;
  332. #ifdef ENABLE_ARM
  333. vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset)));
  334. #else
  335. for (int j = 0; j < C8NUM; ++j) {
  336. (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j];
  337. }
  338. #endif
  339. } // ic8_minus loop
  340. int tmp_ic = ic8_minus * C8NUM;
  341. int res_c = in_channel - tmp_ic;
  342. int tmp_ic_offset = tmp_ic * in_h * in_w;
  343. for (int l = 0; l < res_c; ++l) {
  344. int src_c_offset = src_plane_offset + tmp_ic + l;
  345. int dst_c_offset = dst_plane_offset + tmp_ic_offset + l;
  346. (packed_input + dst_c_offset)[l] = (int16_t)(input_data + src_c_offset)[l];
  347. } // res ic loop
  348. for (int l = res_c; l < C8NUM; ++l) {
  349. int dst_c_offset = dst_plane_offset + tmp_ic_offset + l;
  350. (packed_input + dst_c_offset)[l] = 0;
  351. } // res ic loop
  352. } // kh * kw loop
  353. }
  354. }
  355. void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, ConvParameter *conv_param) {
  356. // origin weight format : ohwi
  357. int input_channel = conv_param->input_channel_;
  358. int ic8 = UP_DIV(input_channel, C8NUM);
  359. int output_channel = conv_param->output_channel_;
  360. QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_;
  361. int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
  362. for (int k = 0; k < kernel_plane; k++) {
  363. int src_kernel_offset = k * input_channel;
  364. int dst_kernel_offset = k * C8NUM;
  365. for (int o = 0; o < output_channel; o++) {
  366. int32_t zp;
  367. if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
  368. zp = filter_zp[0].zp_;
  369. } else {
  370. zp = filter_zp[o].zp_;
  371. }
  372. int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel;
  373. int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM;
  374. int i = 0;
  375. for (; i < (ic8 - 1); i += C8NUM) {
  376. int src_ic_offset = src_oc_offset + i;
  377. int dst_ic_offset = dst_oc_offset + i * kernel_plane;
  378. #ifdef ENABLE_ARM64
  379. int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset);
  380. int16x8_t src_s16 = vmovl_s8(src_s8);
  381. int16x4_t src1_s16 = vget_low_s16(src_s16);
  382. int16x4_t src2_s16 = vget_high_s16(src_s16);
  383. int32x4_t src1_s32 = vmovl_s16(src1_s16);
  384. int32x4_t src2_s32 = vmovl_s16(src2_s16);
  385. int32x4_t zp_s32 = vdupq_n_s32(zp);
  386. int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32);
  387. int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32);
  388. int16x4_t dst1_s16 = vqmovn_s32(dst1_s32);
  389. int16x4_t dst2_s16 = vqmovn_s32(dst2_s32);
  390. vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16);
  391. vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16);
  392. #else
  393. for (int ci = 0; ci < C8NUM; ++ci) {
  394. (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp);
  395. }
  396. #endif
  397. }
  398. dst_oc_offset += (ic8 - 1) * kernel_plane * C8NUM;
  399. for (; i < input_channel; i++) {
  400. int c8_block_rem = i % C8NUM;
  401. int src_ic_offset = src_oc_offset + i;
  402. int dst_ic_offset = dst_oc_offset + c8_block_rem;
  403. (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp);
  404. }
  405. }
  406. }
  407. }
  408. void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
  409. int c4 = UP_DIV(channel, C4NUM);
  410. int c4_minus = c4 - 1;
  411. for (int b = 0; b < batch; b++) {
  412. int src_oc_offset = b * plane * channel;
  413. int dst_oc_offset = b * plane * c4 * C4NUM;
  414. for (int k = 0; k < plane; k++) {
  415. int src_kernel_offset = src_oc_offset + k * channel;
  416. int dst_kernel_offset = dst_oc_offset + k * C4NUM;
  417. for (int j = 0; j < c4_minus; ++j) {
  418. int src_ic_offset = src_kernel_offset + j * C4NUM;
  419. int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
  420. #ifdef ENABLE_ARM
  421. vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
  422. #else
  423. for (int i = 0; i < C4NUM; ++i) {
  424. ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
  425. }
  426. #endif
  427. }
  428. int tmp_c = c4_minus * C4NUM;
  429. int tmp_c_offset = tmp_c * plane;
  430. int res_c = channel - tmp_c;
  431. for (int l = 0; l < res_c; ++l) {
  432. int src_ic_offset = src_kernel_offset + tmp_c + l;
  433. int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
  434. ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
  435. }
  436. }
  437. }
  438. }
  439. void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
  440. int c4 = UP_DIV(channel, C4NUM);
  441. for (int b = 0; b < batch; b++) {
  442. int src_offset = b * plane * channel;
  443. int dst_offset = b * plane * c4 * C4NUM;
  444. for (int c = 0; c < channel; c++) {
  445. int c4_block_num = c / C4NUM;
  446. int c4_block_rem = c % C4NUM;
  447. int src_c_offset = src_offset + c * plane;
  448. int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM;
  449. for (int k = 0; k < plane; k++) {
  450. int src_kernel_offset = src_c_offset + k;
  451. int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem;
  452. ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
  453. }
  454. }
  455. }
  456. }
  457. void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
  458. int c4 = UP_DIV(channel, C4NUM);
  459. int c4_channel = c4 * C4NUM;
  460. int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
  461. int ic_remainder_ = channel % C4NUM;
  462. if (ic_remainder_ != 0) {
  463. int nhwc4_batch_offset = 0;
  464. for (int b = 0; b < batch; b++) {
  465. int batch_offset = b * channel * plane;
  466. for (int i = 0; i < plane; i++) {
  467. float *dst_per_plane = (float *)dst + nhwc4_batch_offset + i * c4_channel;
  468. memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
  469. for (int j = channel; j < c4_channel; ++j) {
  470. dst_per_plane[j] = 0;
  471. }
  472. }
  473. nhwc4_batch_offset += nhwc4_batch_unit_offset;
  474. }
  475. } else {
  476. size_t ori_input_size = batch * plane * channel * sizeof(float);
  477. memcpy((float *)dst, (float *)src, ori_input_size);
  478. }
  479. }
  480. void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
  481. int c4 = UP_DIV(channel, C4NUM);
  482. int ic_remainder_ = channel % C4NUM;
  483. if (ic_remainder_ != 0) {
  484. int nhwc_batch_unit_offset = channel * plane;
  485. for (int b = 0; b < batch; b++) {
  486. int batch_offset = b * c4 * C4NUM * plane;
  487. for (int i = 0; i < plane; i++) {
  488. memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, (float *)src + batch_offset + i * c4 * C4NUM,
  489. channel * sizeof(float));
  490. }
  491. }
  492. } else {
  493. size_t ori_input_size = batch * plane * channel * sizeof(float);
  494. memcpy((float *)dst, (float *)src, ori_input_size);
  495. }
  496. }
  497. void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
  498. int c4 = UP_DIV(channel, C4NUM);
  499. for (int b = 0; b < batch; b++) {
  500. int src_offset = b * plane * c4 * C4NUM;
  501. int dst_offset = b * plane * channel;
  502. for (int c = 0; c < channel; c++) {
  503. int c4_block_num = c / C4NUM;
  504. int c4_block_res = c % C4NUM;
  505. int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
  506. int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
  507. for (int k = 0; k < plane; k++) {
  508. int src_kernel_offset = src_c_offset + k * C4NUM;
  509. int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
  510. ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
  511. }
  512. }
  513. }
  514. }
  515. void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
  516. int c4 = UP_DIV(channel, C4NUM);
  517. for (int b = 0; b < batch; b++) {
  518. int src_offset = b * plane * c4 * C4NUM;
  519. int dst_offset = b * plane * channel;
  520. for (int k = 0; k < plane; k++) {
  521. int src_kernel_offset = src_offset + k * C4NUM;
  522. int dst_kernel_offset = dst_offset + k * channel;
  523. for (int c = 0; c < c4 - 1; c++) {
  524. int src_c_offset = src_kernel_offset + c * plane * C4NUM;
  525. int dst_c_offset = dst_kernel_offset + c * C4NUM;
  526. #ifdef ENABLE_NEON
  527. vst1q_f32((float *)dst + dst_c_offset, vld1q_f32((float *)src + src_c_offset));
  528. #else
  529. ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
  530. ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
  531. ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
  532. ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
  533. #endif
  534. }
  535. // res part
  536. int res_c = channel - (c4 - 1) * C4NUM;
  537. for (int i = 0; i < res_c; i++) {
  538. int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
  539. int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
  540. ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
  541. }
  542. }
  543. }
  544. }
  545. void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
  546. for (int n = 0; n < batch; n++) {
  547. for (int hw = 0; hw < plane; hw++) {
  548. for (int c = 0; c < channel; c++) {
  549. int c8div = c / C8NUM;
  550. int c8mod = c % C8NUM;
  551. int src_index = n * plane * channel + hw * channel + c;
  552. int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
  553. ((float *)dst)[dst_index] = ((float *)src)[src_index];
  554. }
  555. }
  556. }
  557. return;
  558. }
  559. void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) {
  560. int c4 = UP_DIV(channel, C4NUM);
  561. int c4_channel = c4 * C4NUM;
  562. int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
  563. int ic_remainder_ = channel % C4NUM;
  564. if (ic_remainder_ != 0) {
  565. int nhwc4_batch_offset = 0;
  566. for (int b = 0; b < batch; b++) {
  567. int batch_offset = b * channel * plane;
  568. for (int i = 0; i < plane; i++) {
  569. int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel;
  570. memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel);
  571. for (int j = channel; j < c4_channel; ++j) {
  572. dst_per_plane[j] = 0;
  573. }
  574. }
  575. nhwc4_batch_offset += nhwc4_batch_unit_offset;
  576. }
  577. } else {
  578. size_t ori_input_size = batch * plane * channel;
  579. memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
  580. }
  581. }
  582. void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
  583. int c4 = UP_DIV(channel, C4NUM);
  584. int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
  585. int ic_remainder_ = channel % C4NUM;
  586. if (ic_remainder_ != 0) {
  587. for (int b = 0; b < batch; b++) {
  588. int batch_offset = b * channel * plane;
  589. int nhwc4_batch_offset = b * nhwc4_batch_unit_offset;
  590. for (int i = 0; i < plane; i++) {
  591. memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM,
  592. channel);
  593. }
  594. }
  595. } else {
  596. size_t ori_input_size = batch * plane * channel;
  597. memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
  598. }
  599. }
  600. void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) {
  601. int c8 = UP_DIV(channel, C8NUM);
  602. int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
  603. int ic_remainder_ = channel % C8NUM;
  604. if (ic_remainder_ != 0) {
  605. int nhwc8_batch_offset = 0;
  606. for (int b = 0; b < batch; b++) {
  607. int batch_offset = b * channel * plane;
  608. for (int i = 0; i < plane; i++) {
  609. memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel,
  610. channel);
  611. }
  612. nhwc8_batch_offset += nhwc8_batch_unit_offset;
  613. }
  614. } else {
  615. size_t ori_input_size = batch * plane * channel;
  616. memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
  617. }
  618. }
  619. void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
  620. int c8 = UP_DIV(channel, C8NUM);
  621. int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
  622. int ic_remainder_ = channel % C8NUM;
  623. if (ic_remainder_ != 0) {
  624. for (int b = 0; b < batch; b++) {
  625. int batch_offset = b * channel * plane;
  626. int nhwc8_batch_offset = b * nhwc8_batch_unit_offset;
  627. for (int i = 0; i < plane; i++) {
  628. memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM,
  629. channel);
  630. }
  631. }
  632. } else {
  633. size_t ori_input_size = batch * plane * channel;
  634. memcpy((int8_t *)dst, (int8_t *)src, ori_input_size);
  635. }
  636. }
  637. void PackNCHWToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) {
  638. int c8 = UP_DIV(channel, C8NUM);
  639. for (int b = 0; b < batch; b++) {
  640. int src_offset = b * plane * channel;
  641. int dst_offset = b * plane * c8 * C8NUM;
  642. for (int c = 0; c < channel; c++) {
  643. int c8_block_num = c / C8NUM;
  644. int c8_block_rem = c % C8NUM;
  645. int src_c_offset = src_offset + c * plane;
  646. int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM;
  647. for (int k = 0; k < plane; k++) {
  648. int src_kernel_offset = src_c_offset + k;
  649. int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem;
  650. ((int8_t *)dst + dst_kernel_offset)[0] = ((int8_t *)src + src_kernel_offset)[0];
  651. }
  652. }
  653. }
  654. }
  655. void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
  656. int c4 = UP_DIV(channel, C4NUM);
  657. for (int b = 0; b < batch; b++) {
  658. int src_offset = b * plane * c4 * C4NUM;
  659. int dst_offset = b * plane * channel;
  660. for (int k = 0; k < plane; k++) {
  661. int src_kernel_offset = src_offset + k * C4NUM;
  662. int dst_kernel_offset = dst_offset + k * channel;
  663. for (int c = 0; c < c4 - 1; c++) {
  664. int src_c_offset = src_kernel_offset + c * plane * C4NUM;
  665. int dst_c_offset = dst_kernel_offset + c * C4NUM;
  666. ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0];
  667. ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1];
  668. ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2];
  669. ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3];
  670. }
  671. // res part
  672. int res_c = channel - (c4 - 1) * C4NUM;
  673. for (int i = 0; i < res_c; i++) {
  674. int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
  675. int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
  676. ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0];
  677. }
  678. }
  679. }
  680. }
  681. void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) {
  682. for (int n = 0; n < batch; n++) {
  683. for (int hw = 0; hw < plane; hw++) {
  684. for (int c = 0; c < channel; c++) {
  685. int c8div = c / C8NUM;
  686. int c8mod = c % C8NUM;
  687. int src_index = n * plane * channel + hw * channel + c;
  688. int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
  689. ((int8_t *)dst)[dst_index] = ((int8_t *)src)[src_index];
  690. }
  691. }
  692. }
  693. return;
  694. }
  695. void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) {
  696. for (int n = 0; n < batch; n++) {
  697. for (int c = 0; c < channel; c++) {
  698. for (int hw = 0; hw < plane; hw++) {
  699. int nhwc_index = n * channel * plane + hw * channel + c;
  700. int nchw_index = n * channel * plane + c * plane + hw;
  701. ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index];
  702. }
  703. }
  704. }
  705. return;
  706. }
  707. void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
  708. int hw8 = plane / C8NUM * C8NUM;
  709. int c8 = channel / C8NUM * C8NUM;
  710. int batch = plane * channel;
  711. for (int n = 0; n < batches; n++) {
  712. const float *src_batch = (const float *)src + n * batch;
  713. float *dst_batch = (float *)dst + n * batch;
  714. int hw = 0;
  715. for (; hw < hw8; hw += C8NUM) {
  716. int c = 0;
  717. for (; c < c8; c += C8NUM) {
  718. const float *src_ptr = src_batch + hw * channel + c;
  719. float *dst_ptr = dst_batch + c * plane + hw;
  720. #ifdef ENABLE_ARM64
  721. size_t srcStride = channel * sizeof(float);
  722. size_t dstStride = plane * sizeof(float);
  723. asm volatile(
  724. "mov x10, %[src_ptr]\n"
  725. "mov x11, %[dst_ptr]\n"
  726. "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
  727. "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
  728. "zip1 v8.4s, v0.4s, v2.4s\n"
  729. "zip2 v9.4s, v0.4s, v2.4s\n"
  730. "zip1 v12.4s, v1.4s, v3.4s\n"
  731. "zip2 v13.4s, v1.4s, v3.4s\n"
  732. "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
  733. "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
  734. "zip1 v10.4s, v4.4s, v6.4s\n"
  735. "zip2 v11.4s, v4.4s, v6.4s\n"
  736. "zip1 v14.4s, v5.4s, v7.4s\n"
  737. "zip2 v15.4s, v5.4s, v7.4s\n"
  738. "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
  739. "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
  740. "trn1 v16.2d, v8.2d, v10.2d\n"
  741. "trn2 v18.2d, v8.2d, v10.2d\n"
  742. "trn1 v20.2d, v9.2d, v11.2d\n"
  743. "trn2 v22.2d, v9.2d, v11.2d\n"
  744. "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
  745. "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
  746. "trn1 v24.2d, v12.2d, v14.2d\n"
  747. "trn2 v26.2d, v12.2d, v14.2d\n"
  748. "trn1 v28.2d, v13.2d, v15.2d\n"
  749. "trn2 v30.2d, v13.2d, v15.2d\n"
  750. "zip1 v8.4s, v0.4s, v2.4s\n"
  751. "zip2 v9.4s, v0.4s, v2.4s\n"
  752. "zip1 v12.4s, v1.4s, v3.4s\n"
  753. "zip2 v13.4s, v1.4s, v3.4s\n"
  754. "zip1 v10.4s, v4.4s, v6.4s\n"
  755. "zip2 v11.4s, v4.4s, v6.4s\n"
  756. "zip1 v14.4s, v5.4s, v7.4s\n"
  757. "zip2 v15.4s, v5.4s, v7.4s\n"
  758. "trn1 v17.2d, v8.2d, v10.2d\n"
  759. "trn2 v19.2d, v8.2d, v10.2d\n"
  760. "trn1 v21.2d, v9.2d, v11.2d\n"
  761. "trn2 v23.2d, v9.2d, v11.2d\n"
  762. "trn1 v25.2d, v12.2d, v14.2d\n"
  763. "trn2 v27.2d, v12.2d, v14.2d\n"
  764. "trn1 v29.2d, v13.2d, v15.2d\n"
  765. "trn2 v31.2d, v13.2d, v15.2d\n"
  766. "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
  767. "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
  768. "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
  769. "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
  770. "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
  771. "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
  772. "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
  773. "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
  774. :
  775. :
  776. [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
  777. : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
  778. "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
  779. "v30", "v31");
  780. #elif ENABLE_ARM32
  781. size_t srcStride = channel * sizeof(float);
  782. size_t dstStride = plane * sizeof(float);
  783. asm volatile(
  784. "mov r10, %[src_ptr]\n"
  785. "mov r12, %[dst_ptr]\n"
  786. "vld1.32 {q0, q1}, [r10], %[srcStride]\n"
  787. "vld1.32 {q2, q3}, [r10], %[srcStride]\n"
  788. "vtrn.32 d0, d4\n"
  789. "vtrn.32 d1, d5\n"
  790. "vtrn.32 d2, d6\n"
  791. "vtrn.32 d3, d7\n"
  792. "vld1.32 {q4, q5}, [r10], %[srcStride]\n"
  793. "vld1.32 {q6, q7}, [r10], %[srcStride]\n"
  794. "vtrn.32 d8, d12\n"
  795. "vtrn.32 d9, d13\n"
  796. "vtrn.32 d10, d14\n"
  797. "vtrn.32 d11, d15\n"
  798. "vld1.32 {q8, q9}, [r10], %[srcStride]\n"
  799. "vld1.32 {q10, q11}, [r10], %[srcStride]\n"
  800. "vswp d1, d8\n"
  801. "vswp d3, d10\n"
  802. "vswp d5, d12\n"
  803. "vswp d7, d14\n"
  804. "vtrn.32 d16, d20\n"
  805. "vtrn.32 d17, d21\n"
  806. "vtrn.32 d18, d22\n"
  807. "vtrn.32 d19, d23\n"
  808. "vld1.32 {q12, q13}, [r10], %[srcStride]\n"
  809. "vld1.32 {q14, q15}, [r10], %[srcStride]\n"
  810. "vtrn.32 d24, d28\n"
  811. "vtrn.32 d25, d29\n"
  812. "vtrn.32 d26, d30\n"
  813. "vtrn.32 d27, d31\n"
  814. "vswp d17, d24\n"
  815. "vswp d19, d26\n"
  816. "vswp d21, d28\n"
  817. "vswp d23, d30\n"
  818. "add r10, r12, #16\n"
  819. "vst1.32 {q0}, [r12], %[dstStride]\n"
  820. "vst1.32 {q8}, [r10], %[dstStride]\n"
  821. "vst1.32 {q2}, [r12], %[dstStride]\n"
  822. "vst1.32 {q10}, [r10], %[dstStride]\n"
  823. "vst1.32 {q4}, [r12], %[dstStride]\n"
  824. "vst1.32 {q12}, [r10], %[dstStride]\n"
  825. "vst1.32 {q6}, [r12], %[dstStride]\n"
  826. "vst1.32 {q14}, [r10], %[dstStride]\n"
  827. "vst1.32 {q1}, [r12], %[dstStride]\n"
  828. "vst1.32 {q9}, [r10], %[dstStride]\n"
  829. "vst1.32 {q3}, [r12], %[dstStride]\n"
  830. "vst1.32 {q11}, [r10], %[dstStride]\n"
  831. "vst1.32 {q5}, [r12], %[dstStride]\n"
  832. "vst1.32 {q13}, [r10], %[dstStride]\n"
  833. "vst1.32 {q7}, [r12], %[dstStride]\n"
  834. "vst1.32 {q15}, [r10], %[dstStride]\n"
  835. :
  836. :
  837. [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
  838. : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
  839. "q15");
  840. #else
  841. for (int tr = 0; tr < C8NUM; tr++) {
  842. for (int tc = 0; tc < C8NUM; tc++) {
  843. dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
  844. }
  845. }
  846. #endif
  847. }
  848. for (; c < channel; c++) {
  849. const float *src_ptr = src_batch + hw * channel + c;
  850. float *dst_ptr = dst_batch + c * plane + hw;
  851. for (size_t i = 0; i < C8NUM; i++) {
  852. dst_ptr[i] = src_ptr[i * channel];
  853. }
  854. }
  855. }
  856. for (; hw < plane; hw++) {
  857. const float *src_ptr = src_batch + hw * channel;
  858. float *dst_ptr = dst_batch + hw;
  859. for (size_t i = 0; i < channel; i++) {
  860. dst_ptr[i * plane] = src_ptr[i];
  861. }
  862. }
  863. }
  864. return;
  865. }
  866. void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) {
  867. for (int n = 0; n < batch; n++) {
  868. for (int c = 0; c < channel; c++) {
  869. for (int hw = 0; hw < plane; hw++) {
  870. int nhwc_index = n * channel * plane + hw * channel + c;
  871. int nchw_index = n * channel * plane + c * plane + hw;
  872. ((int8_t *)dst)[nchw_index] = ((int8_t *)src)[nhwc_index];
  873. }
  874. }
  875. }
  876. return;
  877. }
  878. void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
  879. return PackNHWCToNCHWFp32(src, dst, batch, channel, plane);
  880. }
  881. void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) {
  882. int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
  883. int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
  884. int unit = conv_param->input_h_ * conv_param->input_w_;
  885. for (int b = 0; b < conv_param->input_batch_; b++) {
  886. const int8_t *src_b = src + b * unit * conv_param->input_channel_;
  887. int16_t *dst_b = dst + b * unit * ic4 * C4NUM;
  888. for (int k = 0; k < unit; k++) {
  889. const int8_t *src_k = src_b + k * conv_param->input_channel_;
  890. int16_t *dst_k = dst_b + k * ic4 * C4NUM;
  891. for (int c = 0; c < conv_param->input_channel_; c++) {
  892. dst_k[c] = (int16_t)(src_k[c] - input_zp);
  893. }
  894. }
  895. }
  896. }
  897. void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
  898. ConvQuantArg *quant_qrg) {
  899. int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
  900. for (int c = 0; c < channel; c++) {
  901. if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
  902. weight_zp = quant_qrg->filter_quant_args_[c].zp_;
  903. }
  904. int c8_block_num = c / C8NUM;
  905. int c8_block_rem = c % C8NUM;
  906. const int8_t *src_c = origin_weight + c * plane;
  907. int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM;
  908. for (int k = 0; k < plane; k++) {
  909. const int8_t *src_kernel = src_c + k;
  910. int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem;
  911. *dst_kernel = (int16_t)(src_kernel[0] - weight_zp);
  912. }
  913. }
  914. }
  915. void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
  916. ConvQuantArg *quant_qrg) {
  917. int weight_zp = quant_qrg->filter_quant_args_[0].zp_;
  918. for (int c = 0; c < channel; c++) {
  919. if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) {
  920. weight_zp = quant_qrg->filter_quant_args_[c].zp_;
  921. }
  922. int c4_block_num = c / C4NUM;
  923. int c4_block_rem = c % C4NUM;
  924. const int8_t *src_c = origin_weight + c * plane;
  925. int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM;
  926. for (int k = 0; k < plane; k++) {
  927. const int8_t *src_kernel = src_c + k;
  928. int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem;
  929. *dst_kernel = (int16_t)(src_kernel[0] - weight_zp);
  930. }
  931. }
  932. }