// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "fuse_static_batchnorm.h" #include "pass_level2.h" #include #include namespace pnnx { class fuse_static_Fbatchnorm_pass_1d : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm1d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 2 || input_rank == 3; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = false; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; } }; class fuse_static_Fbatchnorm_pass_1d_1 : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq pnnx.Attribute op_weight 0 1 weight @qwq pnnx.Attribute op_bias 0 1 bias @qwq F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm1d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 2 || input_rank == 3; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; Attribute weight; Attribute bias; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; if (x.first.substr(0, 10) == "op_weight.") weight = x.second; if (x.first.substr(0, 8) == "op_bias.") bias = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = true; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; op->attrs["weight"] = weight; op->attrs["bias"] = bias; } }; class fuse_static_Fbatchnorm_pass_2d : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm2d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 4; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = false; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; } }; class fuse_static_Fbatchnorm_pass_2d_1 : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq pnnx.Attribute op_weight 0 1 weight @qwq pnnx.Attribute op_bias 0 1 bias @qwq F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm2d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 4; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; Attribute weight; Attribute bias; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; if (x.first.substr(0, 10) == "op_weight.") weight = x.second; if (x.first.substr(0, 8) == "op_bias.") bias = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = true; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; op->attrs["weight"] = weight; op->attrs["bias"] = bias; } }; class fuse_static_Fbatchnorm_pass_3d : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 5 4 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq F.batchnorm op_0 3 1 input running_mean running_var out weight=None bias=None eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm3d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 5; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = false; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; } }; class fuse_static_Fbatchnorm_pass_3d_1 : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 7 6 pnnx.Input input 0 1 input pnnx.Attribute op_mean 0 1 running_mean @qwq pnnx.Attribute op_var 0 1 running_var @qwq pnnx.Attribute op_weight 0 1 weight @qwq pnnx.Attribute op_bias 0 1 bias @qwq F.batch_norm op_0 5 1 input running_mean running_var weight bias out eps=%eps pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "nn.BatchNorm3d"; } const char* name_str() const { return "batchnorm"; } bool match(const std::map& matched_operators) const { int input_rank = matched_operators.at("op_0")->inputs[0]->shape.size(); return input_rank == 5; } void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { Attribute running_mean; Attribute running_var; Attribute weight; Attribute bias; for (const auto& x : captured_attrs) { if (x.first.substr(0, 8) == "op_mean.") running_mean = x.second; if (x.first.substr(0, 7) == "op_var.") running_var = x.second; if (x.first.substr(0, 10) == "op_weight.") weight = x.second; if (x.first.substr(0, 8) == "op_bias.") bias = x.second; } op->params["num_features"] = running_mean.shape[0]; op->params["eps"] = captured_params.at("eps"); op->params["affine"] = true; op->attrs["running_mean"] = running_mean; op->attrs["running_var"] = running_var; op->attrs["weight"] = weight; op->attrs["bias"] = bias; } }; void fuse_static_batchnorm(Graph& graph) { fuse_static_Fbatchnorm_pass_1d a; fuse_static_Fbatchnorm_pass_2d b; fuse_static_Fbatchnorm_pass_3d c; fuse_static_Fbatchnorm_pass_1d_1 a1; fuse_static_Fbatchnorm_pass_2d_1 b1; fuse_static_Fbatchnorm_pass_3d_1 c1; int opindex = 0; pnnx_graph_rewrite(graph, &a, opindex); pnnx_graph_rewrite(graph, &b, opindex); pnnx_graph_rewrite(graph, &c, opindex); pnnx_graph_rewrite(graph, &a1, opindex); pnnx_graph_rewrite(graph, &b1, opindex); pnnx_graph_rewrite(graph, &c1, opindex); } } // namespace pnnx