Browse Source

pnnx fold constant (#3521)

tags/20220216
nihui GitHub 4 years ago
parent
commit
340b4e673e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1016 additions and 334 deletions
  1. +7
    -3
      tools/pnnx/src/CMakeLists.txt
  2. +29
    -7
      tools/pnnx/src/ir.cpp
  3. +3
    -5
      tools/pnnx/src/main.cpp
  4. +2
    -2
      tools/pnnx/src/pass_level0.cpp
  5. +2
    -1
      tools/pnnx/src/pass_level0.h
  6. +217
    -62
      tools/pnnx/src/pass_level0/shape_inference.cpp
  7. +3
    -1
      tools/pnnx/src/pass_level0/shape_inference.h
  8. +41
    -0
      tools/pnnx/src/pass_level2/torch_flip.cpp
  9. +44
    -0
      tools/pnnx/src/pass_level2/torch_randn.cpp
  10. +41
    -0
      tools/pnnx/src/pass_level2/torch_unbind.cpp
  11. +5
    -5
      tools/pnnx/src/pass_level3.cpp
  12. +0
    -27
      tools/pnnx/src/pass_level3/assign_unique_name.cpp
  13. +297
    -0
      tools/pnnx/src/pass_level3/eliminate_noop_math.cpp
  14. +2
    -2
      tools/pnnx/src/pass_level3/eliminate_noop_math.h
  15. +0
    -197
      tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp
  16. +3
    -3
      tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp
  17. +1
    -1
      tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h
  18. +42
    -4
      tools/pnnx/src/pass_level3/fuse_expression.cpp
  19. +1
    -1
      tools/pnnx/src/pass_level4.cpp
  20. +7
    -1
      tools/pnnx/src/pass_level5.cpp
  21. +1
    -1
      tools/pnnx/src/pass_level5.h
  22. +7
    -7
      tools/pnnx/src/pass_level5/eliminate_dropout.cpp
  23. +50
    -0
      tools/pnnx/src/pass_level5/fold_constants.cpp
  24. +21
    -0
      tools/pnnx/src/pass_level5/fold_constants.h
  25. +3
    -0
      tools/pnnx/tests/CMakeLists.txt
  26. +4
    -4
      tools/pnnx/tests/test_Tensor_index.py
  27. +62
    -0
      tools/pnnx/tests/test_pnnx_eliminate_noop_math.py
  28. +60
    -0
      tools/pnnx/tests/test_pnnx_fold_constant.py
  29. +61
    -0
      tools/pnnx/tests/test_torch_unbind.py

+ 7
- 3
tools/pnnx/src/CMakeLists.txt View File

@@ -59,7 +59,7 @@ set(pnnx_pass_level1_SRCS
pass_level1/nn_MaxPool1d.cpp
pass_level1/nn_MaxPool2d.cpp
pass_level1/nn_MaxPool3d.cpp
pass_level1/nn_maxunpool2d.cpp
#pass_level1/nn_maxunpool2d.cpp
pass_level1/nn_Mish.cpp
pass_level1/nn_MultiheadAttention.cpp
pass_level1/nn_PixelShuffle.cpp
@@ -186,6 +186,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_clone.cpp
pass_level2/torch_dequantize.cpp
pass_level2/torch_flatten.cpp
pass_level2/torch_flip.cpp
pass_level2/torch_logsumexp.cpp
pass_level2/torch_matmul.cpp
pass_level2/torch_mean.cpp
@@ -193,6 +194,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_normal.cpp
pass_level2/torch_prod.cpp
pass_level2/torch_quantize_per_tensor.cpp
pass_level2/torch_randn.cpp
pass_level2/torch_roll.cpp
pass_level2/torch_split.cpp
pass_level2/torch_squeeze.cpp
@@ -200,6 +202,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
pass_level2/torch_var.cpp
pass_level2/torch_zeros.cpp
@@ -210,11 +213,11 @@ set(pnnx_pass_level2_SRCS

set(pnnx_pass_level3_SRCS
pass_level3/assign_unique_name.cpp
pass_level3/eliminate_noop_math.cpp
pass_level3/eliminate_tuple_pair.cpp
pass_level3/expand_quantization_modules.cpp
pass_level3/fuse_attribute_expression.cpp
pass_level3/fuse_cat_stack_tensors.cpp
pass_level3/fuse_chunk_split_unpack.cpp
pass_level3/fuse_chunk_split_unbind_unpack.cpp
pass_level3/fuse_expression.cpp
pass_level3/fuse_index_expression.cpp
pass_level3/fuse_rnn_unpack.cpp
@@ -235,6 +238,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_slice.cpp
pass_level5/eliminate_view_reshape.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_channel_shuffle.cpp
pass_level5/fuse_constant_expression.cpp
pass_level5/fuse_conv1d_batchnorm1d.cpp


+ 29
- 7
tools/pnnx/src/ir.cpp View File

@@ -1343,7 +1343,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, ",");
}

fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
if (attr.type == 1 || attr.type == 2 || attr.type == 3)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
}
}

@@ -1373,11 +1380,11 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

if (is_running_mean_var)
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str());
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}
else
{
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str());
fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str());
}

for (size_t i = 0; i < attr.shape.size(); i++)
@@ -1387,7 +1394,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, ",");
}

fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
if (attr.type == 1 || attr.type == 2 || attr.type == 3)
{
fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type));
}
else
{
fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type));
}
}

fprintf(pyfp, " archive.close()\n");
@@ -1452,7 +1466,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
else if (op->type == "pnnx.Attribute")
{
const std::string& key = op->attrs.begin()->first;
fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), key.c_str());
fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str());
}
else if (op->type == "Tensor.slice")
{
@@ -1463,8 +1477,16 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
else if (op->type == "Tensor.index")
{
// index expr
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
if (op->inputs.size() == 2)
{
std::string expanded_expr = expand_expression(op->inputs[1]->producer);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str());
}
else
{
std::string index_expr = make_index_expression(op);
fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str());
}
}
else if (op->type == "Tensor.view" || op->type == "Tensor.reshape")
{


+ 3
- 5
tools/pnnx/src/main.cpp View File

@@ -279,9 +279,6 @@ int main(int argc, char** argv)
fprintf(stderr, "\n");
}

// at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);
// torch::autograd::AutoGradMode guard(false);

for (auto m : customop_modules)
{
fprintf(stderr, "load custom module %s\n", m.c_str());
@@ -339,7 +336,8 @@ int main(int argc, char** argv)

fprintf(stderr, "############# pass_level0\n");

pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators);
std::map<std::string, pnnx::Attribute> foldable_constants;
pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);

// g->dump();

@@ -373,7 +371,7 @@ int main(int argc, char** argv)
{
fprintf(stderr, "############# pass_level5\n");

pnnx::pass_level5(pnnx_graph);
pnnx::pass_level5(pnnx_graph, foldable_constants);
}

pnnx_graph.save(pnnxparampath, pnnxbinpath);


+ 2
- 2
tools/pnnx/src/pass_level0.cpp View File

@@ -20,7 +20,7 @@

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators)
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
{
inline_block(g, module_operators);

@@ -28,7 +28,7 @@ void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Grap

if (!input_tensors.empty())
{
shape_inference(mod, g, input_tensors, input_tensors2);
shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
}
}



+ 2
- 1
tools/pnnx/src/pass_level0.h View File

@@ -16,10 +16,11 @@
#define PNNX_PASS_LEVEL0_H

#include <torch/script.h>
#include "ir.h"

namespace pnnx {

void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators);
void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx



+ 217
- 62
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -13,74 +13,233 @@
// specific language governing permissions and limitations under the License.

#include "shape_inference.h"
#include <unordered_set>

#include "pass_level0/constant_unpooling.h"
#include "pass_level0/inline_block.h"
#include "pass_level0/shape_inference.h"

namespace pnnx {

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2)
static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs)
{
// collect all intermediate output tensors
std::vector<torch::jit::Value*> values;
for (const auto& n : graph->nodes())
for (auto x : inputs)
{
for (const auto& on : n->outputs())
{
auto tensor_type = on->type()->cast<torch::jit::TensorType>();
if (!tensor_type)
continue;
if (v == x)
return true;
}

for (size_t i = 0; i < v->node()->inputs().size(); i++)
{
bool link = value_link_input(v->node()->inputs()[i], inputs);
if (link)
return true;
}

return false;
}

static bool value_link_output(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& outputs)
{
for (auto x : outputs)
{
if (v == x)
return true;
}

values.push_back(on);
for (size_t i = 0; i < v->uses().size(); i++)
{
auto node = v->uses()[i].user;
for (auto x : node->outputs())
{
bool link = value_link_output(x, outputs);
if (link)
return true;
}
}

// set new graph output
auto old_output = graph->outputs()[0];
return false;
}

torch::jit::Node* new_return_node = graph->createTuple(at::ArrayRef<torch::jit::Value*>(values));
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants)
{
// collect all intermediate output tensors
std::vector<std::unordered_set<std::string> > more_value_names;
std::vector<std::vector<torch::jit::Value*> > more_values;
{
std::unordered_set<std::string> value_names;
std::vector<torch::jit::Value*> values;
for (const auto& n : graph->nodes())
{
for (const auto& v : n->outputs())
{
auto tensor_type = v->type()->cast<torch::jit::TensorType>();
if (!tensor_type)
continue;

graph->appendNode(new_return_node);
value_names.insert(v->debugName());
values.push_back(v);
}

graph->eraseOutput(0);
graph->registerOutput(new_return_node->outputs()[0]);
// too many intermediate blobs in one inference results oom
if (value_names.size() >= 1000)
{
more_value_names.push_back(value_names);
value_names.clear();

more_values.push_back(values);
values.clear();
}
}

if (value_names.size() > 0)
{
more_value_names.push_back(value_names);
more_values.push_back(values);
}
}

// collect graph inputs outputs
std::vector<torch::jit::Value*> g_inputs;
for (size_t i = 1; i < graph->inputs().size(); i++)
{
g_inputs.push_back(graph->inputs()[i]);
}
std::vector<torch::jit::Value*> g_outputs;
for (size_t i = 0; i < graph->outputs().size(); i++)
{
g_outputs.push_back(graph->outputs()[i]);
}

// inference for all tensors
std::vector<torch::jit::IValue> inputs;
for (size_t i = 0; i < input_tensors.size(); i++)
{
const at::Tensor& it = input_tensors[i];

inputs.push_back(it);
graph->inputs()[1 + i]->setType(c10::TensorType::create(it));
}

auto outputs = mod.copy().forward(inputs).toTuple();
std::vector<torch::jit::IValue> inputs2;
for (size_t i = 0; i < input_tensors2.size(); i++)
{
const at::Tensor& it = input_tensors2[i];
inputs2.push_back(it);
}

if (input_tensors2.empty())
std::map<torch::jit::Value*, at::Tensor> output_tensors;

for (size_t p = 0; p < more_value_names.size(); p++)
{
// assign shape info
int index = 0;
for (auto e : outputs->elements())
std::unordered_set<std::string>& value_names = more_value_names[p];
std::vector<torch::jit::Value*>& values = more_values[p];

// auto mod2 = mod.deepcopy();

torch::jit::Module mod2 = torch::jit::load(ptpath);
mod2.eval();

auto graph2 = mod2.get_method("forward").graph();

inline_block(graph2, module_operators);

constant_unpooling(graph2);

std::vector<torch::jit::Value*> values2;
for (auto n : graph2->nodes())
{
values[index]->setType(c10::TensorType::create(e.toTensor()));
for (const auto& v : n->outputs())
{
auto tensor_type = v->type()->cast<torch::jit::TensorType>();
if (!tensor_type)
continue;

index++;
if (value_names.find(v->debugName()) != value_names.end())
{
values2.push_back(v);
fprintf(stderr, "%s ", v->debugName().c_str());
}
}
}
}
else
{
std::vector<torch::jit::IValue> inputs2;
for (size_t i = 0; i < input_tensors2.size(); i++)
fprintf(stderr, "\n----------------\n\n");

// set new graph output
torch::jit::Node* new_return_node = graph2->createTuple(at::ArrayRef<torch::jit::Value*>(values2));

graph2->appendNode(new_return_node);

graph2->eraseOutput(0);
graph2->registerOutput(new_return_node->outputs()[0]);

// inference for all tensors
auto outputs = mod2.copy().forward(inputs).toTuple();

if (input_tensors2.empty())
{
const at::Tensor& it = input_tensors2[i];
// assign shape info
for (size_t i = 0; i < values2.size(); i++)
{
auto v = values[i];
auto t = outputs->elements()[i].toTensor();

v->setType(c10::TensorType::create(t));

inputs2.push_back(it);
graph->inputs()[1 + i]->setType(c10::TensorType::create(it));
// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
}
}
}
else
{
// assign dynamic shape info
auto outputs2 = mod2.copy().forward(inputs2).toTuple();

fprintf(stderr, "assign dynamic shape info\n");

auto outputs2 = mod.copy().forward(inputs2).toTuple();
for (size_t i = 0; i < values2.size(); i++)
{
auto v = values[i];
auto t = outputs->elements()[i].toTensor();
auto t2 = outputs2->elements()[i].toTensor();

fprintf(stderr, "assign dynamic shape info\n");
auto type1 = c10::TensorType::create(t);
auto type2 = c10::TensorType::create(t2);

std::vector<c10::ShapeSymbol> sizes1 = type1->symbolic_sizes().sizes().value();
std::vector<c10::ShapeSymbol> sizes2 = type2->symbolic_sizes().sizes().value();

for (size_t i = 0; i < sizes1.size(); i++)
{
if (sizes1[i] == sizes2[i])
continue;

sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1);
}

auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1));

v->setType(finaltype);

// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
}
}
}
}

if (input_tensors2.empty())
{
for (size_t i = 0; i < input_tensors.size(); i++)
{
auto type = c10::TensorType::create(input_tensors[i]);

// assign dynamic shape info
graph->inputs()[1 + i]->setType(type);
}
}
else
{
for (size_t i = 0; i < input_tensors.size(); i++)
{
auto type1 = c10::TensorType::create(input_tensors[i]);
@@ -101,38 +260,34 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::

graph->inputs()[1 + i]->setType(finaltype);
}
}

int index = 0;
for (auto e : outputs->elements())
{
auto type1 = c10::TensorType::create(e.toTensor());
auto type2 = c10::TensorType::create(outputs2->elements()[index].toTensor());

std::vector<c10::ShapeSymbol> sizes1 = type1->symbolic_sizes().sizes().value();
std::vector<c10::ShapeSymbol> sizes2 = type2->symbolic_sizes().sizes().value();
for (auto xx : output_tensors)
{
auto v = xx.first;
auto tensor = xx.second;

for (size_t i = 0; i < sizes1.size(); i++)
bool link_to_output = false;
for (size_t i = 0; i < v->uses().size(); i++)
{
auto node = v->uses()[i].user;
for (auto x : node->outputs())
{
if (sizes1[i] == sizes2[i])
continue;

sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1);
if (output_tensors.find(x) == output_tensors.end())
{
link_to_output = true;
break;
}
}
}

auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1));
values[index]->setType(finaltype);
index++;
const int ndim = (int)tensor.dim();
if (link_to_output && ndim > 0)
{
fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants[v->debugName()] = Attribute(tensor);
}
}

// restore old graph output
graph->eraseOutput(0);
graph->registerOutput(old_output);

new_return_node->removeAllInputs();
new_return_node->destroy();
}

} // namespace pnnx

+ 3
- 1
tools/pnnx/src/pass_level0/shape_inference.h View File

@@ -13,9 +13,11 @@
// specific language governing permissions and limitations under the License.

#include <torch/script.h>
#include <map>
#include "ir.h"

namespace pnnx {

void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2);
void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx

+ 41
- 0
tools/pnnx/src/pass_level2/torch_flip.cpp View File

@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_flip : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dims
aten::flip op_0 2 1 input dims out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.flip";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 20)

} // namespace pnnx

+ 44
- 0
tools/pnnx/src/pass_level2/torch_randn.cpp View File

@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_randn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 size
prim::Constant op_0 0 1 dtype value=*
prim::Constant op_1 0 1 layout value=*
prim::Constant op_2 0 1 device value=*
prim::Constant op_3 0 1 requires_grad value=*
aten::randn op_4 5 1 size dtype layout device requires_grad out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.randn";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_randn, 20)

} // namespace pnnx

+ 41
- 0
tools/pnnx/src/pass_level2/torch_unbind.cpp View File

@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_unbind : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dim
aten::unbind op_0 2 1 input dim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.unbind";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unbind, 20)

} // namespace pnnx

+ 5
- 5
tools/pnnx/src/pass_level3.cpp View File

@@ -15,11 +15,11 @@
#include "pass_level3.h"

#include "pass_level3/assign_unique_name.h"
#include "pass_level3/eliminate_noop_math.h"
#include "pass_level3/eliminate_tuple_pair.h"
#include "pass_level3/expand_quantization_modules.h"
#include "pass_level3/fuse_attribute_expression.h"
#include "pass_level3/fuse_cat_stack_tensors.h"
#include "pass_level3/fuse_chunk_split_unpack.h"
#include "pass_level3/fuse_chunk_split_unbind_unpack.h"
#include "pass_level3/fuse_expression.h"
#include "pass_level3/fuse_index_expression.h"
#include "pass_level3/fuse_rnn_unpack.h"
@@ -39,14 +39,12 @@ void pass_level3(Graph& g)

fuse_cat_stack_tensors(g);

fuse_chunk_split_unpack(g);
fuse_chunk_split_unbind_unpack(g);

fuse_rnn_unpack(g);

expand_quantization_modules(g);

fuse_attribute_expression(g);

eliminate_tuple_pair(g);

rename_F_conv_transposend(g);
@@ -55,6 +53,8 @@ void pass_level3(Graph& g)

rename_F_dropoutnd(g);

eliminate_noop_math(g);

fuse_expression(g);

fuse_index_expression(g);


+ 0
- 27
tools/pnnx/src/pass_level3/assign_unique_name.cpp View File

@@ -45,33 +45,6 @@ void assign_unique_name(Graph& graph)
}
}
}

// assign unique name for all operands
{
std::unordered_set<std::string> names;
int make_unique_index = 0;

for (size_t i = 0; i < graph.operands.size(); i++)
{
Operand* operand = graph.operands[i];
const std::string& name = operand->name;

if (names.find(name) == names.end())
{
names.insert(name);
}
else
{
// duplicated found
std::string new_name = std::string("pnnx_unique_") + std::to_string(make_unique_index);
fprintf(stderr, "assign unique operand name %s to %s\n", new_name.c_str(), name.c_str());
operand->name = new_name;
names.insert(new_name);

make_unique_index++;
}
}
}
}

} // namespace pnnx

+ 297
- 0
tools/pnnx/src/pass_level3/eliminate_noop_math.cpp View File

@@ -0,0 +1,297 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_noop_math.h"

#include <algorithm>
#include "pass_level2.h"
#include "pass_level4/dead_code_elimination.h"

namespace pnnx {

static bool constant_is_all_constant(const Operator* op_constant, float vf, int vi)
{
const Parameter& param = op_constant->params.at("value");

if (param.type == 2)
{
if (param.i != vi)
return false;
}
else if (param.type == 3)
{
if (param.f != vf)
return false;
}
else
{
// unsupported data type
return false;
}

return true;
}

static bool attribute_is_all_constant(const Operator* op_attr, float vf, int vi)
{
const Attribute& attr = op_attr->attrs.begin()->second;

if (attr.shape.empty())
{
fprintf(stderr, "shape empty!\n");
return false;
}

int size = attr.shape[0];
for (size_t i = 1; i < attr.shape.size(); i++)
{
size *= attr.shape[i];
}

if (attr.type == 1)
{
const float* p = (const float*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vf)
return false;
}
}
else if (attr.type == 2)
{
const double* p = (const double*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vf)
return false;
}
}
else if (attr.type == 4)
{
const int* p = (const int*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vi)
return false;
}
}
else if (attr.type == 5)
{
const int64_t* p = (const int64_t*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vi)
return false;
}
}
else if (attr.type == 7)
{
const signed char* p = (const signed char*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vi)
return false;
}
}
else if (attr.type == 8)
{
const unsigned char* p = (const unsigned char*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (p[i] != vi)
return false;
}
}
else
{
// unsupported data type
return false;
}

return true;
}

static bool operator_is_all_constant(const Operator* op, float vf, int vi)
{
if (op->type == "pnnx.Attribute")
return attribute_is_all_constant(op, vf, vi);

if (op->type == "prim::Constant")
return constant_is_all_constant(op, vf, vi);

return false;
}

void eliminate_noop_math(Graph& graph)
{
for (;;)
{
bool need_eliminate = false;

// build expression via reverse order
for (int i = (int)graph.ops.size() - 1; i >= 0; i--)
{
Operator* op = graph.ops[i];

int identity_input_id = 0;
if (op->type == "aten::add" || op->type == "aten::add_")
{
Operator* op0 = op->inputs[0]->producer;
Operator* op1 = op->inputs[1]->producer;
Operator* op2 = op->inputs[2]->producer;

if (operator_is_all_constant(op1, 0.f, 0))
{
// x <= a + 0 * c
need_eliminate = true;
identity_input_id = 0;
}
else if (operator_is_all_constant(op2, 0.f, 0))
{
// x <= a + b * 0
need_eliminate = true;
identity_input_id = 0;
}
else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op0, 1.f, 1))
{
// x <= 0 + b * 1
need_eliminate = true;
identity_input_id = 1;
}
}
if (op->type == "aten::sub")
{
Operator* op1 = op->inputs[1]->producer;
Operator* op2 = op->inputs[2]->producer;

if (operator_is_all_constant(op1, 0.f, 0))
{
// x <= a - 0 * c
need_eliminate = true;
identity_input_id = 0;
}
else if (operator_is_all_constant(op2, 0.f, 0))
{
// x <= a - b * 0
need_eliminate = true;
identity_input_id = 0;
}
}
if (op->type == "aten::rsub")
{
Operator* op0 = op->inputs[0]->producer;
Operator* op1 = op->inputs[1]->producer;
Operator* op2 = op->inputs[2]->producer;

if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op2, 1.f, 1))
{
// x <= b * 1 - 0
need_eliminate = true;
identity_input_id = 1;
}
else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op1, 1.f, 1))
{
// x <= 1 * c - 0
need_eliminate = true;
identity_input_id = 2;
}
}
if (op->type == "aten::mul")
{
Operator* op0 = op->inputs[0]->producer;
Operator* op1 = op->inputs[1]->producer;

if (operator_is_all_constant(op0, 1.f, 1))
{
// x <= 1 * b
need_eliminate = true;
identity_input_id = 1;
}
if (operator_is_all_constant(op1, 1.f, 1))
{
// x <= a * 1
need_eliminate = true;
identity_input_id = 0;
}
}
if (op->type == "aten::div" || op->type == "aten::div_")
{
Operator* op1 = op->inputs[1]->producer;

if (operator_is_all_constant(op1, 1.f, 1))
{
// x <= a / 1
need_eliminate = true;
identity_input_id = 0;
}
}
if (op->type == "aten::pow")
{
Operator* op1 = op->inputs[1]->producer;

if (operator_is_all_constant(op1, 1.f, 1))
{
// x <= x ^ 1
need_eliminate = true;
identity_input_id = 0;
}
}

if (!need_eliminate)
continue;

fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str());

for (auto& x : op->inputs)
{
x->remove_consumer(op);
}

Operand* math_out = op->outputs[0];

for (auto& x : math_out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == math_out)
x->inputs[j] = op->inputs[identity_input_id];
}

op->inputs[identity_input_id]->consumers.push_back(x);
}

math_out->producer = 0;
math_out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), math_out));
delete math_out;

op->inputs.clear();
op->outputs.clear();

graph.ops.erase(graph.ops.begin() + i);
delete op;

break;
}

if (!need_eliminate)
break;
}

// dce
dead_code_elimination(graph);
}

} // namespace pnnx

tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.h → tools/pnnx/src/pass_level3/eliminate_noop_math.h View File

@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
@@ -16,6 +16,6 @@

namespace pnnx {

void fuse_chunk_split_unpack(Graph& graph);
void eliminate_noop_math(Graph& graph);

} // namespace pnnx

+ 0
- 197
tools/pnnx/src/pass_level3/fuse_attribute_expression.cpp View File

@@ -1,197 +0,0 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_attribute_expression.h"
#include <math.h>
#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_attribute_expression(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "pnnx.Attribute")
continue;

if (op->outputs.size() != 1)
continue;

if (op->outputs[0]->consumers.size() != 1)
continue;

Operator* op2 = op->outputs[0]->consumers[0];
Operator* op3 = 0;
Operator* op4 = 0;

float y = 0.f;
float z = 0.f;

if (op2->type == "aten::add" || op2->type == "aten::sub")
{
if (op2->inputs[0] != op->outputs[0])
continue;

op3 = op2->inputs[1]->producer;
if (op3->type != "prim::Constant")
continue;

if (op3->params["value"].type == 2)
{
y = op3->params["value"].i;
}
else if (op3->params["value"].type == 3)
{
y = op3->params["value"].f;
}
else
{
// not a scalar
continue;
}

op4 = op2->inputs[2]->producer;
if (op4->type != "prim::Constant")
continue;

if (op4->params["value"].type == 2)
{
z = op4->params["value"].i;
}
else if (op4->params["value"].type == 3)
{
z = op4->params["value"].f;
}
else
{
// not a scalar
continue;
}
}
else if (op2->type == "aten::mul" || op2->type == "aten::div" || op2->type == "aten::pow")
{
if (op2->inputs[0] != op->outputs[0])
continue;

op3 = op2->inputs[1]->producer;
if (op3->type != "prim::Constant")
continue;

if (op3->params["value"].type == 2)
{
y = op3->params["value"].i;
}
else if (op3->params["value"].type == 3)
{
y = op3->params["value"].f;
}
else
{
// not a scalar
continue;
}
}
else
{
// todo more operator type
continue;
}

matched = true;

// apply mul
{
auto it = op->attrs.begin();
std::string attr_key = it->first;
const Attribute& attr = it->second;

float* weight = (float*)attr.data.data();
const int weight_size = attr.data.size() / sizeof(float);

if (op2->type == "aten::add")
{
for (int i = 0; i < weight_size; i++)
weight[i] += y * z;
}
else if (op2->type == "aten::sub")
{
for (int i = 0; i < weight_size; i++)
weight[i] -= y * z;
}
else if (op2->type == "aten::mul")
{
for (int i = 0; i < weight_size; i++)
weight[i] *= y;
}
else if (op2->type == "aten::div")
{
for (int i = 0; i < weight_size; i++)
weight[i] /= y;
}
else if (op2->type == "aten::pow")
{
for (int i = 0; i < weight_size; i++)
weight[i] = (float)pow(weight[i], y);
}

op->attrs[attr_key] = attr;
}

op2->outputs[0]->producer = op;

for (auto& x : op2->inputs)
{
x->producer = 0;
x->remove_consumer(op2);
}

op->outputs = op2->outputs;

op2->inputs.clear();
op2->outputs.clear();

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2));

delete op2;

if (op3 && op3->outputs[0]->consumers.empty())
{
graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op3));

delete op3;
}

if (op4 && op4->outputs[0]->consumers.empty())
{
graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op4));

delete op4;
}

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

tools/pnnx/src/pass_level3/fuse_chunk_split_unpack.cpp → tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.cpp View File

@@ -12,13 +12,13 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_chunk_split_unpack.h"
#include "fuse_chunk_split_unbind_unpack.h"
#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void fuse_chunk_split_unpack(Graph& graph)
void fuse_chunk_split_unbind_unpack(Graph& graph)
{
while (1)
{
@@ -28,7 +28,7 @@ void fuse_chunk_split_unpack(Graph& graph)
{
Operator* op = graph.ops[i];

if (op->type != "torch.chunk" && op->type != "torch.split")
if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind")
continue;

if (op->outputs.size() != 1)

tools/pnnx/src/pass_level3/fuse_attribute_expression.h → tools/pnnx/src/pass_level3/fuse_chunk_split_unbind_unpack.h View File

@@ -16,6 +16,6 @@

namespace pnnx {

void fuse_attribute_expression(Graph& graph);
void fuse_chunk_split_unbind_unpack(Graph& graph);

} // namespace pnnx

+ 42
- 4
tools/pnnx/src/pass_level3/fuse_expression.cpp View File

@@ -18,13 +18,13 @@

namespace pnnx {

static bool operand_maybe_tensor(Operand* operand)
static bool operand_maybe_tensor(const Operand* operand)
{
Operator* op = operand->producer;
const Operator* op = operand->producer;

if (op->type == "prim::Constant")
{
const Parameter& param = op->params["value"];
const Parameter& param = op->params.at("value");
if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4)
{
return false;
@@ -83,9 +83,25 @@ static bool operand_maybe_tensor(Operand* operand)
return true;
}

static bool operand_is_foldable(const Operand* operand)
{
const Operator* op = operand->producer;

if (op->type == "pnnx.Input")
return false;

for (auto x : op->inputs)
{
if (!operand_is_foldable(x))
return false;
}

return true;
}

static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, bool checksubgraph = true)
{
// fprintf(stderr, "fuse_expression %s\n", operand->name.c_str());
// fprintf(stderr, "fuse_expression %s\n", operand->name.c_str());

Operator* op = operand->producer;

@@ -164,6 +180,28 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
}
}
}
else if (checksubgraph && operand_maybe_tensor(operand) && operand_is_foldable(operand))
{
// fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str());

auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
}
else if (op->type == "prim::NumToTensor")
{
fuse_expression(graph, op->inputs[0], expr, inputs);


+ 1
- 1
tools/pnnx/src/pass_level4.cpp View File

@@ -26,7 +26,7 @@ void pass_level4(Graph& g)

dead_code_elimination(g);

canonicalize(g);
//canonicalize(g);
}

} // namespace pnnx

+ 7
- 1
tools/pnnx/src/pass_level5.cpp View File

@@ -14,6 +14,7 @@

#include "pass_level5.h"

#include "pass_level5/fold_constants.h"
#include "pass_level5/eliminate_dropout.h"
#include "pass_level5/eliminate_slice.h"
#include "pass_level5/eliminate_view_reshape.h"
@@ -29,10 +30,11 @@
#include "pass_level5/fuse_slice_indices.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
#include "pass_level3/fuse_index_expression.h"

namespace pnnx {

void pass_level5(Graph& g)
void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants)
{
eval_expression(g);

@@ -60,6 +62,10 @@ void pass_level5(Graph& g)

fuse_channel_shuffle(g);

fold_constants(g, foldable_constants);

fuse_index_expression(g);

dead_code_elimination(g);

canonicalize(g);


+ 1
- 1
tools/pnnx/src/pass_level5.h View File

@@ -19,7 +19,7 @@

namespace pnnx {

void pass_level5(Graph& g);
void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx



+ 7
- 7
tools/pnnx/src/pass_level5/eliminate_dropout.cpp View File

@@ -40,24 +40,24 @@ void eliminate_dropout(Graph& graph)
x->remove_consumer(op);
}

Operand* slice_out = op->outputs[0];
Operand* dropout_out = op->outputs[0];

for (auto& x : slice_out->consumers)
for (auto& x : dropout_out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == slice_out)
if (x->inputs[j] == dropout_out)
x->inputs[j] = op->inputs[0];
}

op->inputs[0]->consumers.push_back(x);
}

slice_out->producer = 0;
slice_out->consumers.clear();
dropout_out->producer = 0;
dropout_out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out));
delete slice_out;
graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), dropout_out));
delete dropout_out;

op->inputs.clear();
op->outputs.clear();


+ 50
- 0
tools/pnnx/src/pass_level5/fold_constants.cpp View File

@@ -0,0 +1,50 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fold_constants.h"
#include <unordered_set>

#include "pass_level4/dead_code_elimination.h"

namespace pnnx {

void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants)
{
for (size_t i = 0; i < graph.operands.size(); i++)
{
Operand* operand = graph.operands[i];
const std::string& name = operand->name;

if (foldable_constants.find(name) == foldable_constants.end())
continue;

Operator* op = operand->producer;
if (op->type == "pnnx.Attribute")
continue;

// replace producer with attribute
Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op);

op_new->attrs[std::string("pnnx_fold_") + name] = foldable_constants.at(name);
op_new->outputs.push_back(operand);
operand->producer = op_new;

op->outputs.clear();
}

// dce
dead_code_elimination(graph);
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/fold_constants.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants);

} // namespace pnnx

+ 3
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -181,6 +181,7 @@ pnnx_add_test(torch_split)
pnnx_add_test(torch_squeeze)
pnnx_add_test(torch_stack)
pnnx_add_test(torch_transpose)
pnnx_add_test(torch_unbind)
pnnx_add_test(torch_unsqueeze)

pnnx_add_test(mobilenet_v2)
@@ -192,6 +193,8 @@ pnnx_add_test(squeezenet1_1)
# TODO enable end2end quantization model test
#pnnx_add_test(quantization_shufflenet_v2_x1_0)

pnnx_add_test(pnnx_eliminate_noop_math)
pnnx_add_test(pnnx_fold_constant)
pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d)


+ 4
- 4
tools/pnnx/tests/test_Tensor_index.py View File

@@ -39,15 +39,15 @@ def test():

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_Tensor_slice.pt")
mod.save("test_Tensor_index.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_slice.pt inputshape=[3,6],[5,9,2],[2,4,5,10]")
os.system("../src/pnnx test_Tensor_index.pt inputshape=[3,6],[5,9,2],[2,4,5,10]")

# pnnx inference
import test_Tensor_slice_pnnx
b = test_Tensor_slice_pnnx.test_inference()
import test_Tensor_index_pnnx
b = test_Tensor_index_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):


+ 62
- 0
tools/pnnx/tests/test_pnnx_eliminate_noop_math.py View File

@@ -0,0 +1,62 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w0 = nn.Parameter(torch.zeros(1, 12, 52))
self.w1 = nn.Parameter(torch.ones(1, 12, 52))
self.w2 = nn.Parameter(torch.ones(1, 12, 52))

def forward(self, x):
x = x + 0
x = x * 1 / 1
x = 0 + 1 * x
x = x + self.w0 * self.w1
x = x * self.w2
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 52)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_eliminate_noop_math.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_eliminate_noop_math.pt inputshape=[1,12,52]")

# pnnx inference
import test_pnnx_eliminate_noop_math_pnnx
b = test_pnnx_eliminate_noop_math_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 60
- 0
tools/pnnx/tests/test_pnnx_fold_constant.py View File

@@ -0,0 +1,60 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w0 = nn.Parameter(torch.rand(1, 12, 52))
self.w1 = nn.Parameter(torch.rand(1, 12, 52))
self.w2 = nn.Parameter(torch.rand(1, 12, 1))
self.w3 = nn.Parameter(torch.rand(1, 12, 52))

def forward(self, x):
b = (self.w0 + self.w1 + 0.22) + self.w2 * 0.1
x = x + b - self.w3 / 2
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 52)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_fold_constant.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fold_constant.pt inputshape=[1,12,52]")

# pnnx inference
import test_pnnx_fold_constant_pnnx
b = test_pnnx_fold_constant_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 61
- 0
tools/pnnx/tests/test_torch_unbind.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x0, x1, x2 = torch.unbind(x, dim=1)
y0, y1, y2, y3, y4, y5, y6, y7, y8 = torch.unbind(y, dim=2)
z0, z1, z2, z3 = torch.unbind(z, dim=0)
return x0, x1, y0, y1, y2, y3, y4, y5, y6, y7, y8, z0, z1, z2, z3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(4, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_unbind.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_unbind.pt inputshape=[1,3,16],[1,5,9,11],[4,8,5,9,10]")

# pnnx inference
import test_torch_unbind_pnnx
b = test_torch_unbind_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save