Tensor.fill Tensor.index_put Tensor.to Tensor.type_as torch.topk fmod call Tensor member functions with inputnames static shape_as_tensor nn.Linear dynamic bias eliminate noop type_as convert two-dim nn.Linear to ncnn gemm convert torch.stack to ncnn concat+reshape ignore torch einsum path inputtags/20230816
| @@ -180,7 +180,9 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/Tensor_copy.cpp | |||
| pass_level2/Tensor_expand.cpp | |||
| pass_level2/Tensor_expand_as.cpp | |||
| pass_level2/Tensor_fill.cpp | |||
| pass_level2/Tensor_index.cpp | |||
| pass_level2/Tensor_index_put.cpp | |||
| pass_level2/Tensor_masked_fill.cpp | |||
| pass_level2/Tensor_new_empty.cpp | |||
| pass_level2/Tensor_new_ones.cpp | |||
| @@ -189,6 +191,8 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/Tensor_reshape.cpp | |||
| pass_level2/Tensor_select.cpp | |||
| pass_level2/Tensor_slice.cpp | |||
| pass_level2/Tensor_to.cpp | |||
| pass_level2/Tensor_type_as.cpp | |||
| pass_level2/Tensor_view.cpp | |||
| pass_level2/torch_addmm.cpp | |||
| pass_level2/torch_amax.cpp | |||
| @@ -252,6 +256,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_sum.cpp | |||
| pass_level2/torch_permute.cpp | |||
| pass_level2/torch_tensor_split.cpp | |||
| pass_level2/torch_topk.cpp | |||
| pass_level2/torch_transpose.cpp | |||
| pass_level2/torch_unbind.cpp | |||
| pass_level2/torch_unsqueeze.cpp | |||
| @@ -320,6 +325,7 @@ set(pnnx_pass_level5_SRCS | |||
| pass_level5/eliminate_noop_slice.cpp | |||
| pass_level5/eliminate_noop_view_reshape.cpp | |||
| pass_level5/eliminate_reshape_shape_expression.cpp | |||
| pass_level5/eliminate_type_as.cpp | |||
| pass_level5/eval_expression.cpp | |||
| pass_level5/fold_constants.cpp | |||
| pass_level5/fuse_adjacent_reshape.cpp | |||
| @@ -361,6 +367,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/convert_torch_chunk.cpp | |||
| pass_ncnn/convert_torch_einsum.cpp | |||
| pass_ncnn/convert_torch_split.cpp | |||
| pass_ncnn/convert_torch_stack.cpp | |||
| pass_ncnn/convert_torch_tensor_split.cpp | |||
| pass_ncnn/convert_torch_unbind.cpp | |||
| pass_ncnn/convert_Tensor_select.cpp | |||
| @@ -1297,10 +1297,12 @@ static std::string expand_expression(const Operator* op) | |||
| exprstack.push(r); | |||
| } | |||
| else if (t == "atan2" | |||
| || t == "fmod" | |||
| || t == "pow") | |||
| { | |||
| std::string binaryop; | |||
| if (t == "atan2") binaryop = "torch.atan2"; | |||
| if (t == "fmod") binaryop = "torch.fmod"; | |||
| if (t == "pow") binaryop = "torch.pow"; | |||
| std::string a = exprstack.top(); | |||
| @@ -1311,7 +1313,7 @@ static std::string expand_expression(const Operator* op) | |||
| std::string r = binaryop + "(" + a + ", " + b + ")"; | |||
| exprstack.push(r); | |||
| } | |||
| else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") | |||
| else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") | |||
| { | |||
| std::string binaryop; | |||
| if (t == "add") binaryop = "+"; | |||
| @@ -1319,6 +1321,7 @@ static std::string expand_expression(const Operator* op) | |||
| if (t == "mul") binaryop = "*"; | |||
| if (t == "div") binaryop = "/"; | |||
| if (t == "floor_divide") binaryop = "//"; | |||
| if (t == "remainder") binaryop = "%"; | |||
| if (t == "and") binaryop = "&"; | |||
| if (t == "or") binaryop = "|"; | |||
| if (t == "xor") binaryop = "^"; | |||
| @@ -2152,11 +2155,39 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| if (op->type.substr(0, 7) == "Tensor.") | |||
| { | |||
| fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); | |||
| if (op->type == "Tensor.fill") | |||
| { | |||
| fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str()); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); | |||
| } | |||
| if (op->inputnames.size() == op->inputs.size()) | |||
| { | |||
| for (size_t i = 1; i < op->inputs.size(); i++) | |||
| { | |||
| if (!op->inputnames[i].empty()) | |||
| continue; | |||
| for (size_t i = 1; i < op->inputs.size(); i++) | |||
| fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); | |||
| } | |||
| for (size_t i = 1; i < op->inputs.size(); i++) | |||
| { | |||
| if (op->inputnames[i].empty()) | |||
| continue; | |||
| fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); | |||
| for (size_t i = 1; i < op->inputs.size(); i++) | |||
| { | |||
| fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| @@ -39,7 +39,8 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch | |||
| || optype == "aten::empty_like" | |||
| || optype == "aten::full_like" | |||
| || optype == "aten::ones_like" | |||
| || optype == "aten::zeros_like") | |||
| || optype == "aten::zeros_like" | |||
| || optype == "aten::_shape_as_tensor") | |||
| return false; | |||
| } | |||
| @@ -39,10 +39,10 @@ public: | |||
| op->params["in_features"] = weight.size(1); | |||
| op->params["out_features"] = weight.size(0); | |||
| op->params["bias"] = mod.hasattr("bias"); | |||
| op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| if (mod.hasattr("bias") && mod.attr("bias").isTensor()) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class Tensor_fill : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 value | |||
| aten::fill op_0 2 1 input value out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Tensor.fill"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,43 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class Tensor_index_put : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 6 5 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 indices | |||
| pnnx.Input input_2 0 1 values | |||
| prim::Constant op_0 0 1 accumulate value=%accumulate | |||
| aten::index_put op_1 4 1 input indices values accumulate out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Tensor.index_put"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20) | |||
| } // namespace pnnx | |||
| @@ -26,7 +26,7 @@ public: | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 mask | |||
| pnnx.Input input_2 0 1 value | |||
| aten::masked_fill op_1 3 1 input mask value out | |||
| aten::masked_fill op_0 3 1 input mask value out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| @@ -0,0 +1,89 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class Tensor_to : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 dtype value=%dtype | |||
| prim::Constant op_1 0 1 non_blocking value=* | |||
| prim::Constant op_2 0 1 copy value=%copy | |||
| prim::Constant op_3 0 1 memory_format value=%memory_format | |||
| aten::to op_4 5 1 input dtype non_blocking copy memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Tensor.to"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; | |||
| if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; | |||
| if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; | |||
| if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; | |||
| if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; | |||
| if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; | |||
| if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; | |||
| if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; | |||
| if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; | |||
| if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; | |||
| if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; | |||
| if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; | |||
| op->params["copy"] = captured_params.at("copy"); | |||
| if (captured_params.at("memory_format").i == 0) | |||
| op->params["memory_format"] = "torch.contiguous_format"; | |||
| if (captured_params.at("memory_format").i == 1) | |||
| op->params["memory_format"] = "torch.preserve_format"; | |||
| if (captured_params.at("memory_format").i == 2) | |||
| op->params["memory_format"] = "torch.channels_last"; | |||
| } | |||
| }; | |||
| class Tensor_to_1 : public Tensor_to | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 8 7 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 device value=* | |||
| prim::Constant op_1 0 1 dtype value=%dtype | |||
| prim::Constant op_2 0 1 non_blocking value=* | |||
| prim::Constant op_3 0 1 copy value=%copy | |||
| prim::Constant op_4 0 1 memory_format value=%memory_format | |||
| aten::to op_5 6 1 input device dtype non_blocking copy memory_format out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,41 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class Tensor_type_as : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 other | |||
| aten::type_as op_0 2 1 input other out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Tensor.type_as"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20) | |||
| } // namespace pnnx | |||
| @@ -47,7 +47,7 @@ public: | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 equation | |||
| pnnx.Input input_1 0 1 operands | |||
| prim::Constant op_0 0 1 path value=None | |||
| pnnx.Input input_2 0 1 path | |||
| aten::einsum op_1 3 1 equation operands path out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| @@ -57,6 +57,13 @@ pnnx.Output output 1 0 out | |||
| { | |||
| return "torch.einsum"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| // drop path input | |||
| op->inputs[2]->remove_consumer(op); | |||
| op->inputs.resize(2); | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum_1, 20) | |||
| @@ -0,0 +1,44 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_topk : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 7 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 k | |||
| pnnx.Input input_2 0 1 dim | |||
| pnnx.Input input_3 0 1 largest | |||
| pnnx.Input input_4 0 1 sorted | |||
| aten::topk op_0 5 2 input k dim largest sorted values indices | |||
| pnnx.Output output 2 0 values indices | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.topk"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_topk, 20) | |||
| } // namespace pnnx | |||
| @@ -100,6 +100,7 @@ static bool operand_maybe_tensor(const Operand* operand) | |||
| if (op->type == "aten::atan2" | |||
| || op->type == "aten::div" | |||
| || op->type == "aten::floor_divide" | |||
| || op->type == "aten::fmod" | |||
| || op->type == "aten::mul" | |||
| || op->type == "aten::pow" | |||
| || op->type == "aten::remainder") | |||
| @@ -363,7 +364,35 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s | |||
| fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); | |||
| expr += ")"; | |||
| } | |||
| else if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit") | |||
| else if (op->type == "Tensor.to") | |||
| { | |||
| bool noop_type_cast = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); | |||
| if (noop_type_cast) | |||
| { | |||
| fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); | |||
| } | |||
| else | |||
| { | |||
| auto it = std::find(inputs.begin(), inputs.end(), operand); | |||
| if (it == inputs.end()) | |||
| { | |||
| // tensor | |||
| char tmp[32]; | |||
| sprintf(tmp, "@%d", (int)inputs.size()); | |||
| expr += tmp; | |||
| inputs.push_back(operand); | |||
| } | |||
| else | |||
| { | |||
| // tensor | |||
| char tmp[32]; | |||
| sprintf(tmp, "@%d", (int)(it - inputs.begin())); | |||
| expr += tmp; | |||
| } | |||
| } | |||
| } | |||
| else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") | |||
| { | |||
| fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); | |||
| } | |||
| @@ -402,8 +431,8 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s | |||
| expr += ")"; | |||
| } | |||
| else if (op->type == "aten::atan2" | |||
| || op->type == "aten::div" | |||
| || op->type == "aten::floor_divide" | |||
| || op->type == "aten::fmod" | |||
| || op->type == "aten::mul" | |||
| || op->type == "aten::pow" | |||
| || op->type == "aten::remainder") | |||
| @@ -484,6 +513,27 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s | |||
| fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); | |||
| expr += ")"; | |||
| } | |||
| else if (op->type == "aten::div") | |||
| { | |||
| std::string rounding_mode; | |||
| if (op->inputs.size() == 3) | |||
| fuse_expression(graph, op->inputs[2], rounding_mode, inputs, foldable_constants, zip); | |||
| if (rounding_mode == "trunc") | |||
| { | |||
| expr += "floor_divide"; | |||
| } | |||
| else | |||
| { | |||
| expr += "div"; | |||
| } | |||
| expr += "("; | |||
| fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip); | |||
| expr += ","; | |||
| fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip); | |||
| expr += ")"; | |||
| } | |||
| else | |||
| { | |||
| auto it = std::find(inputs.begin(), inputs.end(), operand); | |||
| @@ -542,7 +592,13 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan | |||
| { | |||
| need_fuse = true; | |||
| } | |||
| if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit") | |||
| if (op->type == "Tensor.to") | |||
| { | |||
| // fuse noop type cast only | |||
| bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type); | |||
| need_fuse = noop_to; | |||
| } | |||
| if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit") | |||
| { | |||
| need_fuse = true; | |||
| } | |||
| @@ -562,6 +618,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan | |||
| || op->type == "aten::exp" | |||
| || op->type == "aten::floor" | |||
| || op->type == "aten::floor_divide" | |||
| || op->type == "aten::fmod" | |||
| || op->type == "aten::log" | |||
| || op->type == "aten::log10" | |||
| || op->type == "aten::mul" | |||
| @@ -27,6 +27,7 @@ | |||
| #include "pass_level5/eliminate_noop_slice.h" | |||
| #include "pass_level5/eliminate_noop_view_reshape.h" | |||
| #include "pass_level5/eliminate_reshape_shape_expression.h" | |||
| #include "pass_level5/eliminate_type_as.h" | |||
| #include "pass_level5/eval_expression.h" | |||
| #include "pass_level5/fuse_adjacent_reshape.h" | |||
| #include "pass_level5/fuse_channel_shuffle.h" | |||
| @@ -112,6 +113,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons | |||
| eliminate_noop_cat(g); | |||
| eliminate_dropout(g); | |||
| eliminate_type_as(g); | |||
| eliminate_noop_upsample(g); | |||
| @@ -0,0 +1,84 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "eliminate_type_as.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void eliminate_type_as(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "Tensor.type_as") | |||
| continue; | |||
| if (op->inputs[0]->type == 0 || op->outputs[0]->type == 0) | |||
| continue; | |||
| if (op->inputs[0]->type != op->outputs[0]->type) | |||
| continue; | |||
| // delete noop-like type_as | |||
| matched = true; | |||
| for (auto& x : op->inputs) | |||
| { | |||
| x->remove_consumer(op); | |||
| } | |||
| Operand* type_as_out = op->outputs[0]; | |||
| for (auto& x : type_as_out->consumers) | |||
| { | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j] == type_as_out) | |||
| x->inputs[j] = op->inputs[0]; | |||
| } | |||
| op->inputs[0]->consumers.push_back(x); | |||
| } | |||
| op->inputs[0]->name = type_as_out->name; | |||
| type_as_out->producer = 0; | |||
| type_as_out->consumers.clear(); | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), type_as_out)); | |||
| delete type_as_out; | |||
| op->inputs.clear(); | |||
| op->outputs.clear(); | |||
| graph.ops.erase(graph.ops.begin() + i); | |||
| delete op; | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void eliminate_type_as(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -342,6 +342,7 @@ static std::string eval_expression(const Operator* op) | |||
| || t == "mul" | |||
| || t == "div" | |||
| || t == "floor_divide" | |||
| || t == "fmod" | |||
| || t == "pow" | |||
| || t == "remainder") | |||
| { | |||
| @@ -380,6 +381,11 @@ static std::string eval_expression(const Operator* op) | |||
| float r = af / bf; | |||
| exprstack.push(std::to_string(r)); | |||
| } | |||
| if (t == "fmod") | |||
| { | |||
| float r = fmod(af, bf); | |||
| exprstack.push(std::to_string(r)); | |||
| } | |||
| if (t == "floor_divide") | |||
| { | |||
| int r = (int)af / (int)bf; | |||
| @@ -37,6 +37,12 @@ void fuse_select_to_unbind(Graph& graph) | |||
| if (input_rank == 0) | |||
| continue; | |||
| if (input_rank == 1) | |||
| { | |||
| // skip select scalar | |||
| continue; | |||
| } | |||
| int dim = op->params.at("dim").i; | |||
| const int select_dimsize = op_in->shape[dim]; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "pass_ncnn/convert_torch_chunk.h" | |||
| #include "pass_ncnn/convert_torch_einsum.h" | |||
| #include "pass_ncnn/convert_torch_split.h" | |||
| #include "pass_ncnn/convert_torch_stack.h" | |||
| #include "pass_ncnn/convert_torch_tensor_split.h" | |||
| #include "pass_ncnn/convert_torch_unbind.h" | |||
| #include "pass_ncnn/convert_Tensor_select.h" | |||
| @@ -96,6 +97,7 @@ void pass_ncnn(Graph& g) | |||
| ncnn::convert_torch_cat(g); | |||
| ncnn::convert_torch_chunk(g); | |||
| ncnn::convert_torch_stack(g); | |||
| ncnn::convert_torch_split(g); | |||
| ncnn::convert_torch_unbind(g); | |||
| ncnn::convert_torch_tensor_split(g); | |||
| @@ -0,0 +1,91 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "convert_torch_stack.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_stack(Graph& graph) | |||
| { | |||
| int op_index = 0; | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (Operator* op : graph.ops) | |||
| { | |||
| if (op->type != "torch.stack") | |||
| continue; | |||
| matched = true; | |||
| op->type = "Concat"; | |||
| op->name = std::string("stack_") + std::to_string(op_index++); | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| int axis = op->params.at("dim").i; | |||
| if (axis == batch_index) | |||
| { | |||
| fprintf(stderr, "stack along batch axis %d is not supported\n", batch_index); | |||
| continue; | |||
| } | |||
| if (axis < 0) | |||
| { | |||
| int input_rank = op->inputs[0]->shape.size(); | |||
| axis = input_rank + axis; | |||
| } | |||
| if (axis > batch_index) | |||
| axis -= 1; | |||
| op->params["0"] = axis; | |||
| op->params.erase("dim"); | |||
| // reshape for output, expand the stack dim | |||
| { | |||
| Operand* out = op->outputs[0]; | |||
| Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); | |||
| Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); | |||
| reshape->inputs.push_back(reshape_in); | |||
| reshape->outputs.push_back(out); | |||
| op->outputs[0] = reshape_in; | |||
| out->producer = reshape; | |||
| reshape_in->producer = op; | |||
| reshape_in->consumers.push_back(reshape); | |||
| reshape->params["shape"] = out->shape; | |||
| } | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,25 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void convert_torch_stack(Graph& graph); | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -178,7 +178,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| op_unary->inputs.push_back(op_unary_in); | |||
| op_unary->outputs.push_back(op_unary_out); | |||
| } | |||
| else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow" || t == "atan2") | |||
| else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "fmod" || t == "remainder" || t == "pow" || t == "atan2") | |||
| { | |||
| std::string a = exprstack.top(); | |||
| exprstack.pop(); | |||
| @@ -190,10 +190,16 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| Operator* op_binary = graph.new_operator_before("BinaryOp", t + "_" + std::to_string(pnnx_expr_index++), op); | |||
| // default todo type mark :[ | |||
| op_binary->params["0"] = -1; | |||
| if (t == "add") op_binary->params["0"] = 0; | |||
| if (t == "sub") op_binary->params["0"] = 1; | |||
| if (t == "mul") op_binary->params["0"] = 2; | |||
| if (t == "div") op_binary->params["0"] = 3; | |||
| if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO | |||
| if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO | |||
| if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO | |||
| if (t == "pow") op_binary->params["0"] = 6; | |||
| if (t == "atan2") op_binary->params["0"] = 10; | |||
| @@ -94,19 +94,26 @@ void insert_reshape_linear(Graph& graph) | |||
| reshape_h *= linear_in->shape[j]; | |||
| } | |||
| std::vector<int> reshape0_shape; | |||
| std::vector<int> reshape0_out_shape; | |||
| std::vector<int> reshape1_in_shape; | |||
| if (batch_index == 0 && batch_index != 233) | |||
| { | |||
| reshape0_shape = {1, reshape_h, linear_in->shape[input_rank - 1]}; | |||
| reshape0_out_shape = {1, reshape_h, linear_in->shape[input_rank - 1]}; | |||
| reshape1_in_shape = {1, reshape_h, linear_out->shape[input_rank - 1]}; | |||
| } | |||
| else | |||
| { | |||
| reshape0_shape = {reshape_h, linear_in->shape[input_rank - 1]}; | |||
| reshape0_out_shape = {reshape_h, linear_in->shape[input_rank - 1]}; | |||
| reshape1_in_shape = {reshape_h, linear_out->shape[input_rank - 1]}; | |||
| } | |||
| std::vector<int> reshape1_shape = linear_out->shape; | |||
| reshape0->params["shape"] = reshape0_shape; | |||
| reshape1->params["shape"] = reshape1_shape; | |||
| std::vector<int> reshape1_out_shape = linear_out->shape; | |||
| reshape0->params["shape"] = reshape0_out_shape; | |||
| reshape1->params["shape"] = reshape1_out_shape; | |||
| reshape0_out->type = linear_in->type; | |||
| reshape0_out->shape = reshape0_out_shape; | |||
| reshape1_in->type = linear_out->type; | |||
| reshape1_in->shape = reshape1_in_shape; | |||
| break; | |||
| } | |||
| @@ -18,6 +18,152 @@ namespace pnnx { | |||
| namespace ncnn { | |||
| class nn_Linear_0 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input #input=(1,%m,%in_features)f32 | |||
| nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Gemm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "gemm"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| op->params["2"] = 0; | |||
| op->params["3"] = 1; | |||
| op->params["4"] = 0; | |||
| op->params["5"] = 1; | |||
| op->params["6"] = 1; | |||
| op->params["7"] = captured_params.at("m"); | |||
| op->params["8"] = captured_params.at("out_features"); | |||
| op->params["9"] = captured_params.at("in_features"); | |||
| op->params["10"] = captured_params.at("bias").b ? 4 : -1; | |||
| op->attrs["0"] = Attribute(); | |||
| op->attrs["0"].data = {0, 0, 0, 0}; | |||
| op->attrs["1"] = captured_attrs.at("op_0.weight"); | |||
| if (captured_params.at("bias").b) | |||
| { | |||
| op->attrs["2"] = Attribute(); | |||
| op->attrs["2"].data = {0, 0, 0, 0}; | |||
| op->attrs["3"] = captured_attrs.at("op_0.bias"); | |||
| } | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_0, 19) | |||
| class nn_Linear_01 : public nn_Linear_0 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 3 2 | |||
| pnnx.Input input 0 1 input #input=(%m,%in_features)f32 | |||
| nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const int m = captured_params.at("m").i; | |||
| if (m == 1) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_01, 19) | |||
| class nn_Linear_10 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input #input=(1,%m,%in_features)f32 | |||
| pnnx.Input input_1 0 1 bias | |||
| nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Gemm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "gemm"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| op->params["2"] = 0; | |||
| op->params["3"] = 1; | |||
| op->params["4"] = 0; | |||
| op->params["5"] = 1; | |||
| op->params["6"] = 0; | |||
| op->params["7"] = captured_params.at("m"); | |||
| op->params["8"] = captured_params.at("out_features"); | |||
| op->params["9"] = captured_params.at("in_features"); | |||
| op->params["10"] = 4; | |||
| op->attrs["0"] = Attribute(); | |||
| op->attrs["0"].data = {0, 0, 0, 0}; | |||
| op->attrs["1"] = captured_attrs.at("op_0.weight"); | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_10, 19) | |||
| class nn_Linear_11 : public nn_Linear_10 | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input #input=(%m,%in_features)f32 | |||
| pnnx.Input input_1 0 1 bias | |||
| nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| const int m = captured_params.at("m").i; | |||
| if (m == 1) | |||
| return false; | |||
| return true; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_11, 19) | |||
| class nn_Linear : public GraphRewriterPass | |||
| { | |||
| public: | |||
| @@ -57,6 +203,52 @@ pnnx.Output output 1 0 out | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear, 20) | |||
| class nn_Linear_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 bias | |||
| nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* replace_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 bias | |||
| InnerProduct linear 1 1 input a | |||
| BinaryOp bias 2 1 a bias out 0=0 | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| GraphRewriterPass::write(ops, captured_params, captured_attrs); | |||
| const int batch_index = ops.at("linear")->inputs[0]->params["__batch_index"].i; | |||
| ops.at("linear")->params["0"] = captured_params.at("out_features"); | |||
| ops.at("linear")->params["1"] = 0; | |||
| ops.at("linear")->params["2"] = captured_attrs.at("op_0.weight").elemcount(); | |||
| ops.at("linear")->attrs["0"] = Attribute(); | |||
| ops.at("linear")->attrs["0"].data = {0, 0, 0, 0}; | |||
| ops.at("linear")->attrs["1"] = captured_attrs.at("op_0.weight"); | |||
| ops.at("linear")->outputs[0]->params["__batch_index"] = batch_index; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_1, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -162,7 +162,9 @@ pnnx_add_test(nn_ZeroPad2d) | |||
| pnnx_add_test(Tensor_contiguous) | |||
| pnnx_add_test(Tensor_expand) | |||
| pnnx_add_test(Tensor_fill) | |||
| pnnx_add_test(Tensor_index) | |||
| pnnx_add_test(Tensor_index_put) | |||
| pnnx_add_test(Tensor_masked_fill) | |||
| pnnx_add_test(Tensor_new_empty) | |||
| pnnx_add_test(Tensor_new_full) | |||
| @@ -173,6 +175,8 @@ pnnx_add_test(Tensor_reshape) | |||
| pnnx_add_test(Tensor_select) | |||
| pnnx_add_test(Tensor_slice) | |||
| pnnx_add_test(Tensor_slice_copy) | |||
| pnnx_add_test(Tensor_to) | |||
| pnnx_add_test(Tensor_type_as) | |||
| pnnx_add_test(Tensor_view) | |||
| pnnx_add_test(torch_addmm) | |||
| @@ -221,6 +225,7 @@ pnnx_add_test(torch_squeeze) | |||
| pnnx_add_test(torch_stack) | |||
| pnnx_add_test(torch_std) | |||
| pnnx_add_test(torch_tensor_split) | |||
| pnnx_add_test(torch_topk) | |||
| pnnx_add_test(torch_transpose) | |||
| pnnx_add_test(torch_unbind) | |||
| pnnx_add_test(torch_unsqueeze) | |||
| @@ -295,6 +300,7 @@ pnnx_add_test(pnnx_eliminate_noop_cat) | |||
| pnnx_add_test(pnnx_eliminate_noop_expand) | |||
| pnnx_add_test(pnnx_eliminate_noop_math) | |||
| pnnx_add_test(pnnx_eliminate_noop_upsample) | |||
| pnnx_add_test(pnnx_expression) | |||
| pnnx_add_test(pnnx_fold_constant) | |||
| pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) | |||
| @@ -154,6 +154,7 @@ pnnx_ncnn_add_test(torch_permute) | |||
| pnnx_ncnn_add_test(torch_prod) | |||
| pnnx_ncnn_add_test(torch_sum) | |||
| pnnx_ncnn_add_test(torch_squeeze) | |||
| pnnx_ncnn_add_test(torch_stack) | |||
| pnnx_ncnn_add_test(torch_tensor_split) | |||
| pnnx_ncnn_add_test(torch_transpose) | |||
| pnnx_ncnn_add_test(torch_unbind) | |||
| @@ -0,0 +1,60 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| out0 = torch.stack((x, y), dim=0) | |||
| out1 = torch.stack((z, w), dim=2) | |||
| out0.relu_() | |||
| out1.relu_() | |||
| return out0, out1 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(3, 16) | |||
| z = torch.rand(5, 9, 3) | |||
| w = torch.rand(5, 9, 3) | |||
| a0, a1 = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_torch_stack.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_stack.pt inputshape=[3,16],[3,16],[5,9,3],[5,9,3]") | |||
| # ncnn inference | |||
| import test_torch_stack_ncnn | |||
| b0, b1 = test_torch_stack_ncnn.test_inference() | |||
| return torch.equal(a0, b0) and torch.equal(a1, b1) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,57 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x[:2,:].fill_(z[0]) | |||
| y[:1,:].fill_(0.22) | |||
| return x + y.fill_(7) | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(6, 16) | |||
| y = torch.rand(6, 16) | |||
| z = torch.rand(1) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_Tensor_fill.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_Tensor_fill.pt inputshape=[6,16],[6,16],[1]") | |||
| # pnnx inference | |||
| import test_Tensor_fill_pnnx | |||
| b = test_Tensor_fill_pnnx.test_inference() | |||
| return torch.equal(a, b) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w): | |||
| x = x.clone() | |||
| z = z.clone() | |||
| x = x.index_put(indices=[torch.tensor([10,2])], values=y, accumulate=False) | |||
| z.index_put_(indices=[torch.tensor([1,0,0]), torch.tensor([3,2,1])], values=w, accumulate=True) | |||
| return x, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(12) | |||
| y = torch.rand(2) | |||
| z = torch.rand(6,9) | |||
| w = torch.rand(3) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| mod.save("test_Tensor_index_put.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_Tensor_index_put.pt inputshape=[12],[2],[6,9],[3]") | |||
| # pnnx inference | |||
| import test_Tensor_index_put_pnnx | |||
| b = test_Tensor_index_put_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,63 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y): | |||
| x = x * 10 | |||
| y = y * 13 | |||
| y = y.to(dtype=x.dtype, memory_format=torch.contiguous_format) | |||
| x = x.to(device='cpu', dtype=torch.int, copy=True) | |||
| x = x + 1 | |||
| y = y - 2 | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.randint(10, (1, 13), dtype=torch.int) | |||
| a = net(x, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y)) | |||
| mod.save("test_Tensor_to.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_Tensor_to.pt inputshape=[3,16],[1,13]i32") | |||
| # pnnx inference | |||
| import test_Tensor_to_pnnx | |||
| b = test_Tensor_to_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,65 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = x * 100 | |||
| z = z * 200 | |||
| x = x.type_as(y) | |||
| x = F.relu(x) | |||
| x = x.type_as(z) | |||
| z = F.relu(z) | |||
| z = z.type_as(x) | |||
| return x, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.randint(10, (1, 13), dtype=torch.int) | |||
| z = torch.rand(8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_Tensor_type_as.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_Tensor_type_as.pt inputshape=[3,16],[1,13]i32,[8,5,9,10]") | |||
| # pnnx inference | |||
| import test_Tensor_type_as_pnnx | |||
| b = test_Tensor_type_as_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,75 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from packaging import version | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.w0 = nn.Parameter(torch.rand(12, 15)) | |||
| self.w1 = nn.Parameter(torch.rand(12, 15)) | |||
| self.w2 = nn.Parameter(torch.rand(12, 15)) | |||
| self.w3 = nn.Parameter(torch.rand(12, 15)) | |||
| self.w4 = nn.Parameter(torch.rand(12, 15)) | |||
| self.w5 = nn.Parameter(torch.rand(12, 15)) | |||
| def forward(self, x): | |||
| x0 = x * 10 | |||
| x = x + self.w0 + x0 | |||
| x = x - self.w1 + x0.float() | |||
| x = x * self.w2 + x0 | |||
| x = x / self.w3 + x0 | |||
| x = x // self.w4 + x0 | |||
| if version.parse(torch.__version__) >= version.parse('2.0'): | |||
| x = x % self.w5 + x0 | |||
| else: | |||
| x = torch.fmod(x, self.w5) + x0 | |||
| y = x.int() | |||
| return x, y & 3, y | 3, y ^ 3, y << 3, y >> 3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(12, 15) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_pnnx_expression.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_expression.pt inputshape=[12,15]") | |||
| # pnnx inference | |||
| import test_pnnx_expression_pnnx | |||
| b = test_pnnx_expression_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -148,7 +148,10 @@ def test(): | |||
| b = test_torch_einsum_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| # allclose may auto broadcast compare | |||
| if a0.shape != b0.shape: | |||
| return False | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x, _ = torch.topk(x, 4) | |||
| y, _ = torch.topk(y, k=1, dim=2, largest=False) | |||
| z, indices = torch.topk(z, k=3, dim=-1, sorted=False) | |||
| return x, y, z, indices | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_topk.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_topk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_topk_pnnx | |||
| b = test_torch_topk_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||