From 1be043aad55feb874fe78239cd7ea3164146aff5 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 20 Jan 2022 19:24:55 +0800 Subject: [PATCH] convert torch mean/sum/prod reduction with no args --- tools/pnnx/src/pass_ncnn/torch_mean.cpp | 33 +++++++++++++++++++++++++ tools/pnnx/src/pass_ncnn/torch_prod.cpp | 33 +++++++++++++++++++++++++ tools/pnnx/src/pass_ncnn/torch_sum.cpp | 33 +++++++++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/tools/pnnx/src/pass_ncnn/torch_mean.cpp b/tools/pnnx/src/pass_ncnn/torch_mean.cpp index 28217375a..65577ff87 100644 --- a/tools/pnnx/src/pass_ncnn/torch_mean.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_mean.cpp @@ -67,6 +67,39 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) +class torch_mean_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.mean op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "mean"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 3; + op->params["1"] = 1; + op->params["4"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_prod.cpp b/tools/pnnx/src/pass_ncnn/torch_prod.cpp index ffb693ef9..a2e45f681 100644 --- a/tools/pnnx/src/pass_ncnn/torch_prod.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_prod.cpp @@ -64,6 +64,39 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod, 20) +class torch_prod_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.prod op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "prod"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 6; + op->params["1"] = 1; + op->params["4"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod_1, 20) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_sum.cpp b/tools/pnnx/src/pass_ncnn/torch_sum.cpp index 0baa01b31..ebd1c2f8f 100644 --- a/tools/pnnx/src/pass_ncnn/torch_sum.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_sum.cpp @@ -67,6 +67,39 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum, 20) +class torch_sum_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.sum op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = 0; + op->params["1"] = 1; + op->params["4"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum_1, 20) + } // namespace ncnn } // namespace pnnx