You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rnn.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "rnn.h"
  15. namespace ncnn {
  16. RNN::RNN()
  17. {
  18. one_blob_only = false;
  19. support_inplace = false;
  20. }
  21. int RNN::load_param(const ParamDict& pd)
  22. {
  23. num_output = pd.get(0, 0);
  24. weight_data_size = pd.get(1, 0);
  25. direction = pd.get(2, 0);
  26. int8_scale_term = pd.get(8, 0);
  27. if (int8_scale_term)
  28. {
  29. #if !NCNN_INT8
  30. NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference");
  31. return -1;
  32. #endif
  33. }
  34. return 0;
  35. }
  36. int RNN::load_model(const ModelBin& mb)
  37. {
  38. int num_directions = direction == 2 ? 2 : 1;
  39. int size = weight_data_size / num_directions / num_output;
  40. // raw weight data
  41. weight_xc_data = mb.load(size, num_output, num_directions, 0);
  42. if (weight_xc_data.empty())
  43. return -100;
  44. bias_c_data = mb.load(num_output, 1, num_directions, 0);
  45. if (bias_c_data.empty())
  46. return -100;
  47. weight_hc_data = mb.load(num_output, num_output, num_directions, 0);
  48. if (weight_hc_data.empty())
  49. return -100;
  50. #if NCNN_INT8
  51. if (int8_scale_term)
  52. {
  53. weight_xc_data_int8_scales = mb.load(num_output, num_directions, 1);
  54. weight_hc_data_int8_scales = mb.load(num_output, num_directions, 1);
  55. }
  56. #endif // NCNN_INT8
  57. return 0;
  58. }
  59. static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
  60. {
  61. int size = bottom_blob.w;
  62. int T = bottom_blob.h;
  63. int num_output = top_blob.w;
  64. // num_output
  65. Mat gates(num_output, 4u, opt.workspace_allocator);
  66. if (gates.empty())
  67. return -100;
  68. // unroll
  69. for (int t = 0; t < T; t++)
  70. {
  71. int ti = reverse ? T - 1 - t : t;
  72. const float* x = bottom_blob.row(ti);
  73. #pragma omp parallel for num_threads(opt.num_threads)
  74. for (int q = 0; q < num_output; q++)
  75. {
  76. const float* weight_xc_ptr = weight_xc.row(q);
  77. const float* weight_hc_ptr = weight_hc.row(q);
  78. float H = bias_c[q];
  79. for (int i = 0; i < size; i++)
  80. {
  81. H += weight_xc_ptr[i] * x[i];
  82. }
  83. for (int i = 0; i < num_output; i++)
  84. {
  85. H += weight_hc_ptr[i] * hidden_state[i];
  86. }
  87. H = tanhf(H);
  88. gates[q] = H;
  89. }
  90. float* output_data = top_blob.row(ti);
  91. #pragma omp parallel for num_threads(opt.num_threads)
  92. for (int q = 0; q < num_output; q++)
  93. {
  94. float H = gates[q];
  95. hidden_state[q] = H;
  96. output_data[q] = H;
  97. }
  98. }
  99. return 0;
  100. }
  101. #if NCNN_INT8
  102. static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt)
  103. {
  104. int size = bottom_blob.w;
  105. int T = bottom_blob.h;
  106. int num_output = top_blob.w;
  107. // num_output
  108. Mat gates(num_output, 4u, opt.workspace_allocator);
  109. if (gates.empty())
  110. return -100;
  111. // dynamic quantize bottom_blob
  112. Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator);
  113. Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator);
  114. {
  115. for (int t = 0; t < T; t++)
  116. {
  117. const float* x = bottom_blob.row(t);
  118. float absmax = 0.f;
  119. for (int i = 0; i < size; i++)
  120. {
  121. absmax = std::max(absmax, (float)fabs(x[i]));
  122. }
  123. bottom_blob_int8_scales[t] = 127.f / absmax;
  124. }
  125. Option opt_quant = opt;
  126. opt_quant.blob_allocator = opt.workspace_allocator;
  127. opt_quant.use_packing_layout = false;
  128. quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant);
  129. }
  130. Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator);
  131. Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator);
  132. // unroll
  133. for (int t = 0; t < T; t++)
  134. {
  135. int ti = reverse ? T - 1 - t : t;
  136. // dynamic quantize hidden_state
  137. {
  138. float absmax = 0.f;
  139. for (int i = 0; i < num_output; i++)
  140. {
  141. absmax = std::max(absmax, (float)fabs(hidden_state[i]));
  142. }
  143. if (absmax == 0.f)
  144. {
  145. hidden_state_int8_scales[0] = 1.f;
  146. hidden_state_int8.fill<signed char>(0);
  147. }
  148. else
  149. {
  150. hidden_state_int8_scales[0] = 127.f / absmax;
  151. Option opt_quant = opt;
  152. opt_quant.blob_allocator = opt.workspace_allocator;
  153. opt_quant.use_packing_layout = false;
  154. quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant);
  155. }
  156. }
  157. const signed char* x = bottom_blob_int8.row<const signed char>(ti);
  158. const signed char* hs = hidden_state_int8;
  159. const float descale_x = 1.f / bottom_blob_int8_scales[ti];
  160. const float descale_h = 1.f / hidden_state_int8_scales[0];
  161. #pragma omp parallel for num_threads(opt.num_threads)
  162. for (int q = 0; q < num_output; q++)
  163. {
  164. const signed char* weight_xc_int8_ptr = weight_xc_int8.row<const signed char>(q);
  165. const signed char* weight_hc_int8_ptr = weight_hc_int8.row<const signed char>(q);
  166. const float descale_xc = 1.f / weight_xc_int8_scales[q];
  167. const float descale_hc = 1.f / weight_hc_int8_scales[q];
  168. int Hx = 0;
  169. for (int i = 0; i < size; i++)
  170. {
  171. Hx += weight_xc_int8_ptr[i] * x[i];
  172. }
  173. int Hh = 0;
  174. for (int i = 0; i < num_output; i++)
  175. {
  176. Hh += weight_hc_int8_ptr[i] * hs[i];
  177. }
  178. float H = bias_c[q] + Hx * (descale_x * descale_xc) + Hh * (descale_h * descale_hc);
  179. H = tanhf(H);
  180. gates[q] = H;
  181. }
  182. float* output_data = top_blob.row(ti);
  183. #pragma omp parallel for num_threads(opt.num_threads)
  184. for (int q = 0; q < num_output; q++)
  185. {
  186. float H = gates[q];
  187. hidden_state[q] = H;
  188. output_data[q] = H;
  189. }
  190. }
  191. return 0;
  192. }
  193. #endif // NCNN_INT8
  194. int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
  195. {
  196. int T = bottom_blob.h;
  197. int num_directions = direction == 2 ? 2 : 1;
  198. // initial hidden state
  199. Mat hidden(num_output, 4u, opt.workspace_allocator);
  200. if (hidden.empty())
  201. return -100;
  202. hidden.fill(0.f);
  203. top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
  204. if (top_blob.empty())
  205. return -100;
  206. // Uni directional
  207. if (direction == 0 || direction == 1)
  208. {
  209. #if NCNN_INT8
  210. if (int8_scale_term)
  211. {
  212. int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
  213. if (ret != 0)
  214. return ret;
  215. }
  216. else
  217. #endif
  218. {
  219. int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
  220. if (ret != 0)
  221. return ret;
  222. }
  223. }
  224. if (direction == 2)
  225. {
  226. Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
  227. if (top_blob_forward.empty())
  228. return -100;
  229. Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
  230. if (top_blob_reverse.empty())
  231. return -100;
  232. #if NCNN_INT8
  233. if (int8_scale_term)
  234. {
  235. int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
  236. if (ret != 0)
  237. return ret;
  238. }
  239. else
  240. #endif
  241. {
  242. int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
  243. if (ret != 0)
  244. return ret;
  245. }
  246. hidden.fill(0.0f);
  247. #if NCNN_INT8
  248. if (int8_scale_term)
  249. {
  250. int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt);
  251. if (ret != 0)
  252. return ret;
  253. }
  254. else
  255. #endif
  256. {
  257. int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
  258. if (ret != 0)
  259. return ret;
  260. }
  261. // concat w
  262. for (int i = 0; i < T; i++)
  263. {
  264. const float* pf = top_blob_forward.row(i);
  265. const float* pr = top_blob_reverse.row(i);
  266. float* ptr = top_blob.row(i);
  267. memcpy(ptr, pf, num_output * sizeof(float));
  268. memcpy(ptr + num_output, pr, num_output * sizeof(float));
  269. }
  270. }
  271. return 0;
  272. }
  273. int RNN::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
  274. {
  275. const Mat& bottom_blob = bottom_blobs[0];
  276. int T = bottom_blob.h;
  277. int num_directions = direction == 2 ? 2 : 1;
  278. Mat hidden;
  279. Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
  280. if (bottom_blobs.size() == 2)
  281. {
  282. hidden = bottom_blobs[1].clone(hidden_allocator);
  283. }
  284. else
  285. {
  286. hidden.create(num_output, num_directions, 4u, hidden_allocator);
  287. if (hidden.empty())
  288. return -100;
  289. hidden.fill(0.f);
  290. }
  291. Mat& top_blob = top_blobs[0];
  292. top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
  293. if (top_blob.empty())
  294. return -100;
  295. // Uni directional
  296. if (direction == 0 || direction == 1)
  297. {
  298. #if NCNN_INT8
  299. if (int8_scale_term)
  300. {
  301. int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt);
  302. if (ret != 0)
  303. return ret;
  304. }
  305. else
  306. #endif
  307. {
  308. int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
  309. if (ret != 0)
  310. return ret;
  311. }
  312. }
  313. if (direction == 2)
  314. {
  315. Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
  316. if (top_blob_forward.empty())
  317. return -100;
  318. Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
  319. if (top_blob_reverse.empty())
  320. return -100;
  321. Mat hidden0 = hidden.row_range(0, 1);
  322. #if NCNN_INT8
  323. if (int8_scale_term)
  324. {
  325. int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt);
  326. if (ret != 0)
  327. return ret;
  328. }
  329. else
  330. #endif
  331. {
  332. int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
  333. if (ret != 0)
  334. return ret;
  335. }
  336. Mat hidden1 = hidden.row_range(1, 1);
  337. #if NCNN_INT8
  338. if (int8_scale_term)
  339. {
  340. int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt);
  341. if (ret != 0)
  342. return ret;
  343. }
  344. else
  345. #endif
  346. {
  347. int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
  348. if (ret != 0)
  349. return ret;
  350. }
  351. // concat w
  352. for (int i = 0; i < T; i++)
  353. {
  354. const float* pf = top_blob_forward.row(i);
  355. const float* pr = top_blob_reverse.row(i);
  356. float* ptr = top_blob.row(i);
  357. memcpy(ptr, pf, num_output * sizeof(float));
  358. memcpy(ptr + num_output, pr, num_output * sizeof(float));
  359. }
  360. }
  361. if (top_blobs.size() == 2)
  362. {
  363. top_blobs[1] = hidden;
  364. }
  365. return 0;
  366. }
  367. } // namespace ncnn