Browse Source

pnnx update (#4870)

Tensor.fill
Tensor.index_put
Tensor.to
Tensor.type_as
torch.topk
fmod
call Tensor member functions with inputnames
static shape_as_tensor
nn.Linear dynamic bias
eliminate noop type_as
convert two-dim nn.Linear to ncnn gemm
convert torch.stack to ncnn concat+reshape
ignore torch einsum path input
tags/20230816
nihui GitHub 2 years ago
parent
commit
669ee2f2ff
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1278 additions and 21 deletions
  1. +7
    -0
      tools/pnnx/src/CMakeLists.txt
  2. +35
    -4
      tools/pnnx/src/ir.cpp
  3. +2
    -1
      tools/pnnx/src/pass_level0/shape_inference.cpp
  4. +2
    -2
      tools/pnnx/src/pass_level1/nn_Linear.cpp
  5. +41
    -0
      tools/pnnx/src/pass_level2/Tensor_fill.cpp
  6. +43
    -0
      tools/pnnx/src/pass_level2/Tensor_index_put.cpp
  7. +1
    -1
      tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp
  8. +89
    -0
      tools/pnnx/src/pass_level2/Tensor_to.cpp
  9. +41
    -0
      tools/pnnx/src/pass_level2/Tensor_type_as.cpp
  10. +8
    -1
      tools/pnnx/src/pass_level2/torch_einsum.cpp
  11. +44
    -0
      tools/pnnx/src/pass_level2/torch_topk.cpp
  12. +60
    -3
      tools/pnnx/src/pass_level3/fuse_expression.cpp
  13. +2
    -0
      tools/pnnx/src/pass_level5.cpp
  14. +84
    -0
      tools/pnnx/src/pass_level5/eliminate_type_as.cpp
  15. +21
    -0
      tools/pnnx/src/pass_level5/eliminate_type_as.h
  16. +6
    -0
      tools/pnnx/src/pass_level5/eval_expression.cpp
  17. +6
    -0
      tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp
  18. +2
    -0
      tools/pnnx/src/pass_ncnn.cpp
  19. +91
    -0
      tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp
  20. +25
    -0
      tools/pnnx/src/pass_ncnn/convert_torch_stack.h
  21. +7
    -1
      tools/pnnx/src/pass_ncnn/expand_expression.cpp
  22. +14
    -7
      tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp
  23. +192
    -0
      tools/pnnx/src/pass_ncnn/nn_Linear.cpp
  24. +6
    -0
      tools/pnnx/tests/CMakeLists.txt
  25. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  26. +60
    -0
      tools/pnnx/tests/ncnn/test_torch_stack.py
  27. +57
    -0
      tools/pnnx/tests/test_Tensor_fill.py
  28. +63
    -0
      tools/pnnx/tests/test_Tensor_index_put.py
  29. +63
    -0
      tools/pnnx/tests/test_Tensor_to.py
  30. +65
    -0
      tools/pnnx/tests/test_Tensor_type_as.py
  31. +75
    -0
      tools/pnnx/tests/test_pnnx_expression.py
  32. +4
    -1
      tools/pnnx/tests/test_torch_einsum.py
  33. +61
    -0
      tools/pnnx/tests/test_torch_topk.py

+ 7
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -180,7 +180,9 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_copy.cpp
pass_level2/Tensor_expand.cpp
pass_level2/Tensor_expand_as.cpp
pass_level2/Tensor_fill.cpp
pass_level2/Tensor_index.cpp
pass_level2/Tensor_index_put.cpp
pass_level2/Tensor_masked_fill.cpp
pass_level2/Tensor_new_empty.cpp
pass_level2/Tensor_new_ones.cpp
@@ -189,6 +191,8 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_reshape.cpp
pass_level2/Tensor_select.cpp
pass_level2/Tensor_slice.cpp
pass_level2/Tensor_to.cpp
pass_level2/Tensor_type_as.cpp
pass_level2/Tensor_view.cpp
pass_level2/torch_addmm.cpp
pass_level2/torch_amax.cpp
@@ -252,6 +256,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_topk.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
@@ -320,6 +325,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eliminate_reshape_shape_expression.cpp
pass_level5/eliminate_type_as.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp
@@ -361,6 +367,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/convert_torch_chunk.cpp
pass_ncnn/convert_torch_einsum.cpp
pass_ncnn/convert_torch_split.cpp
pass_ncnn/convert_torch_stack.cpp
pass_ncnn/convert_torch_tensor_split.cpp
pass_ncnn/convert_torch_unbind.cpp
pass_ncnn/convert_Tensor_select.cpp


+ 35
- 4
tools/pnnx/src/ir.cpp View File

@@ -1297,10 +1297,12 @@ static std::string expand_expression(const Operator* op)
exprstack.push(r);
}
else if (t == "atan2"
|| t == "fmod"
|| t == "pow")
{
std::string binaryop;
if (t == "atan2") binaryop = "torch.atan2";
if (t == "fmod") binaryop = "torch.fmod";
if (t == "pow") binaryop = "torch.pow";

std::string a = exprstack.top();
@@ -1311,7 +1313,7 @@ static std::string expand_expression(const Operator* op)
std::string r = binaryop + "(" + a + ", " + b + ")";
exprstack.push(r);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
{
std::string binaryop;
if (t == "add") binaryop = "+";
@@ -1319,6 +1321,7 @@ static std::string expand_expression(const Operator* op)
if (t == "mul") binaryop = "*";
if (t == "div") binaryop = "/";
if (t == "floor_divide") binaryop = "//";
if (t == "remainder") binaryop = "%";
if (t == "and") binaryop = "&";
if (t == "or") binaryop = "|";
if (t == "xor") binaryop = "^";
@@ -2152,11 +2155,39 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

if (op->type.substr(0, 7) == "Tensor.")
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
if (op->type == "Tensor.fill")
{
fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str());
}
else
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
}

if (op->inputnames.size() == op->inputs.size())
{
for (size_t i = 1; i < op->inputs.size(); i++)
{
if (!op->inputnames[i].empty())
continue;

for (size_t i = 1; i < op->inputs.size(); i++)
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}

for (size_t i = 1; i < op->inputs.size(); i++)
{
if (op->inputnames[i].empty())
continue;

fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str());
}
}
else
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
for (size_t i = 1; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}
}
}
else


+ 2
- 1
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -39,7 +39,8 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch
|| optype == "aten::empty_like"
|| optype == "aten::full_like"
|| optype == "aten::ones_like"
|| optype == "aten::zeros_like")
|| optype == "aten::zeros_like"
|| optype == "aten::_shape_as_tensor")
return false;
}



+ 2
- 2
tools/pnnx/src/pass_level1/nn_Linear.cpp View File

@@ -39,10 +39,10 @@ public:

op->params["in_features"] = weight.size(1);
op->params["out_features"] = weight.size(0);
op->params["bias"] = mod.hasattr("bias");
op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor();

op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
if (mod.hasattr("bias") && mod.attr("bias").isTensor())
{
op->attrs["bias"] = mod.attr("bias").toTensor();
}


+ 41
- 0
tools/pnnx/src/pass_level2/Tensor_fill.cpp View File

@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_fill : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 value
aten::fill op_0 2 1 input value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.fill";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20)

} // namespace pnnx

+ 43
- 0
tools/pnnx/src/pass_level2/Tensor_index_put.cpp View File

@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_index_put : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 indices
pnnx.Input input_2 0 1 values
prim::Constant op_0 0 1 accumulate value=%accumulate
aten::index_put op_1 4 1 input indices values accumulate out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.index_put";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20)

} // namespace pnnx

+ 1
- 1
tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp View File

@@ -26,7 +26,7 @@ public:
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mask
pnnx.Input input_2 0 1 value
aten::masked_fill op_1 3 1 input mask value out
aten::masked_fill op_0 3 1 input mask value out
pnnx.Output output 1 0 out
)PNNXIR";
}


+ 89
- 0
tools/pnnx/src/pass_level2/Tensor_to.cpp View File

@@ -0,0 +1,89 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_to : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 dtype value=%dtype
prim::Constant op_1 0 1 non_blocking value=*
prim::Constant op_2 0 1 copy value=%copy
prim::Constant op_3 0 1 memory_format value=%memory_format
aten::to op_4 5 1 input dtype non_blocking copy memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.to";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";

op->params["copy"] = captured_params.at("copy");

if (captured_params.at("memory_format").i == 0)
op->params["memory_format"] = "torch.contiguous_format";
if (captured_params.at("memory_format").i == 1)
op->params["memory_format"] = "torch.preserve_format";
if (captured_params.at("memory_format").i == 2)
op->params["memory_format"] = "torch.channels_last";
}
};

class Tensor_to_1 : public Tensor_to
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 device value=*
prim::Constant op_1 0 1 dtype value=%dtype
prim::Constant op_2 0 1 non_blocking value=*
prim::Constant op_3 0 1 copy value=%copy
prim::Constant op_4 0 1 memory_format value=%memory_format
aten::to op_5 6 1 input device dtype non_blocking copy memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20)

} // namespace pnnx

+ 41
- 0
tools/pnnx/src/pass_level2/Tensor_type_as.cpp View File

@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_type_as : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
aten::type_as op_0 2 1 input other out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.type_as";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20)

} // namespace pnnx

+ 8
- 1
tools/pnnx/src/pass_level2/torch_einsum.cpp View File

@@ -47,7 +47,7 @@ public:
5 4
pnnx.Input input_0 0 1 equation
pnnx.Input input_1 0 1 operands
prim::Constant op_0 0 1 path value=None
pnnx.Input input_2 0 1 path
aten::einsum op_1 3 1 equation operands path out
pnnx.Output output 1 0 out
)PNNXIR";
@@ -57,6 +57,13 @@ pnnx.Output output 1 0 out
{
return "torch.einsum";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
// drop path input
op->inputs[2]->remove_consumer(op);
op->inputs.resize(2);
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum_1, 20)


+ 44
- 0
tools/pnnx/src/pass_level2/torch_topk.cpp View File

@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_topk : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 7
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 k
pnnx.Input input_2 0 1 dim
pnnx.Input input_3 0 1 largest
pnnx.Input input_4 0 1 sorted
aten::topk op_0 5 2 input k dim largest sorted values indices
pnnx.Output output 2 0 values indices
)PNNXIR";
}

const char* type_str() const
{
return "torch.topk";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_topk, 20)

} // namespace pnnx

+ 60
- 3
tools/pnnx/src/pass_level3/fuse_expression.cpp View File

@@ -100,6 +100,7 @@ static bool operand_maybe_tensor(const Operand* operand)
if (op->type == "aten::atan2"
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
@@ -363,7 +364,35 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
expr += ")";
}
else if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
else if (op->type == "Tensor.to")
{
bool noop_type_cast = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type);
if (noop_type_cast)
{
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
if (it == inputs.end())
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)inputs.size());
expr += tmp;

inputs.push_back(operand);
}
else
{
// tensor
char tmp[32];
sprintf(tmp, "@%d", (int)(it - inputs.begin()));
expr += tmp;
}
}
}
else if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
{
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
}
@@ -402,8 +431,8 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
expr += ")";
}
else if (op->type == "aten::atan2"
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::mul"
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
@@ -484,6 +513,27 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
expr += ")";
}
else if (op->type == "aten::div")
{
std::string rounding_mode;
if (op->inputs.size() == 3)
fuse_expression(graph, op->inputs[2], rounding_mode, inputs, foldable_constants, zip);

if (rounding_mode == "trunc")
{
expr += "floor_divide";
}
else
{
expr += "div";
}

expr += "(";
fuse_expression(graph, op->inputs[0], expr, inputs, foldable_constants, zip);
expr += ",";
fuse_expression(graph, op->inputs[1], expr, inputs, foldable_constants, zip);
expr += ")";
}
else
{
auto it = std::find(inputs.begin(), inputs.end(), operand);
@@ -542,7 +592,13 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
{
need_fuse = true;
}
if (op->type == "aten::to" || op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
if (op->type == "Tensor.to")
{
// fuse noop type cast only
bool noop_to = (op->outputs[0]->type != -1) && (op->inputs[0]->type == op->outputs[0]->type);
need_fuse = noop_to;
}
if (op->type == "aten::detach" || op->type == "aten::ScalarImplicit")
{
need_fuse = true;
}
@@ -562,6 +618,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::floor_divide"
|| op->type == "aten::fmod"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::mul"


+ 2
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -27,6 +27,7 @@
#include "pass_level5/eliminate_noop_slice.h"
#include "pass_level5/eliminate_noop_view_reshape.h"
#include "pass_level5/eliminate_reshape_shape_expression.h"
#include "pass_level5/eliminate_type_as.h"
#include "pass_level5/eval_expression.h"
#include "pass_level5/fuse_adjacent_reshape.h"
#include "pass_level5/fuse_channel_shuffle.h"
@@ -112,6 +113,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
eliminate_noop_cat(g);

eliminate_dropout(g);
eliminate_type_as(g);

eliminate_noop_upsample(g);



+ 84
- 0
tools/pnnx/src/pass_level5/eliminate_type_as.cpp View File

@@ -0,0 +1,84 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "eliminate_type_as.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void eliminate_type_as(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "Tensor.type_as")
continue;

if (op->inputs[0]->type == 0 || op->outputs[0]->type == 0)
continue;

if (op->inputs[0]->type != op->outputs[0]->type)
continue;

// delete noop-like type_as
matched = true;

for (auto& x : op->inputs)
{
x->remove_consumer(op);
}

Operand* type_as_out = op->outputs[0];

for (auto& x : type_as_out->consumers)
{
for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == type_as_out)
x->inputs[j] = op->inputs[0];
}

op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = type_as_out->name;

type_as_out->producer = 0;
type_as_out->consumers.clear();

graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), type_as_out));
delete type_as_out;

op->inputs.clear();
op->outputs.clear();

graph.ops.erase(graph.ops.begin() + i);
delete op;

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/eliminate_type_as.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void eliminate_type_as(Graph& graph);

} // namespace pnnx

+ 6
- 0
tools/pnnx/src/pass_level5/eval_expression.cpp View File

@@ -342,6 +342,7 @@ static std::string eval_expression(const Operator* op)
|| t == "mul"
|| t == "div"
|| t == "floor_divide"
|| t == "fmod"
|| t == "pow"
|| t == "remainder")
{
@@ -380,6 +381,11 @@ static std::string eval_expression(const Operator* op)
float r = af / bf;
exprstack.push(std::to_string(r));
}
if (t == "fmod")
{
float r = fmod(af, bf);
exprstack.push(std::to_string(r));
}
if (t == "floor_divide")
{
int r = (int)af / (int)bf;


+ 6
- 0
tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp View File

@@ -37,6 +37,12 @@ void fuse_select_to_unbind(Graph& graph)
if (input_rank == 0)
continue;

if (input_rank == 1)
{
// skip select scalar
continue;
}

int dim = op->params.at("dim").i;
const int select_dimsize = op_in->shape[dim];



+ 2
- 0
tools/pnnx/src/pass_ncnn.cpp View File

@@ -22,6 +22,7 @@
#include "pass_ncnn/convert_torch_chunk.h"
#include "pass_ncnn/convert_torch_einsum.h"
#include "pass_ncnn/convert_torch_split.h"
#include "pass_ncnn/convert_torch_stack.h"
#include "pass_ncnn/convert_torch_tensor_split.h"
#include "pass_ncnn/convert_torch_unbind.h"
#include "pass_ncnn/convert_Tensor_select.h"
@@ -96,6 +97,7 @@ void pass_ncnn(Graph& g)

ncnn::convert_torch_cat(g);
ncnn::convert_torch_chunk(g);
ncnn::convert_torch_stack(g);
ncnn::convert_torch_split(g);
ncnn::convert_torch_unbind(g);
ncnn::convert_torch_tensor_split(g);


+ 91
- 0
tools/pnnx/src/pass_ncnn/convert_torch_stack.cpp View File

@@ -0,0 +1,91 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "convert_torch_stack.h"

namespace pnnx {

namespace ncnn {

void convert_torch_stack(Graph& graph)
{
int op_index = 0;

while (1)
{
bool matched = false;

for (Operator* op : graph.ops)
{
if (op->type != "torch.stack")
continue;

matched = true;

op->type = "Concat";
op->name = std::string("stack_") + std::to_string(op_index++);

const int batch_index = op->inputs[0]->params["__batch_index"].i;

int axis = op->params.at("dim").i;
if (axis == batch_index)
{
fprintf(stderr, "stack along batch axis %d is not supported\n", batch_index);
continue;
}

if (axis < 0)
{
int input_rank = op->inputs[0]->shape.size();
axis = input_rank + axis;
}

if (axis > batch_index)
axis -= 1;

op->params["0"] = axis;

op->params.erase("dim");

// reshape for output, expand the stack dim
{
Operand* out = op->outputs[0];

Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op);

Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in");

reshape->inputs.push_back(reshape_in);
reshape->outputs.push_back(out);

op->outputs[0] = reshape_in;

out->producer = reshape;
reshape_in->producer = op;
reshape_in->consumers.push_back(reshape);

reshape->params["shape"] = out->shape;
}

break;
}

if (!matched)
break;
}
}

} // namespace ncnn

} // namespace pnnx

+ 25
- 0
tools/pnnx/src/pass_ncnn/convert_torch_stack.h View File

@@ -0,0 +1,25 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

void convert_torch_stack(Graph& graph);

} // namespace ncnn

} // namespace pnnx

+ 7
- 1
tools/pnnx/src/pass_ncnn/expand_expression.cpp View File

@@ -178,7 +178,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
op_unary->inputs.push_back(op_unary_in);
op_unary->outputs.push_back(op_unary_out);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow" || t == "atan2")
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "fmod" || t == "remainder" || t == "pow" || t == "atan2")
{
std::string a = exprstack.top();
exprstack.pop();
@@ -190,10 +190,16 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx

Operator* op_binary = graph.new_operator_before("BinaryOp", t + "_" + std::to_string(pnnx_expr_index++), op);

// default todo type mark :[
op_binary->params["0"] = -1;

if (t == "add") op_binary->params["0"] = 0;
if (t == "sub") op_binary->params["0"] = 1;
if (t == "mul") op_binary->params["0"] = 2;
if (t == "div") op_binary->params["0"] = 3;
if (t == "floor_divide") fprintf(stderr, "BinaryOp floor_divide not supported yet\n"); // TODO
if (t == "fmod") fprintf(stderr, "BinaryOp fmod not supported yet\n"); // TODO
if (t == "remainder") fprintf(stderr, "BinaryOp remainder not supported yet\n"); // TODO
if (t == "pow") op_binary->params["0"] = 6;
if (t == "atan2") op_binary->params["0"] = 10;



+ 14
- 7
tools/pnnx/src/pass_ncnn/insert_reshape_linear.cpp View File

@@ -94,19 +94,26 @@ void insert_reshape_linear(Graph& graph)
reshape_h *= linear_in->shape[j];
}

std::vector<int> reshape0_shape;
std::vector<int> reshape0_out_shape;
std::vector<int> reshape1_in_shape;
if (batch_index == 0 && batch_index != 233)
{
reshape0_shape = {1, reshape_h, linear_in->shape[input_rank - 1]};
reshape0_out_shape = {1, reshape_h, linear_in->shape[input_rank - 1]};
reshape1_in_shape = {1, reshape_h, linear_out->shape[input_rank - 1]};
}
else
{
reshape0_shape = {reshape_h, linear_in->shape[input_rank - 1]};
reshape0_out_shape = {reshape_h, linear_in->shape[input_rank - 1]};
reshape1_in_shape = {reshape_h, linear_out->shape[input_rank - 1]};
}
std::vector<int> reshape1_shape = linear_out->shape;

reshape0->params["shape"] = reshape0_shape;
reshape1->params["shape"] = reshape1_shape;
std::vector<int> reshape1_out_shape = linear_out->shape;

reshape0->params["shape"] = reshape0_out_shape;
reshape1->params["shape"] = reshape1_out_shape;
reshape0_out->type = linear_in->type;
reshape0_out->shape = reshape0_out_shape;
reshape1_in->type = linear_out->type;
reshape1_in->shape = reshape1_in_shape;

break;
}


+ 192
- 0
tools/pnnx/src/pass_ncnn/nn_Linear.cpp View File

@@ -18,6 +18,152 @@ namespace pnnx {

namespace ncnn {

class nn_Linear_0 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input #input=(1,%m,%in_features)f32
nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Gemm";
}

const char* name_str() const
{
return "gemm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["2"] = 0;
op->params["3"] = 1;
op->params["4"] = 0;
op->params["5"] = 1;
op->params["6"] = 1;
op->params["7"] = captured_params.at("m");
op->params["8"] = captured_params.at("out_features");
op->params["9"] = captured_params.at("in_features");
op->params["10"] = captured_params.at("bias").b ? 4 : -1;

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = captured_attrs.at("op_0.weight");
if (captured_params.at("bias").b)
{
op->attrs["2"] = Attribute();
op->attrs["2"].data = {0, 0, 0, 0};
op->attrs["3"] = captured_attrs.at("op_0.bias");
}
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_0, 19)

class nn_Linear_01 : public nn_Linear_0
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input #input=(%m,%in_features)f32
nn.Linear op_0 1 1 input out in_features=%in_features out_features=%out_features bias=%bias
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int m = captured_params.at("m").i;

if (m == 1)
return false;

return true;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_01, 19)

class nn_Linear_10 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input #input=(1,%m,%in_features)f32
pnnx.Input input_1 0 1 bias
nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Gemm";
}

const char* name_str() const
{
return "gemm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["2"] = 0;
op->params["3"] = 1;
op->params["4"] = 0;
op->params["5"] = 1;
op->params["6"] = 0;
op->params["7"] = captured_params.at("m");
op->params["8"] = captured_params.at("out_features");
op->params["9"] = captured_params.at("in_features");
op->params["10"] = 4;

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = captured_attrs.at("op_0.weight");
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_10, 19)

class nn_Linear_11 : public nn_Linear_10
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input #input=(%m,%in_features)f32
pnnx.Input input_1 0 1 bias
nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int m = captured_params.at("m").i;

if (m == 1)
return false;

return true;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_11, 19)

class nn_Linear : public GraphRewriterPass
{
public:
@@ -57,6 +203,52 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear, 20)

class nn_Linear_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 bias
nn.Linear op_0 2 1 input bias out in_features=%in_features out_features=%out_features bias=False
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 bias
InnerProduct linear 1 1 input a
BinaryOp bias 2 1 a bias out 0=0
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(ops, captured_params, captured_attrs);

const int batch_index = ops.at("linear")->inputs[0]->params["__batch_index"].i;

ops.at("linear")->params["0"] = captured_params.at("out_features");
ops.at("linear")->params["1"] = 0;
ops.at("linear")->params["2"] = captured_attrs.at("op_0.weight").elemcount();

ops.at("linear")->attrs["0"] = Attribute();
ops.at("linear")->attrs["0"].data = {0, 0, 0, 0};
ops.at("linear")->attrs["1"] = captured_attrs.at("op_0.weight");

ops.at("linear")->outputs[0]->params["__batch_index"] = batch_index;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Linear_1, 20)

} // namespace ncnn

} // namespace pnnx

+ 6
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -162,7 +162,9 @@ pnnx_add_test(nn_ZeroPad2d)

pnnx_add_test(Tensor_contiguous)
pnnx_add_test(Tensor_expand)
pnnx_add_test(Tensor_fill)
pnnx_add_test(Tensor_index)
pnnx_add_test(Tensor_index_put)
pnnx_add_test(Tensor_masked_fill)
pnnx_add_test(Tensor_new_empty)
pnnx_add_test(Tensor_new_full)
@@ -173,6 +175,8 @@ pnnx_add_test(Tensor_reshape)
pnnx_add_test(Tensor_select)
pnnx_add_test(Tensor_slice)
pnnx_add_test(Tensor_slice_copy)
pnnx_add_test(Tensor_to)
pnnx_add_test(Tensor_type_as)
pnnx_add_test(Tensor_view)

pnnx_add_test(torch_addmm)
@@ -221,6 +225,7 @@ pnnx_add_test(torch_squeeze)
pnnx_add_test(torch_stack)
pnnx_add_test(torch_std)
pnnx_add_test(torch_tensor_split)
pnnx_add_test(torch_topk)
pnnx_add_test(torch_transpose)
pnnx_add_test(torch_unbind)
pnnx_add_test(torch_unsqueeze)
@@ -295,6 +300,7 @@ pnnx_add_test(pnnx_eliminate_noop_cat)
pnnx_add_test(pnnx_eliminate_noop_expand)
pnnx_add_test(pnnx_eliminate_noop_math)
pnnx_add_test(pnnx_eliminate_noop_upsample)
pnnx_add_test(pnnx_expression)
pnnx_add_test(pnnx_fold_constant)
pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -154,6 +154,7 @@ pnnx_ncnn_add_test(torch_permute)
pnnx_ncnn_add_test(torch_prod)
pnnx_ncnn_add_test(torch_sum)
pnnx_ncnn_add_test(torch_squeeze)
pnnx_ncnn_add_test(torch_stack)
pnnx_ncnn_add_test(torch_tensor_split)
pnnx_ncnn_add_test(torch_transpose)
pnnx_ncnn_add_test(torch_unbind)


+ 60
- 0
tools/pnnx/tests/ncnn/test_torch_stack.py View File

@@ -0,0 +1,60 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
out0 = torch.stack((x, y), dim=0)
out1 = torch.stack((z, w), dim=2)
out0.relu_()
out1.relu_()
return out0, out1

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(3, 16)
z = torch.rand(5, 9, 3)
w = torch.rand(5, 9, 3)

a0, a1 = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_torch_stack.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_stack.pt inputshape=[3,16],[3,16],[5,9,3],[5,9,3]")

# ncnn inference
import test_torch_stack_ncnn
b0, b1 = test_torch_stack_ncnn.test_inference()

return torch.equal(a0, b0) and torch.equal(a1, b1)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 57
- 0
tools/pnnx/tests/test_Tensor_fill.py View File

@@ -0,0 +1,57 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x[:2,:].fill_(z[0])
y[:1,:].fill_(0.22)
return x + y.fill_(7)

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(6, 16)
y = torch.rand(6, 16)
z = torch.rand(1)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_Tensor_fill.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_fill.pt inputshape=[6,16],[6,16],[1]")

# pnnx inference
import test_Tensor_fill_pnnx
b = test_Tensor_fill_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 63
- 0
tools/pnnx/tests/test_Tensor_index_put.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w):
x = x.clone()
z = z.clone()
x = x.index_put(indices=[torch.tensor([10,2])], values=y, accumulate=False)
z.index_put_(indices=[torch.tensor([1,0,0]), torch.tensor([3,2,1])], values=w, accumulate=True)
return x, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(12)
y = torch.rand(2)
z = torch.rand(6,9)
w = torch.rand(3)

a = net(x, y, z, w)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w))
mod.save("test_Tensor_index_put.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_index_put.pt inputshape=[12],[2],[6,9],[3]")

# pnnx inference
import test_Tensor_index_put_pnnx
b = test_Tensor_index_put_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 63
- 0
tools/pnnx/tests/test_Tensor_to.py View File

@@ -0,0 +1,63 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
x = x * 10
y = y * 13
y = y.to(dtype=x.dtype, memory_format=torch.contiguous_format)
x = x.to(device='cpu', dtype=torch.int, copy=True)
x = x + 1
y = y - 2
return x, y

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.randint(10, (1, 13), dtype=torch.int)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_Tensor_to.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_to.pt inputshape=[3,16],[1,13]i32")

# pnnx inference
import test_Tensor_to_pnnx
b = test_Tensor_to_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 65
- 0
tools/pnnx/tests/test_Tensor_type_as.py View File

@@ -0,0 +1,65 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = x * 100
z = z * 200
x = x.type_as(y)
x = F.relu(x)
x = x.type_as(z)
z = F.relu(z)
z = z.type_as(x)
return x, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.randint(10, (1, 13), dtype=torch.int)
z = torch.rand(8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_Tensor_type_as.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_type_as.pt inputshape=[3,16],[1,13]i32,[8,5,9,10]")

# pnnx inference
import test_Tensor_type_as_pnnx
b = test_Tensor_type_as_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 75
- 0
tools/pnnx/tests/test_pnnx_expression.py View File

@@ -0,0 +1,75 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w0 = nn.Parameter(torch.rand(12, 15))
self.w1 = nn.Parameter(torch.rand(12, 15))
self.w2 = nn.Parameter(torch.rand(12, 15))
self.w3 = nn.Parameter(torch.rand(12, 15))
self.w4 = nn.Parameter(torch.rand(12, 15))
self.w5 = nn.Parameter(torch.rand(12, 15))

def forward(self, x):
x0 = x * 10
x = x + self.w0 + x0
x = x - self.w1 + x0.float()
x = x * self.w2 + x0
x = x / self.w3 + x0
x = x // self.w4 + x0
if version.parse(torch.__version__) >= version.parse('2.0'):
x = x % self.w5 + x0
else:
x = torch.fmod(x, self.w5) + x0
y = x.int()
return x, y & 3, y | 3, y ^ 3, y << 3, y >> 3

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(12, 15)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_pnnx_expression.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_expression.pt inputshape=[12,15]")

# pnnx inference
import test_pnnx_expression_pnnx
b = test_pnnx_expression_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 4
- 1
tools/pnnx/tests/test_torch_einsum.py View File

@@ -148,7 +148,10 @@ def test():
b = test_torch_einsum_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
# allclose may auto broadcast compare
if a0.shape != b0.shape:
return False
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True



+ 61
- 0
tools/pnnx/tests/test_torch_topk.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x, _ = torch.topk(x, 4)
y, _ = torch.topk(y, k=1, dim=2, largest=False)
z, indices = torch.topk(z, k=3, dim=-1, sorted=False)
return x, y, z, indices

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_topk.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_topk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")

# pnnx inference
import test_torch_topk_pnnx
b = test_torch_topk_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save