Browse Source

pnnx cmdline argument inputshape with type (#3419)

tags/20220216
nihui GitHub 4 years ago
parent
commit
db1fed6e11
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 502 additions and 63 deletions
  1. +1
    -1
      tools/pnnx/README.md
  2. +2
    -0
      tools/pnnx/src/CMakeLists.txt
  3. +80
    -21
      tools/pnnx/src/ir.cpp
  4. +73
    -39
      tools/pnnx/src/main.cpp
  5. +2
    -2
      tools/pnnx/src/pass_level2/F_affine_grid.cpp
  6. +44
    -0
      tools/pnnx/src/pass_level2/F_embedding.cpp
  7. +69
    -0
      tools/pnnx/src/pass_ncnn/F_embedding.cpp
  8. +2
    -0
      tools/pnnx/tests/CMakeLists.txt
  9. +2
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  10. +56
    -0
      tools/pnnx/tests/ncnn/test_F_embedding.py
  11. +56
    -0
      tools/pnnx/tests/ncnn/test_nn_Embedding.py
  12. +59
    -0
      tools/pnnx/tests/test_F_embedding.py
  13. +56
    -0
      tools/pnnx/tests/test_nn_Embedding.py

+ 1
- 1
tools/pnnx/README.md View File

@@ -559,7 +559,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) {
|F.dropout3d | |
|F.elu | :heavy_check_mark: | :heavy_check_mark: |
|F.elu_ | :heavy_check_mark: | :heavy_check_mark: |
|F.embedding | |
|F.embedding | :heavy_check_mark: | :heavy_check_mark: |
|F.embedding_bag | |
|F.feature_alpha_dropout | |
|F.fold | |


+ 2
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -111,6 +111,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_conv3d.cpp
pass_level2/F_conv_transpose123d.cpp
pass_level2/F_elu.cpp
pass_level2/F_embedding.cpp
pass_level2/F_gelu.cpp
pass_level2/F_grid_sample.cpp
pass_level2/F_group_norm.cpp
@@ -246,6 +247,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_conv2d.cpp
pass_ncnn/F_conv3d.cpp
pass_ncnn/F_elu.cpp
pass_ncnn/F_embedding.cpp
pass_ncnn/F_gelu.cpp
pass_ncnn/F_group_norm.cpp
pass_ncnn/F_hardsigmoid.cpp


+ 80
- 21
tools/pnnx/src/ir.cpp View File

@@ -27,6 +27,19 @@

namespace pnnx {

static bool type_is_integer(int type)
{
if (type == 1) return false;
if (type == 2) return false;
if (type == 3) return false;
if (type == 4) return true;
if (type == 5) return true;
if (type == 6) return true;
if (type == 7) return true;
if (type == 8) return true;
return false;
}

static const char* type_to_string(int type)
{
if (type == 1) return "f32";
@@ -53,6 +66,19 @@ static const char* type_to_numpy_string(int type)
return "null";
}

static const char* type_to_dtype_string(int type)
{
if (type == 1) return "torch.float";
if (type == 2) return "torch.double";
if (type == 3) return "torch.half";
if (type == 4) return "torch.int";
if (type == 5) return "torch.long";
if (type == 6) return "torch.short";
if (type == 7) return "torch.int8";
if (type == 8) return "torch.uint8";
return "null";
}

static size_t type_to_elemsize(int type)
{
if (type == 1) return 4;
@@ -1701,15 +1727,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

const Operand* r = op->outputs[0];
std::string input_name = std::string("v_") + sanitize_identifier(r->name);
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());

for (size_t i = 0; i < r->shape.size(); i++)
if (type_is_integer(r->type))
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size())
fprintf(pyfp, ", ");
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
fprintf(pyfp, ")\n");

input_names.push_back(input_name);
}
@@ -1755,15 +1792,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

const Operand* r = op->outputs[0];
std::string input_name = std::string("v_") + sanitize_identifier(r->name);
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());

for (size_t i = 0; i < r->shape.size(); i++)
if (type_is_integer(r->type))
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size())
fprintf(pyfp, ", ");
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
fprintf(pyfp, ")\n");

input_names.push_back(input_name);
}
@@ -2018,15 +2066,26 @@ int Graph::ncnn(const std::string& parampath, const std::string& binpath, const
if (!r)
break;

fprintf(pyfp, " %s = torch.rand(", input_name.c_str());

for (size_t i = 0; i < r->shape.size(); i++)
if (type_is_integer(r->type))
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size())
fprintf(pyfp, ", ");
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
fprintf(pyfp, ")\n");
}

fprintf(pyfp, " out = []\n");


+ 73
- 39
tools/pnnx/src/main.cpp View File

@@ -40,86 +40,112 @@ static std::string get_basename(const std::string& path)
return path.substr(0, path.find_last_of('.'));
}

static std::vector<std::string> parse_comma_string_array_list(char* s)
static void parse_string_list(char* s, std::vector<std::string>& list)
{
std::vector<std::string> as;
list.clear();

char* pch = strtok(s, ",");
while (pch != NULL)
{
as.push_back(std::string(pch));
list.push_back(std::string(pch));

pch = strtok(NULL, ",");
}
}

return as;
static void print_string_list(const std::vector<std::string>& list)
{
for (size_t i = 0; i < list.size(); i++)
{
fprintf(stderr, "%s", list[i].c_str());
if (i + 1 != list.size())
fprintf(stderr, ",");
}
}

static std::vector<std::vector<int64_t> > parse_comma_int_array_list(char* s)
static void parse_shape_list(char* s, std::vector<std::vector<int64_t> >& shapes, std::vector<std::string>& types)
{
std::vector<std::vector<int64_t> > aai;
shapes.clear();
types.clear();

char* pch = strtok(s, "[]");
while (pch != NULL)
{
// assign user data type
if (!types.empty() && (pch[0] == 'f' || pch[0] == 'i' || pch[0] == 'u'))
{
char type[32];
int nscan = sscanf(pch, "%31[^,]", type);
if (nscan == 1)
{
types[types.size() - 1] = std::string(type);
}
}

// parse a,b,c
int v;
int nconsumed = 0;
int nscan = sscanf(pch, "%d%n", &v, &nconsumed);
if (nscan == 1)
{
// ok we get array
// ok we get shape
pch += nconsumed;

std::vector<int64_t> ai;
ai.push_back(v);
std::vector<int64_t> s;
s.push_back(v);

nscan = sscanf(pch, ",%d%n", &v, &nconsumed);
while (nscan == 1)
{
pch += nconsumed;

ai.push_back(v);
s.push_back(v);

nscan = sscanf(pch, ",%d%n", &v, &nconsumed);
}

// array end
aai.push_back(ai);
// shape end
shapes.push_back(s);
types.push_back("f32");
}

pch = strtok(NULL, "[]");
}

return aai;
}

static void print_int64_array_list(const std::vector<std::vector<int64_t> >& list)
static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types)
{
for (size_t i = 0; i < list.size(); i++)
for (size_t i = 0; i < shapes.size(); i++)
{
const std::vector<int64_t>& array = list[i];
const std::vector<int64_t>& s = shapes[i];
const std::string& t = types[i];
fprintf(stderr, "[");
for (size_t j = 0; j < array.size(); j++)
for (size_t j = 0; j < s.size(); j++)
{
fprintf(stderr, "%ld", array[j]);
if (j != array.size() - 1)
fprintf(stderr, "%ld", s[j]);
if (j != s.size() - 1)
fprintf(stderr, ",");
}
fprintf(stderr, "]");
if (i != list.size() - 1)
fprintf(stderr, "%s", t.c_str());
if (i != shapes.size() - 1)
fprintf(stderr, ",");
}
}

static void print_string_list(const std::vector<std::string>& list)
static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t)
{
for (size_t i = 0; i < list.size(); i++)
{
fprintf(stderr, "%s", list[i].c_str());
if (i + 1 != list.size())
fprintf(stderr, ",");
}
if (t == "f32") return torch::kFloat32;
if (t == "f16") return torch::kFloat16;
if (t == "f64") return torch::kFloat64;
if (t == "i32") return torch::kInt32;
if (t == "i16") return torch::kInt16;
if (t == "i64") return torch::kInt64;
if (t == "i8") return torch::kInt8;
if (t == "u8") return torch::kUInt8;

fprintf(stderr, "unsupported type %s fallback to f32\n", t.c_str());
return torch::kFloat32;
}

static void show_usage()
@@ -142,7 +168,7 @@ static void show_usage()
#endif
fprintf(stderr, " moduleop=models.common.Focus,models.yolo.Detect,...\n");
fprintf(stderr, "Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]\n");
fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect\n");
fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640]f32 inputshape2=[1,3,320,320]f32 device=gpu moduleop=models.common.Focus,models.yolo.Detect\n");
}

int main(int argc, char** argv)
@@ -175,7 +201,9 @@ int main(int argc, char** argv)
int optlevel = 2;
std::string device = "cpu";
std::vector<std::vector<int64_t> > input_shapes;
std::vector<std::string> input_types;
std::vector<std::vector<int64_t> > input_shapes2;
std::vector<std::string> input_types2;
std::vector<std::string> customop_modules;
std::vector<std::string> module_operators;

@@ -213,13 +241,13 @@ int main(int argc, char** argv)
if (strcmp(key, "device") == 0)
device = value;
if (strcmp(key, "inputshape") == 0)
input_shapes = parse_comma_int_array_list(value);
parse_shape_list(value, input_shapes, input_types);
if (strcmp(key, "inputshape2") == 0)
input_shapes2 = parse_comma_int_array_list(value);
parse_shape_list(value, input_shapes2, input_types2);
if (strcmp(key, "customop") == 0)
customop_modules = parse_comma_string_array_list(value);
parse_string_list(value, customop_modules);
if (strcmp(key, "moduleop") == 0)
module_operators = parse_comma_string_array_list(value);
parse_string_list(value, module_operators);
}

// print options
@@ -233,10 +261,10 @@ int main(int argc, char** argv)
fprintf(stderr, "optlevel = %d\n", optlevel);
fprintf(stderr, "device = %s\n", device.c_str());
fprintf(stderr, "inputshape = ");
print_int64_array_list(input_shapes);
print_shape_list(input_shapes, input_types);
fprintf(stderr, "\n");
fprintf(stderr, "inputshape2 = ");
print_int64_array_list(input_shapes2);
print_shape_list(input_shapes2, input_types2);
fprintf(stderr, "\n");
fprintf(stderr, "customop = ");
print_string_list(customop_modules);
@@ -268,9 +296,12 @@ int main(int argc, char** argv)
}

std::vector<at::Tensor> input_tensors;
for (auto shape : input_shapes)
for (size_t i = 0; i < input_shapes.size(); i++)
{
at::Tensor t = torch::ones(shape);
const std::vector<int64_t>& shape = input_shapes[i];
const std::string& type = input_types[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

@@ -278,9 +309,12 @@ int main(int argc, char** argv)
}

std::vector<at::Tensor> input_tensors2;
for (auto shape : input_shapes2)
for (size_t i = 0; i < input_shapes2.size(); i++)
{
at::Tensor t = torch::ones(shape);
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();



+ 2
- 2
tools/pnnx/src/pass_level2/F_affine_grid.cpp View File

@@ -23,8 +23,8 @@ public:
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_1 0 1 theta
pnnx.Input input_2 0 1 size
pnnx.Input input_0 0 1 theta
pnnx.Input input_1 0 1 size
prim::Constant op_0 0 1 align_corners value=%align_corners
aten::affine_grid_generator op_1 3 1 theta size align_corners out
pnnx.Output output 1 0 out


+ 44
- 0
tools/pnnx/src/pass_level2/F_embedding.cpp View File

@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class F_embedding : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 weight
prim::Constant op_0 0 1 padding_idx value=*
prim::Constant op_1 0 1 scale_grad_by_freq value=%scale_grad_by_freq
prim::Constant op_2 0 1 sparse value=%sparse
aten::embedding op_3 5 1 weight input padding_idx scale_grad_by_freq sparse out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.embedding";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding, 10)

} // namespace pnnx

+ 69
- 0
tools/pnnx/src/pass_ncnn/F_embedding.cpp View File

@@ -0,0 +1,69 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_embedding : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @qwq
F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Embed";
}

const char* name_str() const
{
return "embed";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
}

op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[0];
op->params["2"] = 0;
op->params["3"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = weight;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_embedding, 20)

} // namespace ncnn

} // namespace pnnx

+ 2
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -24,6 +24,7 @@ pnnx_add_test(F_conv_transpose1d)
pnnx_add_test(F_conv_transpose2d)
pnnx_add_test(F_conv_transpose3d)
pnnx_add_test(F_elu)
pnnx_add_test(F_embedding)
pnnx_add_test(F_gelu)
pnnx_add_test(F_grid_sample)
pnnx_add_test(F_group_norm)
@@ -91,6 +92,7 @@ pnnx_add_test(nn_ConvTranspose1d)
pnnx_add_test(nn_ConvTranspose2d)
pnnx_add_test(nn_ConvTranspose3d)
pnnx_add_test(nn_ELU)
pnnx_add_test(nn_Embedding)
pnnx_add_test(nn_GELU)
pnnx_add_test(nn_GroupNorm)
pnnx_add_test(nn_GRU)


+ 2
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -22,6 +22,7 @@ pnnx_ncnn_add_test(F_conv1d)
pnnx_ncnn_add_test(F_conv2d)
pnnx_ncnn_add_test(F_conv3d)
pnnx_ncnn_add_test(F_elu)
pnnx_ncnn_add_test(F_embedding)
pnnx_ncnn_add_test(F_gelu)
pnnx_ncnn_add_test(F_group_norm)
pnnx_ncnn_add_test(F_hardsigmoid)
@@ -70,6 +71,7 @@ pnnx_ncnn_add_test(nn_Conv2d)
pnnx_ncnn_add_test(nn_Conv3d)
pnnx_ncnn_add_test(nn_ConvTranspose2d)
pnnx_ncnn_add_test(nn_ELU)
pnnx_ncnn_add_test(nn_Embedding)
pnnx_ncnn_add_test(nn_GELU)
pnnx_ncnn_add_test(nn_GroupNorm)
pnnx_ncnn_add_test(nn_GRU)


+ 56
- 0
tools/pnnx/tests/ncnn/test_F_embedding.py View File

@@ -0,0 +1,56 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w1 = nn.Parameter(torch.rand(10, 128))

def forward(self, y):
y = F.embedding(y, self.w1)
return y

def test():
net = Model()
net.eval()

torch.manual_seed(0)
y = torch.randint(10, (1, 11), dtype=torch.int)

a = net(y)

# export torchscript
mod = torch.jit.trace(net, (y))
mod.save("test_F_embedding.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_embedding.pt inputshape=[1,11]i32")

# ncnn inference
import test_F_embedding_ncnn
b = test_F_embedding_ncnn.test_inference()

return torch.allclose(a, b, 1e-4, 1e-4)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 56
- 0
tools/pnnx/tests/ncnn/test_nn_Embedding.py View File

@@ -0,0 +1,56 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10)

def forward(self, x):
x = self.embed_0(x)
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.randint(10, (13,), dtype=torch.int)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_nn_Embedding.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_Embedding.pt inputshape=[13]i32")

# ncnn inference
import test_nn_Embedding_ncnn
b = test_nn_Embedding_ncnn.test_inference()

return torch.allclose(a, b, 1e-4, 1e-4)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 59
- 0
tools/pnnx/tests/test_F_embedding.py View File

@@ -0,0 +1,59 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w1 = nn.Parameter(torch.rand(10, 128))

def forward(self, x, w0, y):
x = F.embedding(x, w0)
y = F.embedding(y, self.w1)
return x, y

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.randint(10, (1, 13), dtype=torch.int)
w0 = torch.rand(10, 128)
y = torch.randint(10, (1, 11), dtype=torch.int)

a0, a1 = net(x, w0, y)

# export torchscript
mod = torch.jit.trace(net, (x, w0, y))
mod.save("test_F_embedding.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_F_embedding.pt inputshape=[1,13]i32,[10,128],[1,11]i32")

# pnnx inference
import test_F_embedding_pnnx
b0, b1 = test_F_embedding_pnnx.test_inference()

return torch.equal(a0, b0) and torch.equal(a1, b1)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 56
- 0
tools/pnnx/tests/test_nn_Embedding.py View File

@@ -0,0 +1,56 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10)

def forward(self, x):
x = self.embed_0(x)
return x

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.randint(10, (1, 13), dtype=torch.int)

a = net(x)

# export torchscript
mod = torch.jit.trace(net, x)
mod.save("test_nn_Embedding.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_nn_Embedding.pt inputshape=[1,13]i32")

# pnnx inference
import test_nn_Embedding_pnnx
b = test_nn_Embedding_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save