| @@ -857,8 +857,8 @@ int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| int axis = axes_ptr[i]; | |||
| // handle negative axis | |||
| if (axis < 0) | |||
| axis += dims + 1; | |||
| axes_flag[axis - 1] = 1; | |||
| axis += dims; | |||
| axes_flag[axis] = 1; | |||
| } | |||
| if (dims == 1) | |||
| @@ -1474,11 +1474,11 @@ int main(int argc, char** argv) | |||
| for (int i = 0; i < (int)v.size(); i++) | |||
| { | |||
| if (v[i] == 1) | |||
| fprintf(pp, ",2"); | |||
| fprintf(pp, ",1"); | |||
| if (v[i] == 2) | |||
| fprintf(pp, ",3"); | |||
| fprintf(pp, ",2"); | |||
| if (v[i] == 3) | |||
| fprintf(pp, ",1"); | |||
| fprintf(pp, ",0"); | |||
| } | |||
| fprintf(pp, " 4=%d", keep_dims); | |||
| } | |||
| @@ -2334,9 +2334,9 @@ int main(int argc, char** argv) | |||
| fprintf(pp, " -23303=%zd", axis.size()); | |||
| for (size_t j = 0; j < axis.size(); j++) | |||
| { | |||
| if (axis[j] == 0 || axis[j] > 3 || axis[j] < -3) | |||
| if (axis[j] == 0 || axis[j] > 4 || axis[j] < -3) | |||
| fprintf(stderr, "Unsupported reduction axis !\n"); | |||
| fprintf(pp, ",%d", axis[j]); | |||
| fprintf(pp, ",%d", axis[j] > 0 ? axis[j] - 1 : axis[j]); | |||
| } | |||
| } | |||
| fprintf(pp, " 4=%d", keepdims); | |||
| @@ -5372,9 +5372,9 @@ int main(int argc, char** argv) | |||
| fprintf(pp, " -23303=%zu", axes.size()); | |||
| for (size_t j = 0; j < axes.size(); j++) | |||
| { | |||
| if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3) | |||
| if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) | |||
| fprintf(stderr, "Unsupported reduction axes !\n"); | |||
| fprintf(pp, ",%d", axes[j]); | |||
| fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); | |||
| } | |||
| } | |||
| else | |||
| @@ -5772,9 +5772,9 @@ int main(int argc, char** argv) | |||
| fprintf(pp, " -23303=%zu", axes.size()); | |||
| for (int i = 0; i < (int)axes.size(); i++) | |||
| { | |||
| if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) | |||
| if (axes[i] == 0 || axes[i] > 4 || axes[i] < -3) | |||
| fprintf(stderr, "Unsupported squeeze axes !\n"); | |||
| fprintf(pp, ",%d", axes[i]); | |||
| fprintf(pp, ",%d", axes[i] > 0 ? axes[i] - 1 : axes[i]); | |||
| } | |||
| } | |||
| } | |||
| @@ -5932,7 +5932,7 @@ int main(int argc, char** argv) | |||
| { | |||
| if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) | |||
| fprintf(stderr, "Unsupported unsqueeze axes !\n"); | |||
| fprintf(pp, ",%d", axes[i]); | |||
| fprintf(pp, ",%d", axes[i] > 0 ? axes[i] - 1 : axes[i]); | |||
| } | |||
| } | |||
| else | |||
| @@ -170,6 +170,8 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/Tensor_select.cpp | |||
| pass_level2/Tensor_slice.cpp | |||
| pass_level2/Tensor_view.cpp | |||
| pass_level2/torch_amax.cpp | |||
| pass_level2/torch_amin.cpp | |||
| pass_level2/torch_argmax.cpp | |||
| pass_level2/torch_argmin.cpp | |||
| pass_level2/torch_cat.cpp | |||
| @@ -178,12 +180,14 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_clone.cpp | |||
| pass_level2/torch_dequantize.cpp | |||
| pass_level2/torch_flatten.cpp | |||
| pass_level2/torch_logsumexp.cpp | |||
| pass_level2/torch_mean.cpp | |||
| pass_level2/torch_normal.cpp | |||
| pass_level2/torch_prod.cpp | |||
| pass_level2/torch_quantize_per_tensor.cpp | |||
| pass_level2/torch_sum.cpp | |||
| pass_level2/torch_split.cpp | |||
| pass_level2/torch_squeeze.cpp | |||
| pass_level2/torch_sum.cpp | |||
| pass_level2/torch_permute.cpp | |||
| pass_level2/torch_transpose.cpp | |||
| pass_level2/torch_unsqueeze.cpp | |||
| @@ -358,12 +362,17 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/Tensor_reshape.cpp | |||
| pass_ncnn/Tensor_slice.cpp | |||
| pass_ncnn/Tensor_view.cpp | |||
| pass_ncnn/torch_amax.cpp | |||
| pass_ncnn/torch_amin.cpp | |||
| pass_ncnn/torch_clamp.cpp | |||
| pass_ncnn/torch_clone.cpp | |||
| pass_ncnn/torch_flatten.cpp | |||
| pass_ncnn/torch_logsumexp.cpp | |||
| pass_ncnn/torch_mean.cpp | |||
| pass_ncnn/torch_permute.cpp | |||
| pass_ncnn/torch_prod.cpp | |||
| pass_ncnn/torch_squeeze.cpp | |||
| pass_ncnn/torch_sum.cpp | |||
| pass_ncnn/torch_transpose.cpp | |||
| pass_ncnn/torch_unsqueeze.cpp | |||
| ) | |||
| @@ -0,0 +1,42 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_amax : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| prim::Constant op_0 0 1 keepdim value=%keepdim | |||
| aten::amax op_1 3 1 input dim keepdim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.amax"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_amax, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,42 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_amin : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| prim::Constant op_0 0 1 keepdim value=%keepdim | |||
| aten::amin op_1 3 1 input dim keepdim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.amin"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_amin, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,42 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_logsumexp : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| prim::Constant op_0 0 1 keepdim value=%keepdim | |||
| aten::logsumexp op_1 3 1 input dim keepdim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.logsumexp"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_logsumexp, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,43 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_prod : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| prim::Constant op_0 0 1 keepdim value=%keepdim | |||
| prim::Constant op_1 0 1 dtype value=* | |||
| aten::prod op_2 4 1 input dim keepdim dtype out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.prod"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod, 20) | |||
| } // namespace pnnx | |||
| @@ -31,7 +31,7 @@ void convert_torch_split(Graph& graph) | |||
| op->name = std::string("split_") + std::to_string(op_index++); | |||
| const Parameter& split_size_or_sections = op->params.at("split_size_or_sections"); | |||
| if (split_size_or_sections.type != 1 && split_size_or_sections.type != 5) | |||
| if (split_size_or_sections.type != 2 && split_size_or_sections.type != 5) | |||
| { | |||
| fprintf(stderr, "malformed split split_size_or_sections type %d\n", split_size_or_sections.type); | |||
| continue; | |||
| @@ -55,7 +55,7 @@ void convert_torch_split(Graph& graph) | |||
| if (axis > batch_index) | |||
| axis -= 1; | |||
| if (split_size_or_sections.type == 1) | |||
| if (split_size_or_sections.type == 2) | |||
| { | |||
| const size_t output_size = op->outputs.size(); | |||
| op->params["0"].type = 5; | |||
| @@ -0,0 +1,72 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_amax : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.amax op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Reduction"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "amax"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::vector<int>& dims = captured_params.at("dim").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| // drop batch index | |||
| std::vector<int> new_dims; | |||
| for (int i = 0; i < (int)dims.size(); i++) | |||
| { | |||
| if (dims[i] == batch_index) | |||
| continue; | |||
| int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; | |||
| new_dims.push_back(new_dim); | |||
| } | |||
| op->params["0"] = 4; | |||
| op->params["1"] = 0; | |||
| op->params["3"] = new_dims; | |||
| op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_amax, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,72 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_amin : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.amin op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Reduction"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "amin"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::vector<int>& dims = captured_params.at("dim").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| // drop batch index | |||
| std::vector<int> new_dims; | |||
| for (int i = 0; i < (int)dims.size(); i++) | |||
| { | |||
| if (dims[i] == batch_index) | |||
| continue; | |||
| int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; | |||
| new_dims.push_back(new_dim); | |||
| } | |||
| op->params["0"] = 5; | |||
| op->params["1"] = 0; | |||
| op->params["3"] = new_dims; | |||
| op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_amin, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,72 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_logsumexp : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.logsumexp op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Reduction"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "logsumexp"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::vector<int>& dims = captured_params.at("dim").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| // drop batch index | |||
| std::vector<int> new_dims; | |||
| for (int i = 0; i < (int)dims.size(); i++) | |||
| { | |||
| if (dims[i] == batch_index) | |||
| continue; | |||
| int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; | |||
| new_dims.push_back(new_dim); | |||
| } | |||
| op->params["0"] = 10; | |||
| op->params["1"] = 0; | |||
| op->params["3"] = new_dims; | |||
| op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_logsumexp, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -26,38 +26,6 @@ public: | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.mean op_0 1 1 input out dim=(2,3) keepdim=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Pooling"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "gap"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| op->params["0"] = 1; | |||
| op->params["4"] = 1; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) | |||
| class torch_mean_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.mean op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -79,7 +47,7 @@ pnnx.Output output 1 0 out | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| // drop mean batch index | |||
| // drop batch index | |||
| std::vector<int> new_dims; | |||
| for (int i = 0; i < (int)dims.size(); i++) | |||
| { | |||
| @@ -97,7 +65,7 @@ pnnx.Output output 1 0 out | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean_1, 20) | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_mean, 20) | |||
| } // namespace ncnn | |||
| @@ -0,0 +1,69 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_prod : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.prod op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Reduction"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "prod"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| int dim = captured_params.at("dim").i; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| if (dim == batch_index) | |||
| { | |||
| fprintf(stderr, "prod along batch axis is not supported\n"); | |||
| return; | |||
| } | |||
| int new_dim = dim > batch_index ? dim - 1 : dim; | |||
| op->params["0"] = 6; | |||
| op->params["1"] = 0; | |||
| op->params["3"] = std::vector<int>{new_dim}; | |||
| op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_prod, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,72 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_sum : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input | |||
| torch.sum op_0 1 1 input out dim=%dim keepdim=%keepdim | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Reduction"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "sum"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const std::vector<int>& dims = captured_params.at("dim").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| // drop batch index | |||
| std::vector<int> new_dims; | |||
| for (int i = 0; i < (int)dims.size(); i++) | |||
| { | |||
| if (dims[i] == batch_index) | |||
| continue; | |||
| int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; | |||
| new_dims.push_back(new_dim); | |||
| } | |||
| op->params["0"] = 0; | |||
| op->params["1"] = 0; | |||
| op->params["3"] = new_dims; | |||
| op->params["4"] = captured_params.at("keepdim").b ? 1 : 0; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_sum, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -161,6 +161,8 @@ pnnx_add_test(Tensor_select) | |||
| pnnx_add_test(Tensor_slice) | |||
| pnnx_add_test(Tensor_view) | |||
| pnnx_add_test(torch_amax) | |||
| pnnx_add_test(torch_amin) | |||
| pnnx_add_test(torch_argmax) | |||
| pnnx_add_test(torch_argmin) | |||
| pnnx_add_test(torch_cat) | |||
| @@ -168,8 +170,10 @@ pnnx_add_test(torch_chunk) | |||
| pnnx_add_test(torch_clamp) | |||
| pnnx_add_test(torch_clone) | |||
| pnnx_add_test(torch_flatten) | |||
| pnnx_add_test(torch_logsumexp) | |||
| pnnx_add_test(torch_mean) | |||
| pnnx_add_test(torch_permute) | |||
| pnnx_add_test(torch_prod) | |||
| pnnx_add_test(torch_sum) | |||
| pnnx_add_test(torch_split) | |||
| pnnx_add_test(torch_squeeze) | |||
| @@ -122,11 +122,17 @@ pnnx_ncnn_add_test(Tensor_reshape) | |||
| pnnx_ncnn_add_test(Tensor_slice) | |||
| pnnx_ncnn_add_test(Tensor_view) | |||
| pnnx_ncnn_add_test(torch_amax) | |||
| pnnx_ncnn_add_test(torch_amin) | |||
| pnnx_ncnn_add_test(torch_cat) | |||
| pnnx_ncnn_add_test(torch_chunk) | |||
| pnnx_ncnn_add_test(torch_clamp) | |||
| pnnx_ncnn_add_test(torch_clone) | |||
| pnnx_ncnn_add_test(torch_logsumexp) | |||
| pnnx_ncnn_add_test(torch_mean) | |||
| pnnx_ncnn_add_test(torch_permute) | |||
| pnnx_ncnn_add_test(torch_prod) | |||
| pnnx_ncnn_add_test(torch_sum) | |||
| pnnx_ncnn_add_test(torch_squeeze) | |||
| pnnx_ncnn_add_test(torch_transpose) | |||
| pnnx_ncnn_add_test(torch_unsqueeze) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.amax(x, dim=0, keepdim=False) | |||
| y = torch.amax(y, dim=(1,2), keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_amax.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_amax.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_amax_ncnn | |||
| b = test_torch_amax_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.amin(x, dim=0, keepdim=False) | |||
| y = torch.amin(y, dim=(1,2), keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_amin.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_amin.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_amin_ncnn | |||
| b = test_torch_amin_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.logsumexp(x, dim=0, keepdim=False) | |||
| y = torch.logsumexp(y, dim=(1,2), keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_logsumexp.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_logsumexp.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_logsumexp_ncnn | |||
| b = test_torch_logsumexp_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.mean(x, dim=0, keepdim=False) | |||
| y = torch.mean(y, dim=(1,2), keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_mean.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_mean.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_mean_ncnn | |||
| b = test_torch_mean_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.prod(x, dim=0, keepdim=False) | |||
| y = torch.prod(y, dim=1, keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_prod.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_prod.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_prod_ncnn | |||
| b = test_torch_prod_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = torch.sum(x, dim=0, keepdim=False) | |||
| y = torch.sum(y, dim=(1,2), keepdim=True) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_sum.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_sum.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_sum_ncnn | |||
| b = test_torch_sum_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.amax(x, dim=1, keepdim=False) | |||
| y = torch.amax(y, dim=(2,3), keepdim=False) | |||
| z = torch.amax(z, dim=0, keepdim=True) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_amax.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_amax.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_amax_pnnx | |||
| b = test_torch_amax_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.amin(x, dim=1, keepdim=False) | |||
| y = torch.amin(y, dim=(2,3), keepdim=False) | |||
| z = torch.amin(z, dim=0, keepdim=True) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_amin.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_amin.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_amin_pnnx | |||
| b = test_torch_amin_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.logsumexp(x, dim=1, keepdim=False) | |||
| y = torch.logsumexp(y, dim=(2,3), keepdim=False) | |||
| z = torch.logsumexp(z, dim=0, keepdim=True) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_logsumexp.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_logsumexp.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_logsumexp_pnnx | |||
| b = test_torch_logsumexp_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.prod(x, dim=1, keepdim=False) | |||
| y = torch.prod(y, dim=2, keepdim=False) | |||
| z = torch.prod(z, dim=0, keepdim=True) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_prod.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_prod.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_prod_pnnx | |||
| b = test_torch_prod_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||