* convert torch.unbind * torch.ones torch.ones_like * torch.full torch.full_like * torch.randn_like * torch.empty torch.empty_liketags/20220420
| @@ -185,16 +185,23 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_clamp.cpp | |||
| pass_level2/torch_clone.cpp | |||
| pass_level2/torch_dequantize.cpp | |||
| pass_level2/torch_empty.cpp | |||
| pass_level2/torch_empty_like.cpp | |||
| pass_level2/torch_flatten.cpp | |||
| pass_level2/torch_flip.cpp | |||
| pass_level2/torch_full.cpp | |||
| pass_level2/torch_full_like.cpp | |||
| pass_level2/torch_logsumexp.cpp | |||
| pass_level2/torch_matmul.cpp | |||
| pass_level2/torch_mean.cpp | |||
| pass_level2/torch_norm.cpp | |||
| pass_level2/torch_normal.cpp | |||
| pass_level2/torch_ones.cpp | |||
| pass_level2/torch_ones_like.cpp | |||
| pass_level2/torch_prod.cpp | |||
| pass_level2/torch_quantize_per_tensor.cpp | |||
| pass_level2/torch_randn.cpp | |||
| pass_level2/torch_randn_like.cpp | |||
| pass_level2/torch_roll.cpp | |||
| pass_level2/torch_split.cpp | |||
| pass_level2/torch_squeeze.cpp | |||
| @@ -263,6 +270,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/convert_torch_cat.cpp | |||
| pass_ncnn/convert_torch_chunk.cpp | |||
| pass_ncnn/convert_torch_split.cpp | |||
| pass_ncnn/convert_torch_unbind.cpp | |||
| pass_ncnn/eliminate_output.cpp | |||
| pass_ncnn/expand_expression.cpp | |||
| pass_ncnn/insert_split.cpp | |||
| @@ -1735,6 +1735,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| { | |||
| fprintf(pyfp, "%s=", it.first.c_str()); | |||
| } | |||
| else if (op->inputs.empty() && i == 0) | |||
| { | |||
| fprintf(pyfp, "%s=", it.first.c_str()); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, ", %s=", it.first.c_str()); | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_empty : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 size | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| prim::Constant op_4 0 1 memory_format value=* | |||
| aten::empty op_5 6 1 size dtype layout device requires_grad memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.empty"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_empty, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_empty_like : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| prim::Constant op_4 0 1 memory_format value=* | |||
| aten::empty_like op_5 6 1 input dtype layout device requires_grad memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.empty_like"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_empty_like, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_full : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 size | |||
| pnnx.Input input_1 0 1 fill_value | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| aten::full op_4 6 1 size fill_value dtype layout device requires_grad out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.full"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,46 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_full_like : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 9 8 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 fill_value | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| prim::Constant op_4 0 1 memory_format value=* | |||
| aten::full_like op_5 7 1 input fill_value dtype layout device requires_grad memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.full_like"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,44 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_ones : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input_0 0 1 size | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| aten::ones op_4 5 1 size dtype layout device requires_grad out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.ones"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_ones, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_ones_like : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| prim::Constant op_4 0 1 memory_format value=* | |||
| aten::ones_like op_5 6 1 input dtype layout device requires_grad memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.ones_like"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_ones_like, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,45 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_randn_like : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| prim::Constant op_4 0 1 memory_format value=* | |||
| aten::randn_like op_5 6 1 input dtype layout device requires_grad memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.randn_like"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_randn_like, 20) | |||
| } // namespace pnnx | |||
| @@ -21,6 +21,7 @@ | |||
| #include "pass_ncnn/convert_torch_cat.h" | |||
| #include "pass_ncnn/convert_torch_chunk.h" | |||
| #include "pass_ncnn/convert_torch_split.h" | |||
| #include "pass_ncnn/convert_torch_unbind.h" | |||
| #include "pass_ncnn/eliminate_output.h" | |||
| #include "pass_ncnn/expand_expression.h" | |||
| #include "pass_ncnn/insert_split.h" | |||
| @@ -80,6 +81,11 @@ void pass_ncnn(Graph& g) | |||
| ncnn::convert_half_to_float(g); | |||
| ncnn::convert_torch_cat(g); | |||
| ncnn::convert_torch_chunk(g); | |||
| ncnn::convert_torch_split(g); | |||
| ncnn::convert_torch_unbind(g); | |||
| int opindex = 0; | |||
| for (auto x : g_global_pnnx_ncnn_graph_rewriter_passes) | |||
| { | |||
| @@ -89,10 +95,6 @@ void pass_ncnn(Graph& g) | |||
| } | |||
| } | |||
| ncnn::convert_torch_cat(g); | |||
| ncnn::convert_torch_chunk(g); | |||
| ncnn::convert_torch_split(g); | |||
| ncnn::insert_split(g); | |||
| ncnn::eliminate_noop(g); | |||
| @@ -0,0 +1,97 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "convert_torch_unbind.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_unbind(Graph& graph) | |||
| { | |||
| int op_index = 0; | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (Operator* op : graph.ops) | |||
| { | |||
| if (op->type != "torch.unbind") | |||
| continue; | |||
| matched = true; | |||
| op->type = "Slice"; | |||
| op->name = std::string("unbind_") + std::to_string(op_index++); | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| int axis = op->params.at("dim").i; | |||
| if (axis == batch_index) | |||
| { | |||
| fprintf(stderr, "unbind along batch axis %d is not supported\n", batch_index); | |||
| continue; | |||
| } | |||
| if (axis < 0) | |||
| { | |||
| int input_rank = op->inputs[0]->shape.size(); | |||
| axis = input_rank + axis; | |||
| } | |||
| int output_size = (int)op->outputs.size(); | |||
| if (axis > batch_index) | |||
| axis -= 1; | |||
| op->params["0"].type = 5; | |||
| op->params["0"].ai.resize(output_size, -233); | |||
| op->params["1"] = axis; | |||
| op->params.erase("dim"); | |||
| // reshape for each output, squeezing the unbind dim | |||
| for (int i = 0; i < output_size; i++) | |||
| { | |||
| Operand* out = op->outputs[i]; | |||
| Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape" + std::to_string(i), op); | |||
| Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape" + std::to_string(i) + "_in"); | |||
| reshape->inputs.push_back(reshape_in); | |||
| reshape->outputs.push_back(out); | |||
| op->outputs[i] = reshape_in; | |||
| out->producer = reshape; | |||
| reshape_in->producer = op; | |||
| reshape_in->consumers.push_back(reshape); | |||
| reshape->params["shape"] = out->shape; | |||
| } | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,25 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_unbind(Graph& graph); | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -171,10 +171,14 @@ pnnx_add_test(torch_chunk) | |||
| pnnx_add_test(torch_clamp) | |||
| pnnx_add_test(torch_clone) | |||
| pnnx_add_test(torch_flatten) | |||
| pnnx_add_test(torch_full) | |||
| pnnx_add_test(torch_full_like) | |||
| pnnx_add_test(torch_logsumexp) | |||
| pnnx_add_test(torch_matmul) | |||
| pnnx_add_test(torch_mean) | |||
| pnnx_add_test(torch_norm) | |||
| pnnx_add_test(torch_ones) | |||
| pnnx_add_test(torch_ones_like) | |||
| pnnx_add_test(torch_permute) | |||
| pnnx_add_test(torch_prod) | |||
| pnnx_add_test(torch_sum) | |||
| @@ -184,6 +188,8 @@ pnnx_add_test(torch_stack) | |||
| pnnx_add_test(torch_transpose) | |||
| pnnx_add_test(torch_unbind) | |||
| pnnx_add_test(torch_unsqueeze) | |||
| pnnx_add_test(torch_zeros) | |||
| pnnx_add_test(torch_zeros_like) | |||
| pnnx_add_test(mobilenet_v2) | |||
| pnnx_add_test(mobilenet_v3_small) | |||
| @@ -139,6 +139,7 @@ pnnx_ncnn_add_test(torch_prod) | |||
| pnnx_ncnn_add_test(torch_sum) | |||
| pnnx_ncnn_add_test(torch_squeeze) | |||
| pnnx_ncnn_add_test(torch_transpose) | |||
| pnnx_ncnn_add_test(torch_unbind) | |||
| pnnx_ncnn_add_test(torch_unsqueeze) | |||
| pnnx_ncnn_add_test(mobilenet_v2) | |||
| @@ -0,0 +1,71 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x0, x1, x2 = torch.unbind(x, dim=0) | |||
| y0, y1, y2, y3, y4, y5, y6, y7, y8 = torch.unbind(y, dim=1) | |||
| x0 = F.relu(x0) | |||
| x1 = F.relu(x1) | |||
| y0 = F.relu(y0) | |||
| y1 = F.relu(y1) | |||
| y2 = F.relu(y2) | |||
| y3 = F.relu(y3) | |||
| y4 = F.relu(y4) | |||
| y5 = F.relu(y5) | |||
| y6 = F.relu(y6) | |||
| y7 = F.relu(y7) | |||
| y8 = F.relu(y8) | |||
| return x0, x1, y0, y1, y2, y3, y4, y5, y6, y7, y8 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_torch_unbind.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_unbind.pt inputshape=[3,16],[5,9,11]") | |||
| # ncnn inference | |||
| import test_torch_unbind_ncnn | |||
| b = test_torch_unbind_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.full(x.size(), 1.5) | |||
| y = torch.full(y.size(), 3) | |||
| z = torch.full(z.size(), -2.2) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_full.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_full.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_full_pnnx | |||
| b = test_torch_full_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.full_like(x, 1.5) | |||
| y = torch.full_like(y, 3) | |||
| z = torch.full_like(z, -2.2) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_full_like.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_full_like.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_full_like_pnnx | |||
| b = test_torch_full_like_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.ones(x.size()) | |||
| y = torch.ones(y.size()) | |||
| z = torch.ones(z.size()) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_ones.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_ones.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_ones_pnnx | |||
| b = test_torch_ones_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.ones_like(x) | |||
| y = torch.ones_like(y) | |||
| z = torch.ones_like(z) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_ones_like.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_ones_like.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_ones_like_pnnx | |||
| b = test_torch_ones_like_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.zeros(x.size()) | |||
| y = torch.zeros(y.size()) | |||
| z = torch.zeros(z.size()) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_zeros.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_zeros.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_zeros_pnnx | |||
| b = test_torch_zeros_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.zeros_like(x) | |||
| y = torch.zeros_like(y) | |||
| z = torch.zeros_like(z) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_zeros_like.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_zeros_like.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_zeros_like_pnnx | |||
| b = test_torch_zeros_like_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||