diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index d0114ce42..35fc31683 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -59,7 +59,7 @@ set(pnnx_pass_level1_SRCS pass_level1/nn_MaxPool1d.cpp pass_level1/nn_MaxPool2d.cpp pass_level1/nn_MaxPool3d.cpp - pass_level1/nn_maxunpool2d.cpp + #pass_level1/nn_maxunpool2d.cpp pass_level1/nn_Mish.cpp pass_level1/nn_MultiheadAttention.cpp pass_level1/nn_PixelShuffle.cpp @@ -186,6 +186,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_clone.cpp pass_level2/torch_dequantize.cpp pass_level2/torch_flatten.cpp + pass_level2/torch_flip.cpp pass_level2/torch_logsumexp.cpp pass_level2/torch_matmul.cpp pass_level2/torch_mean.cpp @@ -193,6 +194,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_normal.cpp pass_level2/torch_prod.cpp pass_level2/torch_quantize_per_tensor.cpp + pass_level2/torch_randn.cpp pass_level2/torch_roll.cpp pass_level2/torch_split.cpp pass_level2/torch_squeeze.cpp @@ -200,6 +202,7 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_sum.cpp pass_level2/torch_permute.cpp pass_level2/torch_transpose.cpp + pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp pass_level2/torch_var.cpp pass_level2/torch_zeros.cpp @@ -210,11 +213,11 @@ set(pnnx_pass_level2_SRCS set(pnnx_pass_level3_SRCS pass_level3/assign_unique_name.cpp + pass_level3/eliminate_noop_math.cpp pass_level3/eliminate_tuple_pair.cpp pass_level3/expand_quantization_modules.cpp - pass_level3/fuse_attribute_expression.cpp pass_level3/fuse_cat_stack_tensors.cpp - pass_level3/fuse_chunk_split_unpack.cpp + pass_level3/fuse_chunk_split_unbind_unpack.cpp pass_level3/fuse_expression.cpp pass_level3/fuse_index_expression.cpp pass_level3/fuse_rnn_unpack.cpp @@ -235,6 +238,7 @@ set(pnnx_pass_level5_SRCS pass_level5/eliminate_slice.cpp pass_level5/eliminate_view_reshape.cpp pass_level5/eval_expression.cpp + pass_level5/fold_constants.cpp pass_level5/fuse_channel_shuffle.cpp pass_level5/fuse_constant_expression.cpp pass_level5/fuse_conv1d_batchnorm1d.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index a83771368..d776f73c9 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1343,7 +1343,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, ","); } - fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + if (attr.type == 1 || attr.type == 2 || attr.type == 3) + { + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + else + { + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } } } @@ -1373,11 +1380,11 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) if (is_running_mean_var) { - fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); + fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } else { - fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); + fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); } for (size_t i = 0; i < attr.shape.size(); i++) @@ -1387,7 +1394,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, ","); } - fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + if (attr.type == 1 || attr.type == 2 || attr.type == 3) + { + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + else + { + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } } fprintf(pyfp, " archive.close()\n"); @@ -1452,7 +1466,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) else if (op->type == "pnnx.Attribute") { const std::string& key = op->attrs.begin()->first; - fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); + fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); } else if (op->type == "Tensor.slice") { @@ -1463,8 +1477,16 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) else if (op->type == "Tensor.index") { // index expr - std::string index_expr = make_index_expression(op); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + if (op->inputs.size() == 2) + { + std::string expanded_expr = expand_expression(op->inputs[1]->producer); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); + } + else + { + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + } } else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") { diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 8dc7c1815..07730dfa7 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -279,9 +279,6 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); } - // at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); - // torch::autograd::AutoGradMode guard(false); - for (auto m : customop_modules) { fprintf(stderr, "load custom module %s\n", m.c_str()); @@ -339,7 +336,8 @@ int main(int argc, char** argv) fprintf(stderr, "############# pass_level0\n"); - pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators); + std::map foldable_constants; + pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants); // g->dump(); @@ -373,7 +371,7 @@ int main(int argc, char** argv) { fprintf(stderr, "############# pass_level5\n"); - pnnx::pass_level5(pnnx_graph); + pnnx::pass_level5(pnnx_graph, foldable_constants); } pnnx_graph.save(pnnxparampath, pnnxbinpath); diff --git a/tools/pnnx/src/pass_level0.cpp b/tools/pnnx/src/pass_level0.cpp index d09842397..d50f71bbe 100644 --- a/tools/pnnx/src/pass_level0.cpp +++ b/tools/pnnx/src/pass_level0.cpp @@ -20,7 +20,7 @@ namespace pnnx { -void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators) +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants) { inline_block(g, module_operators); @@ -28,7 +28,7 @@ void pass_level0(const torch::jit::Module& mod, std::shared_ptr +#include "ir.h" namespace pnnx { -void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators); +void pass_level0(const torch::jit::Module& mod, std::shared_ptr& g, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index 326f87995..51ed9a9ff 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -13,74 +13,233 @@ // specific language governing permissions and limitations under the License. #include "shape_inference.h" +#include + +#include "pass_level0/constant_unpooling.h" +#include "pass_level0/inline_block.h" +#include "pass_level0/shape_inference.h" namespace pnnx { -void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2) +static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs) { - // collect all intermediate output tensors - std::vector values; - for (const auto& n : graph->nodes()) + for (auto x : inputs) { - for (const auto& on : n->outputs()) - { - auto tensor_type = on->type()->cast(); - if (!tensor_type) - continue; + if (v == x) + return true; + } + + for (size_t i = 0; i < v->node()->inputs().size(); i++) + { + bool link = value_link_input(v->node()->inputs()[i], inputs); + if (link) + return true; + } + + return false; +} + +static bool value_link_output(const torch::jit::Value* v, const std::vector& outputs) +{ + for (auto x : outputs) + { + if (v == x) + return true; + } - values.push_back(on); + for (size_t i = 0; i < v->uses().size(); i++) + { + auto node = v->uses()[i].user; + for (auto x : node->outputs()) + { + bool link = value_link_output(x, outputs); + if (link) + return true; } } - // set new graph output - auto old_output = graph->outputs()[0]; + return false; +} - torch::jit::Node* new_return_node = graph->createTuple(at::ArrayRef(values)); +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants) +{ + // collect all intermediate output tensors + std::vector > more_value_names; + std::vector > more_values; + { + std::unordered_set value_names; + std::vector values; + for (const auto& n : graph->nodes()) + { + for (const auto& v : n->outputs()) + { + auto tensor_type = v->type()->cast(); + if (!tensor_type) + continue; - graph->appendNode(new_return_node); + value_names.insert(v->debugName()); + values.push_back(v); + } - graph->eraseOutput(0); - graph->registerOutput(new_return_node->outputs()[0]); + // too many intermediate blobs in one inference results oom + if (value_names.size() >= 1000) + { + more_value_names.push_back(value_names); + value_names.clear(); + + more_values.push_back(values); + values.clear(); + } + } + + if (value_names.size() > 0) + { + more_value_names.push_back(value_names); + more_values.push_back(values); + } + } + + // collect graph inputs outputs + std::vector g_inputs; + for (size_t i = 1; i < graph->inputs().size(); i++) + { + g_inputs.push_back(graph->inputs()[i]); + } + std::vector g_outputs; + for (size_t i = 0; i < graph->outputs().size(); i++) + { + g_outputs.push_back(graph->outputs()[i]); + } - // inference for all tensors std::vector inputs; for (size_t i = 0; i < input_tensors.size(); i++) { const at::Tensor& it = input_tensors[i]; - inputs.push_back(it); - graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); } - auto outputs = mod.copy().forward(inputs).toTuple(); + std::vector inputs2; + for (size_t i = 0; i < input_tensors2.size(); i++) + { + const at::Tensor& it = input_tensors2[i]; + inputs2.push_back(it); + } - if (input_tensors2.empty()) + std::map output_tensors; + + for (size_t p = 0; p < more_value_names.size(); p++) { - // assign shape info - int index = 0; - for (auto e : outputs->elements()) + std::unordered_set& value_names = more_value_names[p]; + std::vector& values = more_values[p]; + + // auto mod2 = mod.deepcopy(); + + torch::jit::Module mod2 = torch::jit::load(ptpath); + mod2.eval(); + + auto graph2 = mod2.get_method("forward").graph(); + + inline_block(graph2, module_operators); + + constant_unpooling(graph2); + + std::vector values2; + for (auto n : graph2->nodes()) { - values[index]->setType(c10::TensorType::create(e.toTensor())); + for (const auto& v : n->outputs()) + { + auto tensor_type = v->type()->cast(); + if (!tensor_type) + continue; - index++; + if (value_names.find(v->debugName()) != value_names.end()) + { + values2.push_back(v); + fprintf(stderr, "%s ", v->debugName().c_str()); + } + } } - } - else - { - std::vector inputs2; - for (size_t i = 0; i < input_tensors2.size(); i++) + fprintf(stderr, "\n----------------\n\n"); + + // set new graph output + torch::jit::Node* new_return_node = graph2->createTuple(at::ArrayRef(values2)); + + graph2->appendNode(new_return_node); + + graph2->eraseOutput(0); + graph2->registerOutput(new_return_node->outputs()[0]); + + // inference for all tensors + auto outputs = mod2.copy().forward(inputs).toTuple(); + + if (input_tensors2.empty()) { - const at::Tensor& it = input_tensors2[i]; + // assign shape info + for (size_t i = 0; i < values2.size(); i++) + { + auto v = values[i]; + auto t = outputs->elements()[i].toTensor(); + + v->setType(c10::TensorType::create(t)); - inputs2.push_back(it); - graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); + // check if value that does not depend on inputs + if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + { + output_tensors[v] = t; + } + } } + else + { + // assign dynamic shape info + auto outputs2 = mod2.copy().forward(inputs2).toTuple(); + + fprintf(stderr, "assign dynamic shape info\n"); - auto outputs2 = mod.copy().forward(inputs2).toTuple(); + for (size_t i = 0; i < values2.size(); i++) + { + auto v = values[i]; + auto t = outputs->elements()[i].toTensor(); + auto t2 = outputs2->elements()[i].toTensor(); - fprintf(stderr, "assign dynamic shape info\n"); + auto type1 = c10::TensorType::create(t); + auto type2 = c10::TensorType::create(t2); + + std::vector sizes1 = type1->symbolic_sizes().sizes().value(); + std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + + for (size_t i = 0; i < sizes1.size(); i++) + { + if (sizes1[i] == sizes2[i]) + continue; + + sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + } + + auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); + + v->setType(finaltype); + + // check if value that does not depend on inputs + if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) + { + output_tensors[v] = t; + } + } + } + } + + if (input_tensors2.empty()) + { + for (size_t i = 0; i < input_tensors.size(); i++) + { + auto type = c10::TensorType::create(input_tensors[i]); - // assign dynamic shape info + graph->inputs()[1 + i]->setType(type); + } + } + else + { for (size_t i = 0; i < input_tensors.size(); i++) { auto type1 = c10::TensorType::create(input_tensors[i]); @@ -101,38 +260,34 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrinputs()[1 + i]->setType(finaltype); } + } - int index = 0; - for (auto e : outputs->elements()) - { - auto type1 = c10::TensorType::create(e.toTensor()); - auto type2 = c10::TensorType::create(outputs2->elements()[index].toTensor()); - - std::vector sizes1 = type1->symbolic_sizes().sizes().value(); - std::vector sizes2 = type2->symbolic_sizes().sizes().value(); + for (auto xx : output_tensors) + { + auto v = xx.first; + auto tensor = xx.second; - for (size_t i = 0; i < sizes1.size(); i++) + bool link_to_output = false; + for (size_t i = 0; i < v->uses().size(); i++) + { + auto node = v->uses()[i].user; + for (auto x : node->outputs()) { - if (sizes1[i] == sizes2[i]) - continue; - - sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); + if (output_tensors.find(x) == output_tensors.end()) + { + link_to_output = true; + break; + } } + } - auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); - - values[index]->setType(finaltype); - - index++; + const int ndim = (int)tensor.dim(); + if (link_to_output && ndim > 0) + { + fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); + foldable_constants[v->debugName()] = Attribute(tensor); } } - - // restore old graph output - graph->eraseOutput(0); - graph->registerOutput(old_output); - - new_return_node->removeAllInputs(); - new_return_node->destroy(); } } // namespace pnnx diff --git a/tools/pnnx/src/pass_level0/shape_inference.h b/tools/pnnx/src/pass_level0/shape_inference.h index b1feba801..cf80ade7a 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.h +++ b/tools/pnnx/src/pass_level0/shape_inference.h @@ -13,9 +13,11 @@ // specific language governing permissions and limitations under the License. #include +#include +#include "ir.h" namespace pnnx { -void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2); +void shape_inference(const torch::jit::Module& mod, std::shared_ptr& graph, const std::vector& input_tensors, const std::vector& input_tensors2, const std::vector& module_operators, const std::string& ptpath, std::map& foldable_constants); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_flip.cpp b/tools/pnnx/src/pass_level2/torch_flip.cpp new file mode 100644 index 000000000..3bf95d2b8 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_flip.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_flip : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dims +aten::flip op_0 2 1 input dims out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.flip"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_randn.cpp b/tools/pnnx/src/pass_level2/torch_randn.cpp new file mode 100644 index 000000000..18cc83d04 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_randn.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_randn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 size +prim::Constant op_0 0 1 dtype value=* +prim::Constant op_1 0 1 layout value=* +prim::Constant op_2 0 1 device value=* +prim::Constant op_3 0 1 requires_grad value=* +aten::randn op_4 5 1 size dtype layout device requires_grad out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.randn"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_randn, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_unbind.cpp b/tools/pnnx/src/pass_level2/torch_unbind.cpp new file mode 100644 index 000000000..c973b904b --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_unbind.cpp @@ -0,0 +1,41 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_unbind : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 dim +aten::unbind op_0 2 1 input dim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unbind"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unbind, 20) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3.cpp b/tools/pnnx/src/pass_level3.cpp index 959dcc253..970d9068b 100644 --- a/tools/pnnx/src/pass_level3.cpp +++ b/tools/pnnx/src/pass_level3.cpp @@ -15,11 +15,11 @@ #include "pass_level3.h" #include "pass_level3/assign_unique_name.h" +#include "pass_level3/eliminate_noop_math.h" #include "pass_level3/eliminate_tuple_pair.h" #include "pass_level3/expand_quantization_modules.h" -#include "pass_level3/fuse_attribute_expression.h" #include "pass_level3/fuse_cat_stack_tensors.h" -#include "pass_level3/fuse_chunk_split_unpack.h" +#include "pass_level3/fuse_chunk_split_unbind_unpack.h" #include "pass_level3/fuse_expression.h" #include "pass_level3/fuse_index_expression.h" #include "pass_level3/fuse_rnn_unpack.h" @@ -39,14 +39,12 @@ void pass_level3(Graph& g) fuse_cat_stack_tensors(g); - fuse_chunk_split_unpack(g); + fuse_chunk_split_unbind_unpack(g); fuse_rnn_unpack(g); expand_quantization_modules(g); - fuse_attribute_expression(g); - eliminate_tuple_pair(g); rename_F_conv_transposend(g); @@ -55,6 +53,8 @@ void pass_level3(Graph& g) rename_F_dropoutnd(g); + eliminate_noop_math(g); + fuse_expression(g); fuse_index_expression(g); diff --git a/tools/pnnx/src/pass_level3/assign_unique_name.cpp b/tools/pnnx/src/pass_level3/assign_unique_name.cpp index 7a789e00c..201a8a8e4 100644 --- a/tools/pnnx/src/pass_level3/assign_unique_name.cpp +++ b/tools/pnnx/src/pass_level3/assign_unique_name.cpp @@ -45,33 +45,6 @@ void assign_unique_name(Graph& graph) } } } - - // assign unique name for all operands - { - std::unordered_set names; - int make_unique_index = 0; - - for (size_t i = 0; i < graph.operands.size(); i++) - { - Operand* operand = graph.operands[i]; - const std::string& name = operand->name; - - if (names.find(name) == names.end()) - { - names.insert(name); - } - else - { - // duplicated found - std::string new_name = std::string("pnnx_unique_") + std::to_string(make_unique_index); - fprintf(stderr, "assign unique operand name %s to %s\n", new_name.c_str(), name.c_str()); - operand->name = new_name; - names.insert(new_name); - - make_unique_index++; - } - } - } } } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp b/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp new file mode 100644 index 000000000..5ba1e0e90 --- /dev/null +++ b/tools/pnnx/src/pass_level3/eliminate_noop_math.cpp @@ -0,0 +1,297 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "eliminate_noop_math.h" + +#include +#include "pass_level2.h" +#include "pass_level4/dead_code_elimination.h" + +namespace pnnx { + +static bool constant_is_all_constant(const Operator* op_constant, float vf, int vi) +{ + const Parameter& param = op_constant->params.at("value"); + + if (param.type == 2) + { + if (param.i != vi) + return false; + } + else if (param.type == 3) + { + if (param.f != vf) + return false; + } + else + { + // unsupported data type + return false; + } + + return true; +} + +static bool attribute_is_all_constant(const Operator* op_attr, float vf, int vi) +{ + const Attribute& attr = op_attr->attrs.begin()->second; + + if (attr.shape.empty()) + { + fprintf(stderr, "shape empty!\n"); + return false; + } + + int size = attr.shape[0]; + for (size_t i = 1; i < attr.shape.size(); i++) + { + size *= attr.shape[i]; + } + + if (attr.type == 1) + { + const float* p = (const float*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vf) + return false; + } + } + else if (attr.type == 2) + { + const double* p = (const double*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vf) + return false; + } + } + else if (attr.type == 4) + { + const int* p = (const int*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vi) + return false; + } + } + else if (attr.type == 5) + { + const int64_t* p = (const int64_t*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vi) + return false; + } + } + else if (attr.type == 7) + { + const signed char* p = (const signed char*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vi) + return false; + } + } + else if (attr.type == 8) + { + const unsigned char* p = (const unsigned char*)attr.data.data(); + for (int i = 0; i < size; i++) + { + if (p[i] != vi) + return false; + } + } + else + { + // unsupported data type + return false; + } + + return true; +} + +static bool operator_is_all_constant(const Operator* op, float vf, int vi) +{ + if (op->type == "pnnx.Attribute") + return attribute_is_all_constant(op, vf, vi); + + if (op->type == "prim::Constant") + return constant_is_all_constant(op, vf, vi); + + return false; +} + +void eliminate_noop_math(Graph& graph) +{ + for (;;) + { + bool need_eliminate = false; + + // build expression via reverse order + for (int i = (int)graph.ops.size() - 1; i >= 0; i--) + { + Operator* op = graph.ops[i]; + + int identity_input_id = 0; + if (op->type == "aten::add" || op->type == "aten::add_") + { + Operator* op0 = op->inputs[0]->producer; + Operator* op1 = op->inputs[1]->producer; + Operator* op2 = op->inputs[2]->producer; + + if (operator_is_all_constant(op1, 0.f, 0)) + { + // x <= a + 0 * c + need_eliminate = true; + identity_input_id = 0; + } + else if (operator_is_all_constant(op2, 0.f, 0)) + { + // x <= a + b * 0 + need_eliminate = true; + identity_input_id = 0; + } + else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op0, 1.f, 1)) + { + // x <= 0 + b * 1 + need_eliminate = true; + identity_input_id = 1; + } + } + if (op->type == "aten::sub") + { + Operator* op1 = op->inputs[1]->producer; + Operator* op2 = op->inputs[2]->producer; + + if (operator_is_all_constant(op1, 0.f, 0)) + { + // x <= a - 0 * c + need_eliminate = true; + identity_input_id = 0; + } + else if (operator_is_all_constant(op2, 0.f, 0)) + { + // x <= a - b * 0 + need_eliminate = true; + identity_input_id = 0; + } + } + if (op->type == "aten::rsub") + { + Operator* op0 = op->inputs[0]->producer; + Operator* op1 = op->inputs[1]->producer; + Operator* op2 = op->inputs[2]->producer; + + if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op2, 1.f, 1)) + { + // x <= b * 1 - 0 + need_eliminate = true; + identity_input_id = 1; + } + else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op1, 1.f, 1)) + { + // x <= 1 * c - 0 + need_eliminate = true; + identity_input_id = 2; + } + } + if (op->type == "aten::mul") + { + Operator* op0 = op->inputs[0]->producer; + Operator* op1 = op->inputs[1]->producer; + + if (operator_is_all_constant(op0, 1.f, 1)) + { + // x <= 1 * b + need_eliminate = true; + identity_input_id = 1; + } + if (operator_is_all_constant(op1, 1.f, 1)) + { + // x <= a * 1 + need_eliminate = true; + identity_input_id = 0; + } + } + if (op->type == "aten::div" || op->type == "aten::div_") + { + Operator* op1 = op->inputs[1]->producer; + + if (operator_is_all_constant(op1, 1.f, 1)) + { + // x <= a / 1 + need_eliminate = true; + identity_input_id = 0; + } + } + if (op->type == "aten::pow") + { + Operator* op1 = op->inputs[1]->producer; + + if (operator_is_all_constant(op1, 1.f, 1)) + { + // x <= x ^ 1 + need_eliminate = true; + identity_input_id = 0; + } + } + + if (!need_eliminate) + continue; + + fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str()); + + for (auto& x : op->inputs) + { + x->remove_consumer(op); + } + + Operand* math_out = op->outputs[0]; + + for (auto& x : math_out->consumers) + { + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == math_out) + x->inputs[j] = op->inputs[identity_input_id]; + } + + op->inputs[identity_input_id]->consumers.push_back(x); + } + + math_out->producer = 0; + math_out->consumers.clear(); + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), math_out)); + delete math_out; + + op->inputs.clear(); + op->outputs.clear(); + + graph.ops.erase(graph.ops.begin() + i); + delete op; + + break; + } + + if (!need_eliminate) + break; + } + + // dce + dead_code_elimination(graph); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h b/tools/pnnx/src/pass_level3/eliminate_noop_math.h similarity index 86% rename from tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h rename to tools/pnnx/src/pass_level3/eliminate_noop_math.h index 06949e6f0..08d0113c3 100644 --- a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h +++ b/tools/pnnx/src/pass_level3/eliminate_noop_math.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_chunk_split_unpack(Graph& graph); +void eliminate_noop_math(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp b/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp deleted file mode 100644 index be409705f..000000000 --- a/tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp +++ /dev/null @@ -1,197 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "fuse_attribute_expression.h" -#include -#include -#include "pass_level2.h" - -namespace pnnx { - -void fuse_attribute_expression(Graph& graph) -{ - while (1) - { - bool matched = false; - - for (size_t i = 0; i < graph.ops.size(); i++) - { - Operator* op = graph.ops[i]; - - if (op->type != "pnnx.Attribute") - continue; - - if (op->outputs.size() != 1) - continue; - - if (op->outputs[0]->consumers.size() != 1) - continue; - - Operator* op2 = op->outputs[0]->consumers[0]; - Operator* op3 = 0; - Operator* op4 = 0; - - float y = 0.f; - float z = 0.f; - - if (op2->type == "aten::add" || op2->type == "aten::sub") - { - if (op2->inputs[0] != op->outputs[0]) - continue; - - op3 = op2->inputs[1]->producer; - if (op3->type != "prim::Constant") - continue; - - if (op3->params["value"].type == 2) - { - y = op3->params["value"].i; - } - else if (op3->params["value"].type == 3) - { - y = op3->params["value"].f; - } - else - { - // not a scalar - continue; - } - - op4 = op2->inputs[2]->producer; - if (op4->type != "prim::Constant") - continue; - - if (op4->params["value"].type == 2) - { - z = op4->params["value"].i; - } - else if (op4->params["value"].type == 3) - { - z = op4->params["value"].f; - } - else - { - // not a scalar - continue; - } - } - else if (op2->type == "aten::mul" || op2->type == "aten::div" || op2->type == "aten::pow") - { - if (op2->inputs[0] != op->outputs[0]) - continue; - - op3 = op2->inputs[1]->producer; - if (op3->type != "prim::Constant") - continue; - - if (op3->params["value"].type == 2) - { - y = op3->params["value"].i; - } - else if (op3->params["value"].type == 3) - { - y = op3->params["value"].f; - } - else - { - // not a scalar - continue; - } - } - else - { - // todo more operator type - continue; - } - - matched = true; - - // apply mul - { - auto it = op->attrs.begin(); - std::string attr_key = it->first; - const Attribute& attr = it->second; - - float* weight = (float*)attr.data.data(); - const int weight_size = attr.data.size() / sizeof(float); - - if (op2->type == "aten::add") - { - for (int i = 0; i < weight_size; i++) - weight[i] += y * z; - } - else if (op2->type == "aten::sub") - { - for (int i = 0; i < weight_size; i++) - weight[i] -= y * z; - } - else if (op2->type == "aten::mul") - { - for (int i = 0; i < weight_size; i++) - weight[i] *= y; - } - else if (op2->type == "aten::div") - { - for (int i = 0; i < weight_size; i++) - weight[i] /= y; - } - else if (op2->type == "aten::pow") - { - for (int i = 0; i < weight_size; i++) - weight[i] = (float)pow(weight[i], y); - } - - op->attrs[attr_key] = attr; - } - - op2->outputs[0]->producer = op; - - for (auto& x : op2->inputs) - { - x->producer = 0; - x->remove_consumer(op2); - } - - op->outputs = op2->outputs; - - op2->inputs.clear(); - op2->outputs.clear(); - - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); - - delete op2; - - if (op3 && op3->outputs[0]->consumers.empty()) - { - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op3)); - - delete op3; - } - - if (op4 && op4->outputs[0]->consumers.empty()) - { - graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op4)); - - delete op4; - } - - break; - } - - if (!matched) - break; - } -} - -} // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp b/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp similarity index 93% rename from tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp rename to tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp index a020f4dd5..ee3ebec14 100644 --- a/tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp +++ b/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp @@ -12,13 +12,13 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "fuse_chunk_split_unpack.h" +#include "fuse_chunk_split_unbind_unpack.h" #include #include "pass_level2.h" namespace pnnx { -void fuse_chunk_split_unpack(Graph& graph) +void fuse_chunk_split_unbind_unpack(Graph& graph) { while (1) { @@ -28,7 +28,7 @@ void fuse_chunk_split_unpack(Graph& graph) { Operator* op = graph.ops[i]; - if (op->type != "torch.chunk" && op->type != "torch.split") + if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind") continue; if (op->outputs.size() != 1) diff --git a/tools/pnnx/src/pass_level3/fuse_attribute_expression.h b/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h similarity index 93% rename from tools/pnnx/src/pass_level3/fuse_attribute_expression.h rename to tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h index 348542d76..da6fa0dec 100644 --- a/tools/pnnx/src/pass_level3/fuse_attribute_expression.h +++ b/tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h @@ -16,6 +16,6 @@ namespace pnnx { -void fuse_attribute_expression(Graph& graph); +void fuse_chunk_split_unbind_unpack(Graph& graph); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index df1115b1f..6f379059a 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -18,13 +18,13 @@ namespace pnnx { -static bool operand_maybe_tensor(Operand* operand) +static bool operand_maybe_tensor(const Operand* operand) { - Operator* op = operand->producer; + const Operator* op = operand->producer; if (op->type == "prim::Constant") { - const Parameter& param = op->params["value"]; + const Parameter& param = op->params.at("value"); if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4) { return false; @@ -83,9 +83,25 @@ static bool operand_maybe_tensor(Operand* operand) return true; } +static bool operand_is_foldable(const Operand* operand) +{ + const Operator* op = operand->producer; + + if (op->type == "pnnx.Input") + return false; + + for (auto x : op->inputs) + { + if (!operand_is_foldable(x)) + return false; + } + + return true; +} + static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector& inputs, bool checksubgraph = true) { - // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); + // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); Operator* op = operand->producer; @@ -164,6 +180,28 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s } } } + else if (checksubgraph && operand_maybe_tensor(operand) && operand_is_foldable(operand)) + { + // fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str()); + + auto it = std::find(inputs.begin(), inputs.end(), operand); + if (it == inputs.end()) + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)inputs.size()); + expr += tmp; + + inputs.push_back(operand); + } + else + { + // tensor + char tmp[32]; + sprintf(tmp, "@%d", (int)(it - inputs.begin())); + expr += tmp; + } + } else if (op->type == "prim::NumToTensor") { fuse_expression(graph, op->inputs[0], expr, inputs); diff --git a/tools/pnnx/src/pass_level4.cpp b/tools/pnnx/src/pass_level4.cpp index 8ebcb4bfe..0e565b46d 100644 --- a/tools/pnnx/src/pass_level4.cpp +++ b/tools/pnnx/src/pass_level4.cpp @@ -26,7 +26,7 @@ void pass_level4(Graph& g) dead_code_elimination(g); - canonicalize(g); + //canonicalize(g); } } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 6f3b2cb84..757a8f180 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -14,6 +14,7 @@ #include "pass_level5.h" +#include "pass_level5/fold_constants.h" #include "pass_level5/eliminate_dropout.h" #include "pass_level5/eliminate_slice.h" #include "pass_level5/eliminate_view_reshape.h" @@ -29,10 +30,11 @@ #include "pass_level5/fuse_slice_indices.h" #include "pass_level4/dead_code_elimination.h" #include "pass_level4/canonicalize.h" +#include "pass_level3/fuse_index_expression.h" namespace pnnx { -void pass_level5(Graph& g) +void pass_level5(Graph& g, const std::map& foldable_constants) { eval_expression(g); @@ -60,6 +62,10 @@ void pass_level5(Graph& g) fuse_channel_shuffle(g); + fold_constants(g, foldable_constants); + + fuse_index_expression(g); + dead_code_elimination(g); canonicalize(g); diff --git a/tools/pnnx/src/pass_level5.h b/tools/pnnx/src/pass_level5.h index 14228516d..fbf4ff486 100644 --- a/tools/pnnx/src/pass_level5.h +++ b/tools/pnnx/src/pass_level5.h @@ -19,7 +19,7 @@ namespace pnnx { -void pass_level5(Graph& g); +void pass_level5(Graph& g, const std::map& foldable_constants); } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp index a23e2b048..b570f7de1 100644 --- a/tools/pnnx/src/pass_level5/eliminate_dropout.cpp +++ b/tools/pnnx/src/pass_level5/eliminate_dropout.cpp @@ -40,24 +40,24 @@ void eliminate_dropout(Graph& graph) x->remove_consumer(op); } - Operand* slice_out = op->outputs[0]; + Operand* dropout_out = op->outputs[0]; - for (auto& x : slice_out->consumers) + for (auto& x : dropout_out->consumers) { for (size_t j = 0; j < x->inputs.size(); j++) { - if (x->inputs[j] == slice_out) + if (x->inputs[j] == dropout_out) x->inputs[j] = op->inputs[0]; } op->inputs[0]->consumers.push_back(x); } - slice_out->producer = 0; - slice_out->consumers.clear(); + dropout_out->producer = 0; + dropout_out->consumers.clear(); - graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out)); - delete slice_out; + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), dropout_out)); + delete dropout_out; op->inputs.clear(); op->outputs.clear(); diff --git a/tools/pnnx/src/pass_level5/fold_constants.cpp b/tools/pnnx/src/pass_level5/fold_constants.cpp new file mode 100644 index 000000000..51c8e7153 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fold_constants.cpp @@ -0,0 +1,50 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fold_constants.h" +#include + +#include "pass_level4/dead_code_elimination.h" + +namespace pnnx { + +void fold_constants(Graph& graph, const std::map& foldable_constants) +{ + for (size_t i = 0; i < graph.operands.size(); i++) + { + Operand* operand = graph.operands[i]; + const std::string& name = operand->name; + + if (foldable_constants.find(name) == foldable_constants.end()) + continue; + + Operator* op = operand->producer; + if (op->type == "pnnx.Attribute") + continue; + + // replace producer with attribute + Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); + + op_new->attrs[std::string("pnnx_fold_") + name] = foldable_constants.at(name); + op_new->outputs.push_back(operand); + operand->producer = op_new; + + op->outputs.clear(); + } + + // dce + dead_code_elimination(graph); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fold_constants.h b/tools/pnnx/src/pass_level5/fold_constants.h new file mode 100644 index 000000000..6ebffbda0 --- /dev/null +++ b/tools/pnnx/src/pass_level5/fold_constants.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fold_constants(Graph& graph, const std::map& foldable_constants); + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index d96cb76f8..c975b252d 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -181,6 +181,7 @@ pnnx_add_test(torch_split) pnnx_add_test(torch_squeeze) pnnx_add_test(torch_stack) pnnx_add_test(torch_transpose) +pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) pnnx_add_test(mobilenet_v2) @@ -192,6 +193,8 @@ pnnx_add_test(squeezenet1_1) # TODO enable end2end quantization model test #pnnx_add_test(quantization_shufflenet_v2_x1_0) +pnnx_add_test(pnnx_eliminate_noop_math) +pnnx_add_test(pnnx_fold_constant) pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d) pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) diff --git a/tools/pnnx/tests/test_Tensor_index.py b/tools/pnnx/tests/test_Tensor_index.py index 921304232..2156efcd6 100644 --- a/tools/pnnx/tests/test_Tensor_index.py +++ b/tools/pnnx/tests/test_Tensor_index.py @@ -39,15 +39,15 @@ def test(): # export torchscript mod = torch.jit.trace(net, (x, y, z)) - mod.save("test_Tensor_slice.pt") + mod.save("test_Tensor_index.pt") # torchscript to pnnx import os - os.system("../src/pnnx test_Tensor_slice.pt inputshape=[3,6],[5,9,2],[2,4,5,10]") + os.system("../src/pnnx test_Tensor_index.pt inputshape=[3,6],[5,9,2],[2,4,5,10]") # pnnx inference - import test_Tensor_slice_pnnx - b = test_Tensor_slice_pnnx.test_inference() + import test_Tensor_index_pnnx + b = test_Tensor_index_pnnx.test_inference() for a0, b0 in zip(a, b): if not torch.equal(a0, b0): diff --git a/tools/pnnx/tests/test_pnnx_eliminate_noop_math.py b/tools/pnnx/tests/test_pnnx_eliminate_noop_math.py new file mode 100644 index 000000000..37b8522b0 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_eliminate_noop_math.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w0 = nn.Parameter(torch.zeros(1, 12, 52)) + self.w1 = nn.Parameter(torch.ones(1, 12, 52)) + self.w2 = nn.Parameter(torch.ones(1, 12, 52)) + + def forward(self, x): + x = x + 0 + x = x * 1 / 1 + x = 0 + 1 * x + x = x + self.w0 * self.w1 + x = x * self.w2 + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 52) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_eliminate_noop_math.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_eliminate_noop_math.pt inputshape=[1,12,52]") + + # pnnx inference + import test_pnnx_eliminate_noop_math_pnnx + b = test_pnnx_eliminate_noop_math_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_pnnx_fold_constant.py b/tools/pnnx/tests/test_pnnx_fold_constant.py new file mode 100644 index 000000000..f492295b3 --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fold_constant.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w0 = nn.Parameter(torch.rand(1, 12, 52)) + self.w1 = nn.Parameter(torch.rand(1, 12, 52)) + self.w2 = nn.Parameter(torch.rand(1, 12, 1)) + self.w3 = nn.Parameter(torch.rand(1, 12, 52)) + + def forward(self, x): + b = (self.w0 + self.w1 + 0.22) + self.w2 * 0.1 + x = x + b - self.w3 / 2 + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 52) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_pnnx_fold_constant.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fold_constant.pt inputshape=[1,12,52]") + + # pnnx inference + import test_pnnx_fold_constant_pnnx + b = test_pnnx_fold_constant_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_unbind.py b/tools/pnnx/tests/test_torch_unbind.py new file mode 100644 index 000000000..c92c87b74 --- /dev/null +++ b/tools/pnnx/tests/test_torch_unbind.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x0, x1, x2 = torch.unbind(x, dim=1) + y0, y1, y2, y3, y4, y5, y6, y7, y8 = torch.unbind(y, dim=2) + z0, z1, z2, z3 = torch.unbind(z, dim=0) + return x0, x1, y0, y1, y2, y3, y4, y5, y6, y7, y8, z0, z1, z2, z3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(4, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_unbind.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_unbind.pt inputshape=[1,3,16],[1,5,9,11],[4,8,5,9,10]") + + # pnnx inference + import test_torch_unbind_pnnx + b = test_torch_unbind_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)