diff --git a/src/layer/reduction.cpp b/src/layer/reduction.cpp index 859974f53..4c84ef129 100644 --- a/src/layer/reduction.cpp +++ b/src/layer/reduction.cpp @@ -857,8 +857,8 @@ int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int axis = axes_ptr[i]; // handle negative axis if (axis < 0) - axis += dims + 1; - axes_flag[axis - 1] = 1; + axis += dims; + axes_flag[axis] = 1; } if (dims == 1) diff --git a/tools/mlir/mlir2ncnn.cpp b/tools/mlir/mlir2ncnn.cpp index 5f2307438..82ebe7673 100644 --- a/tools/mlir/mlir2ncnn.cpp +++ b/tools/mlir/mlir2ncnn.cpp @@ -1474,11 +1474,11 @@ int main(int argc, char** argv) for (int i = 0; i < (int)v.size(); i++) { if (v[i] == 1) - fprintf(pp, ",2"); + fprintf(pp, ",1"); if (v[i] == 2) - fprintf(pp, ",3"); + fprintf(pp, ",2"); if (v[i] == 3) - fprintf(pp, ",1"); + fprintf(pp, ",0"); } fprintf(pp, " 4=%d", keep_dims); } diff --git a/tools/mxnet/mxnet2ncnn.cpp b/tools/mxnet/mxnet2ncnn.cpp index 71a59ae8e..61ce8d72d 100644 --- a/tools/mxnet/mxnet2ncnn.cpp +++ b/tools/mxnet/mxnet2ncnn.cpp @@ -2334,9 +2334,9 @@ int main(int argc, char** argv) fprintf(pp, " -23303=%zd", axis.size()); for (size_t j = 0; j < axis.size(); j++) { - if (axis[j] == 0 || axis[j] > 3 || axis[j] < -3) + if (axis[j] == 0 || axis[j] > 4 || axis[j] < -3) fprintf(stderr, "Unsupported reduction axis !\n"); - fprintf(pp, ",%d", axis[j]); + fprintf(pp, ",%d", axis[j] > 0 ? axis[j] - 1 : axis[j]); } } fprintf(pp, " 4=%d", keepdims); diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 9be98054b..fd1702988 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -5372,9 +5372,9 @@ int main(int argc, char** argv) fprintf(pp, " -23303=%zu", axes.size()); for (size_t j = 0; j < axes.size(); j++) { - if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3) + if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) fprintf(stderr, "Unsupported reduction axes !\n"); - fprintf(pp, ",%d", axes[j]); + fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); } } else @@ -5772,9 +5772,9 @@ int main(int argc, char** argv) fprintf(pp, " -23303=%zu", axes.size()); for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) + if (axes[i] == 0 || axes[i] > 4 || axes[i] < -3) fprintf(stderr, "Unsupported squeeze axes !\n"); - fprintf(pp, ",%d", axes[i]); + fprintf(pp, ",%d", axes[i] > 0 ? axes[i] - 1 : axes[i]); } } } @@ -5932,7 +5932,7 @@ int main(int argc, char** argv) { if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) fprintf(stderr, "Unsupported unsqueeze axes !\n"); - fprintf(pp, ",%d", axes[i]); + fprintf(pp, ",%d", axes[i] > 0 ? axes[i] - 1 : axes[i]); } } else diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 4b67b695f..ad0786a02 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -170,6 +170,8 @@ set(pnnx_pass_level2_SRCS pass_level2/Tensor_select.cpp pass_level2/Tensor_slice.cpp pass_level2/Tensor_view.cpp + pass_level2/torch_amax.cpp + pass_level2/torch_amin.cpp pass_level2/torch_argmax.cpp pass_level2/torch_argmin.cpp pass_level2/torch_cat.cpp @@ -178,12 +180,14 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_clone.cpp pass_level2/torch_dequantize.cpp pass_level2/torch_flatten.cpp + pass_level2/torch_logsumexp.cpp pass_level2/torch_mean.cpp pass_level2/torch_normal.cpp + pass_level2/torch_prod.cpp pass_level2/torch_quantize_per_tensor.cpp - pass_level2/torch_sum.cpp pass_level2/torch_split.cpp pass_level2/torch_squeeze.cpp + pass_level2/torch_sum.cpp pass_level2/torch_permute.cpp pass_level2/torch_transpose.cpp pass_level2/torch_unsqueeze.cpp @@ -358,12 +362,17 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/Tensor_reshape.cpp pass_ncnn/Tensor_slice.cpp pass_ncnn/Tensor_view.cpp + pass_ncnn/torch_amax.cpp + pass_ncnn/torch_amin.cpp pass_ncnn/torch_clamp.cpp pass_ncnn/torch_clone.cpp pass_ncnn/torch_flatten.cpp + pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_mean.cpp pass_ncnn/torch_permute.cpp + pass_ncnn/torch_prod.cpp pass_ncnn/torch_squeeze.cpp + pass_ncnn/torch_sum.cpp pass_ncnn/torch_transpose.cpp pass_ncnn/torch_unsqueeze.cpp ) diff --git a/tools/pnnx/src/pass_level2/torch_amax.cpp b/tools/pnnx/src/pass_level2/torch_amax.cpp new file mode 100644 index 000000000..2c8636c65 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_amax.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_amax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +aten::amax op_1 3 1 input dim keepdim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.amax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_amax, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_amin.cpp b/tools/pnnx/src/pass_level2/torch_amin.cpp new file mode 100644 index 000000000..46201c714 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_amin.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_amin : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +aten::amin op_1 3 1 input dim keepdim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.amin"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_amin, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_logsumexp.cpp b/tools/pnnx/src/pass_level2/torch_logsumexp.cpp new file mode 100644 index 000000000..c00a26e3e --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_logsumexp.cpp @@ -0,0 +1,42 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_logsumexp : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +aten::logsumexp op_1 3 1 input dim keepdim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.logsumexp"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_logsumexp, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_prod.cpp b/tools/pnnx/src/pass_level2/torch_prod.cpp new file mode 100644 index 000000000..bd3e49b8c --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_prod.cpp @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_prod : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +prim::Constant op_0 0 1 keepdim value=%keepdim +prim::Constant op_1 0 1 dtype value=* +aten::prod op_2 4 1 input dim keepdim dtype out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.prod"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp b/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp index bd18ca9a8..3c50b6f42 100644 --- a/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_torch_split.cpp @@ -31,7 +31,7 @@ void convert_torch_split(Graph& graph) op->name = std::string("split_") + std::to_string(op_index++); const Parameter& split_size_or_sections = op->params.at("split_size_or_sections"); - if (split_size_or_sections.type != 1 && split_size_or_sections.type != 5) + if (split_size_or_sections.type != 2 && split_size_or_sections.type != 5) { fprintf(stderr, "malformed split split_size_or_sections type %d\n", split_size_or_sections.type); continue; @@ -55,7 +55,7 @@ void convert_torch_split(Graph& graph) if (axis > batch_index) axis -= 1; - if (split_size_or_sections.type == 1) + if (split_size_or_sections.type == 2) { const size_t output_size = op->outputs.size(); op->params["0"].type = 5; diff --git a/tools/pnnx/src/pass_ncnn/torch_amax.cpp b/tools/pnnx/src/pass_ncnn/torch_amax.cpp new file mode 100644 index 000000000..020ccc172 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_amax.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_amax : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.amax op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "amax"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dim").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = 4; + op->params["1"] = 0; + op->params["3"] = new_dims; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_amax, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_amin.cpp b/tools/pnnx/src/pass_ncnn/torch_amin.cpp new file mode 100644 index 000000000..2d4233e86 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_amin.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_amin : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.amin op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "amin"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dim").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = 5; + op->params["1"] = 0; + op->params["3"] = new_dims; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_amin, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_logsumexp.cpp b/tools/pnnx/src/pass_ncnn/torch_logsumexp.cpp new file mode 100644 index 000000000..b72dd687b --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_logsumexp.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_logsumexp : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.logsumexp op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "logsumexp"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dim").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = 10; + op->params["1"] = 0; + op->params["3"] = new_dims; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_logsumexp, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_mean.cpp b/tools/pnnx/src/pass_ncnn/torch_mean.cpp index e0eecce29..28217375a 100644 --- a/tools/pnnx/src/pass_ncnn/torch_mean.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_mean.cpp @@ -26,38 +26,6 @@ public: return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -torch.mean op_0 1 1 input out dim=(2,3) keepdim=False -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "Pooling"; - } - - const char* name_str() const - { - return "gap"; - } - - void write(Operator* op, const std::map& /*captured_params*/) const - { - op->params["0"] = 1; - op->params["4"] = 1; - } -}; - -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) - -class torch_mean_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input torch.mean op_0 1 1 input out dim=%dim keepdim=%keepdim pnnx.Output output 1 0 out )PNNXIR"; @@ -79,7 +47,7 @@ pnnx.Output output 1 0 out const int batch_index = op->inputs[0]->params["__batch_index"].i; - // drop mean batch index + // drop batch index std::vector new_dims; for (int i = 0; i < (int)dims.size(); i++) { @@ -97,7 +65,7 @@ pnnx.Output output 1 0 out } }; -REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) } // namespace ncnn diff --git a/tools/pnnx/src/pass_ncnn/torch_prod.cpp b/tools/pnnx/src/pass_ncnn/torch_prod.cpp new file mode 100644 index 000000000..ffb693ef9 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_prod.cpp @@ -0,0 +1,69 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_prod : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.prod op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "prod"; + } + + void write(Operator* op, const std::map& captured_params) const + { + int dim = captured_params.at("dim").i; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + if (dim == batch_index) + { + fprintf(stderr, "prod along batch axis is not supported\n"); + return; + } + + int new_dim = dim > batch_index ? dim - 1 : dim; + + op->params["0"] = 6; + op->params["1"] = 0; + op->params["3"] = std::vector{new_dim}; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_sum.cpp b/tools/pnnx/src/pass_ncnn/torch_sum.cpp new file mode 100644 index 000000000..0baa01b31 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_sum.cpp @@ -0,0 +1,72 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_sum : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +torch.sum op_0 1 1 input out dim=%dim keepdim=%keepdim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Reduction"; + } + + const char* name_str() const + { + return "sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dim").ai; + + const int batch_index = op->inputs[0]->params["__batch_index"].i; + + // drop batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + op->params["0"] = 0; + op->params["1"] = 0; + op->params["3"] = new_dims; + op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 8a1acb65e..9aaa4533e 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -161,6 +161,8 @@ pnnx_add_test(Tensor_select) pnnx_add_test(Tensor_slice) pnnx_add_test(Tensor_view) +pnnx_add_test(torch_amax) +pnnx_add_test(torch_amin) pnnx_add_test(torch_argmax) pnnx_add_test(torch_argmin) pnnx_add_test(torch_cat) @@ -168,8 +170,10 @@ pnnx_add_test(torch_chunk) pnnx_add_test(torch_clamp) pnnx_add_test(torch_clone) pnnx_add_test(torch_flatten) +pnnx_add_test(torch_logsumexp) pnnx_add_test(torch_mean) pnnx_add_test(torch_permute) +pnnx_add_test(torch_prod) pnnx_add_test(torch_sum) pnnx_add_test(torch_split) pnnx_add_test(torch_squeeze) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 78a061ac1..86100d603 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -122,11 +122,17 @@ pnnx_ncnn_add_test(Tensor_reshape) pnnx_ncnn_add_test(Tensor_slice) pnnx_ncnn_add_test(Tensor_view) +pnnx_ncnn_add_test(torch_amax) +pnnx_ncnn_add_test(torch_amin) pnnx_ncnn_add_test(torch_cat) pnnx_ncnn_add_test(torch_chunk) pnnx_ncnn_add_test(torch_clamp) pnnx_ncnn_add_test(torch_clone) +pnnx_ncnn_add_test(torch_logsumexp) +pnnx_ncnn_add_test(torch_mean) pnnx_ncnn_add_test(torch_permute) +pnnx_ncnn_add_test(torch_prod) +pnnx_ncnn_add_test(torch_sum) pnnx_ncnn_add_test(torch_squeeze) pnnx_ncnn_add_test(torch_transpose) pnnx_ncnn_add_test(torch_unsqueeze) diff --git a/tools/pnnx/tests/ncnn/test_torch_amax.py b/tools/pnnx/tests/ncnn/test_torch_amax.py new file mode 100644 index 000000000..041d7deb2 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_amax.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.amax(x, dim=0, keepdim=False) + y = torch.amax(y, dim=(1,2), keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_amax.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_amax.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_amax_ncnn + b = test_torch_amax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_amin.py b/tools/pnnx/tests/ncnn/test_torch_amin.py new file mode 100644 index 000000000..26b485398 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_amin.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.amin(x, dim=0, keepdim=False) + y = torch.amin(y, dim=(1,2), keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_amin.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_amin.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_amin_ncnn + b = test_torch_amin_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_logsumexp.py b/tools/pnnx/tests/ncnn/test_torch_logsumexp.py new file mode 100644 index 000000000..6b50cd471 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_logsumexp.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.logsumexp(x, dim=0, keepdim=False) + y = torch.logsumexp(y, dim=(1,2), keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_logsumexp.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_logsumexp.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_logsumexp_ncnn + b = test_torch_logsumexp_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_mean.py b/tools/pnnx/tests/ncnn/test_torch_mean.py new file mode 100644 index 000000000..1dbd09da6 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_mean.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.mean(x, dim=0, keepdim=False) + y = torch.mean(y, dim=(1,2), keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_mean.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_mean.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_mean_ncnn + b = test_torch_mean_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_prod.py b/tools/pnnx/tests/ncnn/test_torch_prod.py new file mode 100644 index 000000000..37b2b2750 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_prod.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.prod(x, dim=0, keepdim=False) + y = torch.prod(y, dim=1, keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_prod.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_prod.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_prod_ncnn + b = test_torch_prod_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_sum.py b/tools/pnnx/tests/ncnn/test_torch_sum.py new file mode 100644 index 000000000..f08c891ac --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_sum.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + x = torch.sum(x, dim=0, keepdim=False) + y = torch.sum(y, dim=(1,2), keepdim=True) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_sum.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_sum.pt inputshape=[3,16],[5,9,11]") + + # ncnn inference + import test_torch_sum_ncnn + b = test_torch_sum_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_amax.py b/tools/pnnx/tests/test_torch_amax.py new file mode 100644 index 000000000..7f94f9625 --- /dev/null +++ b/tools/pnnx/tests/test_torch_amax.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.amax(x, dim=1, keepdim=False) + y = torch.amax(y, dim=(2,3), keepdim=False) + z = torch.amax(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_amax.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_amax.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_amax_pnnx + b = test_torch_amax_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_amin.py b/tools/pnnx/tests/test_torch_amin.py new file mode 100644 index 000000000..b18d02bc5 --- /dev/null +++ b/tools/pnnx/tests/test_torch_amin.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.amin(x, dim=1, keepdim=False) + y = torch.amin(y, dim=(2,3), keepdim=False) + z = torch.amin(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_amin.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_amin.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_amin_pnnx + b = test_torch_amin_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_logsumexp.py b/tools/pnnx/tests/test_torch_logsumexp.py new file mode 100644 index 000000000..c3f1f5ff8 --- /dev/null +++ b/tools/pnnx/tests/test_torch_logsumexp.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.logsumexp(x, dim=1, keepdim=False) + y = torch.logsumexp(y, dim=(2,3), keepdim=False) + z = torch.logsumexp(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_logsumexp.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_logsumexp.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_logsumexp_pnnx + b = test_torch_logsumexp_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_prod.py b/tools/pnnx/tests/test_torch_prod.py new file mode 100644 index 000000000..c04b53833 --- /dev/null +++ b/tools/pnnx/tests/test_torch_prod.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.prod(x, dim=1, keepdim=False) + y = torch.prod(y, dim=2, keepdim=False) + z = torch.prod(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_prod.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_prod.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_prod_pnnx + b = test_torch_prod_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)