From db1fed6e115c409ca1ae5b26c8faf849e2ee1c6d Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 14 Dec 2021 18:58:23 +0800 Subject: [PATCH] pnnx cmdline argument inputshape with type (#3419) --- tools/pnnx/README.md | 2 +- tools/pnnx/src/CMakeLists.txt | 2 + tools/pnnx/src/ir.cpp | 101 +++++++++++++---- tools/pnnx/src/main.cpp | 112 ++++++++++++------- tools/pnnx/src/pass_level2/F_affine_grid.cpp | 4 +- tools/pnnx/src/pass_level2/F_embedding.cpp | 44 ++++++++ tools/pnnx/src/pass_ncnn/F_embedding.cpp | 69 ++++++++++++ tools/pnnx/tests/CMakeLists.txt | 2 + tools/pnnx/tests/ncnn/CMakeLists.txt | 2 + tools/pnnx/tests/ncnn/test_F_embedding.py | 56 ++++++++++ tools/pnnx/tests/ncnn/test_nn_Embedding.py | 56 ++++++++++ tools/pnnx/tests/test_F_embedding.py | 59 ++++++++++ tools/pnnx/tests/test_nn_Embedding.py | 56 ++++++++++ 13 files changed, 502 insertions(+), 63 deletions(-) create mode 100644 tools/pnnx/src/pass_level2/F_embedding.cpp create mode 100644 tools/pnnx/src/pass_ncnn/F_embedding.cpp create mode 100644 tools/pnnx/tests/ncnn/test_F_embedding.py create mode 100644 tools/pnnx/tests/ncnn/test_nn_Embedding.py create mode 100644 tools/pnnx/tests/test_F_embedding.py create mode 100644 tools/pnnx/tests/test_nn_Embedding.py diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md index 5267e349f..0591091a7 100644 --- a/tools/pnnx/README.md +++ b/tools/pnnx/README.md @@ -559,7 +559,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) { |F.dropout3d | | |F.elu | :heavy_check_mark: | :heavy_check_mark: | |F.elu_ | :heavy_check_mark: | :heavy_check_mark: | -|F.embedding | | +|F.embedding | :heavy_check_mark: | :heavy_check_mark: | |F.embedding_bag | | |F.feature_alpha_dropout | | |F.fold | | diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index f6ce3c810..fb33af1fb 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -111,6 +111,7 @@ set(pnnx_pass_level2_SRCS pass_level2/F_conv3d.cpp pass_level2/F_conv_transpose123d.cpp pass_level2/F_elu.cpp + pass_level2/F_embedding.cpp pass_level2/F_gelu.cpp pass_level2/F_grid_sample.cpp pass_level2/F_group_norm.cpp @@ -246,6 +247,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_conv2d.cpp pass_ncnn/F_conv3d.cpp pass_ncnn/F_elu.cpp + pass_ncnn/F_embedding.cpp pass_ncnn/F_gelu.cpp pass_ncnn/F_group_norm.cpp pass_ncnn/F_hardsigmoid.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 0af523d09..1a37baf5f 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -27,6 +27,19 @@ namespace pnnx { +static bool type_is_integer(int type) +{ + if (type == 1) return false; + if (type == 2) return false; + if (type == 3) return false; + if (type == 4) return true; + if (type == 5) return true; + if (type == 6) return true; + if (type == 7) return true; + if (type == 8) return true; + return false; +} + static const char* type_to_string(int type) { if (type == 1) return "f32"; @@ -53,6 +66,19 @@ static const char* type_to_numpy_string(int type) return "null"; } +static const char* type_to_dtype_string(int type) +{ + if (type == 1) return "torch.float"; + if (type == 2) return "torch.double"; + if (type == 3) return "torch.half"; + if (type == 4) return "torch.int"; + if (type == 5) return "torch.long"; + if (type == 6) return "torch.short"; + if (type == 7) return "torch.int8"; + if (type == 8) return "torch.uint8"; + return "null"; +} + static size_t type_to_elemsize(int type) { if (type == 1) return 4; @@ -1701,15 +1727,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - - for (size_t i = 0; i < r->shape.size(); i++) + if (type_is_integer(r->type)) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size()) - fprintf(pyfp, ", "); + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } - fprintf(pyfp, ")\n"); input_names.push_back(input_name); } @@ -1755,15 +1792,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) const Operand* r = op->outputs[0]; std::string input_name = std::string("v_") + sanitize_identifier(r->name); - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - - for (size_t i = 0; i < r->shape.size(); i++) + if (type_is_integer(r->type)) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size()) - fprintf(pyfp, ", "); + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } - fprintf(pyfp, ")\n"); input_names.push_back(input_name); } @@ -2018,15 +2066,26 @@ int Graph::ncnn(const std::string& parampath, const std::string& binpath, const if (!r) break; - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - - for (size_t i = 0; i < r->shape.size(); i++) + if (type_is_integer(r->type)) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size()) - fprintf(pyfp, ", "); + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); } - fprintf(pyfp, ")\n"); } fprintf(pyfp, " out = []\n"); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index b87d8ea25..01e1166c1 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -40,86 +40,112 @@ static std::string get_basename(const std::string& path) return path.substr(0, path.find_last_of('.')); } -static std::vector parse_comma_string_array_list(char* s) +static void parse_string_list(char* s, std::vector& list) { - std::vector as; + list.clear(); char* pch = strtok(s, ","); while (pch != NULL) { - as.push_back(std::string(pch)); + list.push_back(std::string(pch)); pch = strtok(NULL, ","); } +} - return as; +static void print_string_list(const std::vector& list) +{ + for (size_t i = 0; i < list.size(); i++) + { + fprintf(stderr, "%s", list[i].c_str()); + if (i + 1 != list.size()) + fprintf(stderr, ","); + } } -static std::vector > parse_comma_int_array_list(char* s) +static void parse_shape_list(char* s, std::vector >& shapes, std::vector& types) { - std::vector > aai; + shapes.clear(); + types.clear(); char* pch = strtok(s, "[]"); while (pch != NULL) { + // assign user data type + if (!types.empty() && (pch[0] == 'f' || pch[0] == 'i' || pch[0] == 'u')) + { + char type[32]; + int nscan = sscanf(pch, "%31[^,]", type); + if (nscan == 1) + { + types[types.size() - 1] = std::string(type); + } + } + // parse a,b,c int v; int nconsumed = 0; int nscan = sscanf(pch, "%d%n", &v, &nconsumed); if (nscan == 1) { - // ok we get array + // ok we get shape pch += nconsumed; - std::vector ai; - ai.push_back(v); + std::vector s; + s.push_back(v); nscan = sscanf(pch, ",%d%n", &v, &nconsumed); while (nscan == 1) { pch += nconsumed; - ai.push_back(v); + s.push_back(v); nscan = sscanf(pch, ",%d%n", &v, &nconsumed); } - // array end - aai.push_back(ai); + // shape end + shapes.push_back(s); + types.push_back("f32"); } pch = strtok(NULL, "[]"); } - - return aai; } -static void print_int64_array_list(const std::vector >& list) +static void print_shape_list(const std::vector >& shapes, const std::vector& types) { - for (size_t i = 0; i < list.size(); i++) + for (size_t i = 0; i < shapes.size(); i++) { - const std::vector& array = list[i]; + const std::vector& s = shapes[i]; + const std::string& t = types[i]; fprintf(stderr, "["); - for (size_t j = 0; j < array.size(); j++) + for (size_t j = 0; j < s.size(); j++) { - fprintf(stderr, "%ld", array[j]); - if (j != array.size() - 1) + fprintf(stderr, "%ld", s[j]); + if (j != s.size() - 1) fprintf(stderr, ","); } fprintf(stderr, "]"); - if (i != list.size() - 1) + fprintf(stderr, "%s", t.c_str()); + if (i != shapes.size() - 1) fprintf(stderr, ","); } } -static void print_string_list(const std::vector& list) +static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t) { - for (size_t i = 0; i < list.size(); i++) - { - fprintf(stderr, "%s", list[i].c_str()); - if (i + 1 != list.size()) - fprintf(stderr, ","); - } + if (t == "f32") return torch::kFloat32; + if (t == "f16") return torch::kFloat16; + if (t == "f64") return torch::kFloat64; + if (t == "i32") return torch::kInt32; + if (t == "i16") return torch::kInt16; + if (t == "i64") return torch::kInt64; + if (t == "i8") return torch::kInt8; + if (t == "u8") return torch::kUInt8; + + fprintf(stderr, "unsupported type %s fallback to f32\n", t.c_str()); + return torch::kFloat32; } static void show_usage() @@ -142,7 +168,7 @@ static void show_usage() #endif fprintf(stderr, " moduleop=models.common.Focus,models.yolo.Detect,...\n"); fprintf(stderr, "Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]\n"); - fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect\n"); + fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640]f32 inputshape2=[1,3,320,320]f32 device=gpu moduleop=models.common.Focus,models.yolo.Detect\n"); } int main(int argc, char** argv) @@ -175,7 +201,9 @@ int main(int argc, char** argv) int optlevel = 2; std::string device = "cpu"; std::vector > input_shapes; + std::vector input_types; std::vector > input_shapes2; + std::vector input_types2; std::vector customop_modules; std::vector module_operators; @@ -213,13 +241,13 @@ int main(int argc, char** argv) if (strcmp(key, "device") == 0) device = value; if (strcmp(key, "inputshape") == 0) - input_shapes = parse_comma_int_array_list(value); + parse_shape_list(value, input_shapes, input_types); if (strcmp(key, "inputshape2") == 0) - input_shapes2 = parse_comma_int_array_list(value); + parse_shape_list(value, input_shapes2, input_types2); if (strcmp(key, "customop") == 0) - customop_modules = parse_comma_string_array_list(value); + parse_string_list(value, customop_modules); if (strcmp(key, "moduleop") == 0) - module_operators = parse_comma_string_array_list(value); + parse_string_list(value, module_operators); } // print options @@ -233,10 +261,10 @@ int main(int argc, char** argv) fprintf(stderr, "optlevel = %d\n", optlevel); fprintf(stderr, "device = %s\n", device.c_str()); fprintf(stderr, "inputshape = "); - print_int64_array_list(input_shapes); + print_shape_list(input_shapes, input_types); fprintf(stderr, "\n"); fprintf(stderr, "inputshape2 = "); - print_int64_array_list(input_shapes2); + print_shape_list(input_shapes2, input_types2); fprintf(stderr, "\n"); fprintf(stderr, "customop = "); print_string_list(customop_modules); @@ -268,9 +296,12 @@ int main(int argc, char** argv) } std::vector input_tensors; - for (auto shape : input_shapes) + for (size_t i = 0; i < input_shapes.size(); i++) { - at::Tensor t = torch::ones(shape); + const std::vector& shape = input_shapes[i]; + const std::string& type = input_types[i]; + + at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); if (device == "gpu") t = t.cuda(); @@ -278,9 +309,12 @@ int main(int argc, char** argv) } std::vector input_tensors2; - for (auto shape : input_shapes2) + for (size_t i = 0; i < input_shapes2.size(); i++) { - at::Tensor t = torch::ones(shape); + const std::vector& shape = input_shapes2[i]; + const std::string& type = input_types2[i]; + + at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); if (device == "gpu") t = t.cuda(); diff --git a/tools/pnnx/src/pass_level2/F_affine_grid.cpp b/tools/pnnx/src/pass_level2/F_affine_grid.cpp index d766da701..44952e0d0 100644 --- a/tools/pnnx/src/pass_level2/F_affine_grid.cpp +++ b/tools/pnnx/src/pass_level2/F_affine_grid.cpp @@ -23,8 +23,8 @@ public: { return R"PNNXIR(7767517 5 4 -pnnx.Input input_1 0 1 theta -pnnx.Input input_2 0 1 size +pnnx.Input input_0 0 1 theta +pnnx.Input input_1 0 1 size prim::Constant op_0 0 1 align_corners value=%align_corners aten::affine_grid_generator op_1 3 1 theta size align_corners out pnnx.Output output 1 0 out diff --git a/tools/pnnx/src/pass_level2/F_embedding.cpp b/tools/pnnx/src/pass_level2/F_embedding.cpp new file mode 100644 index 000000000..2b3b58e04 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_embedding.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_embedding : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +prim::Constant op_0 0 1 padding_idx value=* +prim::Constant op_1 0 1 scale_grad_by_freq value=%scale_grad_by_freq +prim::Constant op_2 0 1 sparse value=%sparse +aten::embedding op_3 5 1 weight input padding_idx scale_grad_by_freq sparse out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.embedding"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding, 10) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/F_embedding.cpp b/tools/pnnx/src/pass_ncnn/F_embedding.cpp new file mode 100644 index 000000000..37fbe30a7 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_embedding.cpp @@ -0,0 +1,69 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_embedding : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_weight 0 1 weight @qwq +F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Embed"; + } + + const char* name_str() const + { + return "embed"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + Attribute weight; + for (const auto& x : captured_attrs) + { + if (x.first.substr(0, 10) == "op_weight.") + weight = x.second; + } + + op->params["0"] = weight.shape[1]; + op->params["1"] = weight.shape[0]; + op->params["2"] = 0; + op->params["3"] = (int)(weight.data.size() / sizeof(float)); + + op->attrs["0"] = Attribute(); + op->attrs["0"].data = {0, 0, 0, 0}; + op->attrs["1"] = weight; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_embedding, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 237354d0f..9db39be52 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -24,6 +24,7 @@ pnnx_add_test(F_conv_transpose1d) pnnx_add_test(F_conv_transpose2d) pnnx_add_test(F_conv_transpose3d) pnnx_add_test(F_elu) +pnnx_add_test(F_embedding) pnnx_add_test(F_gelu) pnnx_add_test(F_grid_sample) pnnx_add_test(F_group_norm) @@ -91,6 +92,7 @@ pnnx_add_test(nn_ConvTranspose1d) pnnx_add_test(nn_ConvTranspose2d) pnnx_add_test(nn_ConvTranspose3d) pnnx_add_test(nn_ELU) +pnnx_add_test(nn_Embedding) pnnx_add_test(nn_GELU) pnnx_add_test(nn_GroupNorm) pnnx_add_test(nn_GRU) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index a33cc9b91..429f30822 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -22,6 +22,7 @@ pnnx_ncnn_add_test(F_conv1d) pnnx_ncnn_add_test(F_conv2d) pnnx_ncnn_add_test(F_conv3d) pnnx_ncnn_add_test(F_elu) +pnnx_ncnn_add_test(F_embedding) pnnx_ncnn_add_test(F_gelu) pnnx_ncnn_add_test(F_group_norm) pnnx_ncnn_add_test(F_hardsigmoid) @@ -70,6 +71,7 @@ pnnx_ncnn_add_test(nn_Conv2d) pnnx_ncnn_add_test(nn_Conv3d) pnnx_ncnn_add_test(nn_ConvTranspose2d) pnnx_ncnn_add_test(nn_ELU) +pnnx_ncnn_add_test(nn_Embedding) pnnx_ncnn_add_test(nn_GELU) pnnx_ncnn_add_test(nn_GroupNorm) pnnx_ncnn_add_test(nn_GRU) diff --git a/tools/pnnx/tests/ncnn/test_F_embedding.py b/tools/pnnx/tests/ncnn/test_F_embedding.py new file mode 100644 index 000000000..c45d1e5c2 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_embedding.py @@ -0,0 +1,56 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w1 = nn.Parameter(torch.rand(10, 128)) + + def forward(self, y): + y = F.embedding(y, self.w1) + return y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + y = torch.randint(10, (1, 11), dtype=torch.int) + + a = net(y) + + # export torchscript + mod = torch.jit.trace(net, (y)) + mod.save("test_F_embedding.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_embedding.pt inputshape=[1,11]i32") + + # ncnn inference + import test_F_embedding_ncnn + b = test_F_embedding_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_Embedding.py b/tools/pnnx/tests/ncnn/test_nn_Embedding.py new file mode 100644 index 000000000..c34a0b650 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_Embedding.py @@ -0,0 +1,56 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10) + + def forward(self, x): + x = self.embed_0(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.randint(10, (13,), dtype=torch.int) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Embedding.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_Embedding.pt inputshape=[13]i32") + + # ncnn inference + import test_nn_Embedding_ncnn + b = test_nn_Embedding_ncnn.test_inference() + + return torch.allclose(a, b, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_F_embedding.py b/tools/pnnx/tests/test_F_embedding.py new file mode 100644 index 000000000..0756bc7f7 --- /dev/null +++ b/tools/pnnx/tests/test_F_embedding.py @@ -0,0 +1,59 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w1 = nn.Parameter(torch.rand(10, 128)) + + def forward(self, x, w0, y): + x = F.embedding(x, w0) + y = F.embedding(y, self.w1) + return x, y + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.randint(10, (1, 13), dtype=torch.int) + w0 = torch.rand(10, 128) + y = torch.randint(10, (1, 11), dtype=torch.int) + + a0, a1 = net(x, w0, y) + + # export torchscript + mod = torch.jit.trace(net, (x, w0, y)) + mod.save("test_F_embedding.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_embedding.pt inputshape=[1,13]i32,[10,128],[1,11]i32") + + # pnnx inference + import test_F_embedding_pnnx + b0, b1 = test_F_embedding_pnnx.test_inference() + + return torch.equal(a0, b0) and torch.equal(a1, b1) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_nn_Embedding.py b/tools/pnnx/tests/test_nn_Embedding.py new file mode 100644 index 000000000..48a74dfb9 --- /dev/null +++ b/tools/pnnx/tests/test_nn_Embedding.py @@ -0,0 +1,56 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10) + + def forward(self, x): + x = self.embed_0(x) + return x + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.randint(10, (1, 13), dtype=torch.int) + + a = net(x) + + # export torchscript + mod = torch.jit.trace(net, x) + mod.save("test_nn_Embedding.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_nn_Embedding.pt inputshape=[1,13]i32") + + # pnnx inference + import test_nn_Embedding_pnnx + b = test_nn_Embedding_pnnx.test_inference() + + return torch.equal(a, b) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)