You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_static_batchnorm.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "fuse_static_batchnorm.h"
  15. #include "pass_level2.h"
  16. #include <math.h>
  17. #include <string.h>
  18. namespace pnnx {
  19. class fuse_static_Fbatchnorm_pass_1d : public GraphRewriterPass
  20. {
  21. public:
  22. const char* match_pattern_graph() const
  23. {
  24. return R"PNNXIR(7767517
  25. 5 4
  26. pnnx.Input input 0 1 input
  27. pnnx.Attribute op_mean 0 1 running_mean @qwq
  28. pnnx.Attribute op_var 0 1 running_var @qwq
  29. F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
  30. pnnx.Output output 1 0 out
  31. )PNNXIR";
  32. }
  33. const char* type_str() const
  34. {
  35. return "nn.BatchNorm1d";
  36. }
  37. const char* name_str() const
  38. {
  39. return "batchnorm";
  40. }
  41. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  42. {
  43. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  44. return input_rank == 2 || input_rank == 3;
  45. }
  46. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  47. {
  48. Attribute running_mean;
  49. Attribute running_var;
  50. for (const auto& x : captured_attrs)
  51. {
  52. if (x.first.substr(0, 8) == "op_mean.")
  53. running_mean = x.second;
  54. if (x.first.substr(0, 7) == "op_var.")
  55. running_var = x.second;
  56. }
  57. op->params["num_features"] = running_mean.shape[0];
  58. op->params["eps"] = captured_params.at("eps");
  59. op->params["affine"] = false;
  60. op->attrs["running_mean"] = running_mean;
  61. op->attrs["running_var"] = running_var;
  62. }
  63. };
  64. class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass
  65. {
  66. public:
  67. const char* match_pattern_graph() const
  68. {
  69. return R"PNNXIR(7767517
  70. 7 6
  71. pnnx.Input input 0 1 input
  72. pnnx.Attribute op_mean 0 1 running_mean @qwq
  73. pnnx.Attribute op_var 0 1 running_var @qwq
  74. pnnx.Attribute op_weight 0 1 weight @qwq
  75. pnnx.Attribute op_bias 0 1 bias @qwq
  76. F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
  77. pnnx.Output output 1 0 out
  78. )PNNXIR";
  79. }
  80. const char* type_str() const
  81. {
  82. return "nn.BatchNorm1d";
  83. }
  84. const char* name_str() const
  85. {
  86. return "batchnorm";
  87. }
  88. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  89. {
  90. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  91. return input_rank == 2 || input_rank == 3;
  92. }
  93. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  94. {
  95. Attribute running_mean;
  96. Attribute running_var;
  97. Attribute weight;
  98. Attribute bias;
  99. for (const auto& x : captured_attrs)
  100. {
  101. if (x.first.substr(0, 8) == "op_mean.")
  102. running_mean = x.second;
  103. if (x.first.substr(0, 7) == "op_var.")
  104. running_var = x.second;
  105. if (x.first.substr(0, 10) == "op_weight.")
  106. weight = x.second;
  107. if (x.first.substr(0, 8) == "op_bias.")
  108. bias = x.second;
  109. }
  110. op->params["num_features"] = running_mean.shape[0];
  111. op->params["eps"] = captured_params.at("eps");
  112. op->params["affine"] = true;
  113. op->attrs["running_mean"] = running_mean;
  114. op->attrs["running_var"] = running_var;
  115. op->attrs["weight"] = weight;
  116. op->attrs["bias"] = bias;
  117. }
  118. };
  119. class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass
  120. {
  121. public:
  122. const char* match_pattern_graph() const
  123. {
  124. return R"PNNXIR(7767517
  125. 5 4
  126. pnnx.Input input 0 1 input
  127. pnnx.Attribute op_mean 0 1 running_mean @qwq
  128. pnnx.Attribute op_var 0 1 running_var @qwq
  129. F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
  130. pnnx.Output output 1 0 out
  131. )PNNXIR";
  132. }
  133. const char* type_str() const
  134. {
  135. return "nn.BatchNorm2d";
  136. }
  137. const char* name_str() const
  138. {
  139. return "batchnorm";
  140. }
  141. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  142. {
  143. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  144. return input_rank == 4;
  145. }
  146. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  147. {
  148. Attribute running_mean;
  149. Attribute running_var;
  150. for (const auto& x : captured_attrs)
  151. {
  152. if (x.first.substr(0, 8) == "op_mean.")
  153. running_mean = x.second;
  154. if (x.first.substr(0, 7) == "op_var.")
  155. running_var = x.second;
  156. }
  157. op->params["num_features"] = running_mean.shape[0];
  158. op->params["eps"] = captured_params.at("eps");
  159. op->params["affine"] = false;
  160. op->attrs["running_mean"] = running_mean;
  161. op->attrs["running_var"] = running_var;
  162. }
  163. };
  164. class fuse_static_Fbatchnorm_pass_2d_1 : public GraphRewriterPass
  165. {
  166. public:
  167. const char* match_pattern_graph() const
  168. {
  169. return R"PNNXIR(7767517
  170. 7 6
  171. pnnx.Input input 0 1 input
  172. pnnx.Attribute op_mean 0 1 running_mean @qwq
  173. pnnx.Attribute op_var 0 1 running_var @qwq
  174. pnnx.Attribute op_weight 0 1 weight @qwq
  175. pnnx.Attribute op_bias 0 1 bias @qwq
  176. F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
  177. pnnx.Output output 1 0 out
  178. )PNNXIR";
  179. }
  180. const char* type_str() const
  181. {
  182. return "nn.BatchNorm2d";
  183. }
  184. const char* name_str() const
  185. {
  186. return "batchnorm";
  187. }
  188. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  189. {
  190. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  191. return input_rank == 4;
  192. }
  193. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  194. {
  195. Attribute running_mean;
  196. Attribute running_var;
  197. Attribute weight;
  198. Attribute bias;
  199. for (const auto& x : captured_attrs)
  200. {
  201. if (x.first.substr(0, 8) == "op_mean.")
  202. running_mean = x.second;
  203. if (x.first.substr(0, 7) == "op_var.")
  204. running_var = x.second;
  205. if (x.first.substr(0, 10) == "op_weight.")
  206. weight = x.second;
  207. if (x.first.substr(0, 8) == "op_bias.")
  208. bias = x.second;
  209. }
  210. op->params["num_features"] = running_mean.shape[0];
  211. op->params["eps"] = captured_params.at("eps");
  212. op->params["affine"] = true;
  213. op->attrs["running_mean"] = running_mean;
  214. op->attrs["running_var"] = running_var;
  215. op->attrs["weight"] = weight;
  216. op->attrs["bias"] = bias;
  217. }
  218. };
  219. class fuse_static_Fbatchnorm_pass_3d : public GraphRewriterPass
  220. {
  221. public:
  222. const char* match_pattern_graph() const
  223. {
  224. return R"PNNXIR(7767517
  225. 5 4
  226. pnnx.Input input 0 1 input
  227. pnnx.Attribute op_mean 0 1 running_mean @qwq
  228. pnnx.Attribute op_var 0 1 running_var @qwq
  229. F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps
  230. pnnx.Output output 1 0 out
  231. )PNNXIR";
  232. }
  233. const char* type_str() const
  234. {
  235. return "nn.BatchNorm3d";
  236. }
  237. const char* name_str() const
  238. {
  239. return "batchnorm";
  240. }
  241. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  242. {
  243. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  244. return input_rank == 5;
  245. }
  246. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  247. {
  248. Attribute running_mean;
  249. Attribute running_var;
  250. for (const auto& x : captured_attrs)
  251. {
  252. if (x.first.substr(0, 8) == "op_mean.")
  253. running_mean = x.second;
  254. if (x.first.substr(0, 7) == "op_var.")
  255. running_var = x.second;
  256. }
  257. op->params["num_features"] = running_mean.shape[0];
  258. op->params["eps"] = captured_params.at("eps");
  259. op->params["affine"] = false;
  260. op->attrs["running_mean"] = running_mean;
  261. op->attrs["running_var"] = running_var;
  262. }
  263. };
  264. class fuse_static_Fbatchnorm_pass_3d_1 : public GraphRewriterPass
  265. {
  266. public:
  267. const char* match_pattern_graph() const
  268. {
  269. return R"PNNXIR(7767517
  270. 7 6
  271. pnnx.Input input 0 1 input
  272. pnnx.Attribute op_mean 0 1 running_mean @qwq
  273. pnnx.Attribute op_var 0 1 running_var @qwq
  274. pnnx.Attribute op_weight 0 1 weight @qwq
  275. pnnx.Attribute op_bias 0 1 bias @qwq
  276. F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps
  277. pnnx.Output output 1 0 out
  278. )PNNXIR";
  279. }
  280. const char* type_str() const
  281. {
  282. return "nn.BatchNorm3d";
  283. }
  284. const char* name_str() const
  285. {
  286. return "batchnorm";
  287. }
  288. bool match(const std::map<std::string, const Operator*>& matched_operators) const
  289. {
  290. int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size();
  291. return input_rank == 5;
  292. }
  293. void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
  294. {
  295. Attribute running_mean;
  296. Attribute running_var;
  297. Attribute weight;
  298. Attribute bias;
  299. for (const auto& x : captured_attrs)
  300. {
  301. if (x.first.substr(0, 8) == "op_mean.")
  302. running_mean = x.second;
  303. if (x.first.substr(0, 7) == "op_var.")
  304. running_var = x.second;
  305. if (x.first.substr(0, 10) == "op_weight.")
  306. weight = x.second;
  307. if (x.first.substr(0, 8) == "op_bias.")
  308. bias = x.second;
  309. }
  310. op->params["num_features"] = running_mean.shape[0];
  311. op->params["eps"] = captured_params.at("eps");
  312. op->params["affine"] = true;
  313. op->attrs["running_mean"] = running_mean;
  314. op->attrs["running_var"] = running_var;
  315. op->attrs["weight"] = weight;
  316. op->attrs["bias"] = bias;
  317. }
  318. };
  319. void fuse_static_batchnorm(Graph& graph)
  320. {
  321. fuse_static_Fbatchnorm_pass_1d a;
  322. fuse_static_Fbatchnorm_pass_2d b;
  323. fuse_static_Fbatchnorm_pass_3d c;
  324. fuse_static_Fbatchnorm_pass_1d_1 a1;
  325. fuse_static_Fbatchnorm_pass_2d_1 b1;
  326. fuse_static_Fbatchnorm_pass_3d_1 c1;
  327. int opindex = 0;
  328. pnnx_graph_rewrite(graph, &a, opindex);
  329. pnnx_graph_rewrite(graph, &b, opindex);
  330. pnnx_graph_rewrite(graph, &c, opindex);
  331. pnnx_graph_rewrite(graph, &a1, opindex);
  332. pnnx_graph_rewrite(graph, &b1, opindex);
  333. pnnx_graph_rewrite(graph, &c1, opindex);
  334. }
  335. } // namespace pnnx