| @@ -559,7 +559,7 @@ TORCH_LIBRARY(upfirdn2d_op, m) { | |||
| |F.dropout3d | | | |||
| |F.elu | :heavy_check_mark: | :heavy_check_mark: | | |||
| |F.elu_ | :heavy_check_mark: | :heavy_check_mark: | | |||
| |F.embedding | | | |||
| |F.embedding | :heavy_check_mark: | :heavy_check_mark: | | |||
| |F.embedding_bag | | | |||
| |F.feature_alpha_dropout | | | |||
| |F.fold | | | |||
| @@ -111,6 +111,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/F_conv3d.cpp | |||
| pass_level2/F_conv_transpose123d.cpp | |||
| pass_level2/F_elu.cpp | |||
| pass_level2/F_embedding.cpp | |||
| pass_level2/F_gelu.cpp | |||
| pass_level2/F_grid_sample.cpp | |||
| pass_level2/F_group_norm.cpp | |||
| @@ -246,6 +247,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/F_conv2d.cpp | |||
| pass_ncnn/F_conv3d.cpp | |||
| pass_ncnn/F_elu.cpp | |||
| pass_ncnn/F_embedding.cpp | |||
| pass_ncnn/F_gelu.cpp | |||
| pass_ncnn/F_group_norm.cpp | |||
| pass_ncnn/F_hardsigmoid.cpp | |||
| @@ -27,6 +27,19 @@ | |||
| namespace pnnx { | |||
| static bool type_is_integer(int type) | |||
| { | |||
| if (type == 1) return false; | |||
| if (type == 2) return false; | |||
| if (type == 3) return false; | |||
| if (type == 4) return true; | |||
| if (type == 5) return true; | |||
| if (type == 6) return true; | |||
| if (type == 7) return true; | |||
| if (type == 8) return true; | |||
| return false; | |||
| } | |||
| static const char* type_to_string(int type) | |||
| { | |||
| if (type == 1) return "f32"; | |||
| @@ -53,6 +66,19 @@ static const char* type_to_numpy_string(int type) | |||
| return "null"; | |||
| } | |||
| static const char* type_to_dtype_string(int type) | |||
| { | |||
| if (type == 1) return "torch.float"; | |||
| if (type == 2) return "torch.double"; | |||
| if (type == 3) return "torch.half"; | |||
| if (type == 4) return "torch.int"; | |||
| if (type == 5) return "torch.long"; | |||
| if (type == 6) return "torch.short"; | |||
| if (type == 7) return "torch.int8"; | |||
| if (type == 8) return "torch.uint8"; | |||
| return "null"; | |||
| } | |||
| static size_t type_to_elemsize(int type) | |||
| { | |||
| if (type == 1) return 4; | |||
| @@ -1701,15 +1727,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| const Operand* r = op->outputs[0]; | |||
| std::string input_name = std::string("v_") + sanitize_identifier(r->name); | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| if (type_is_integer(r->type)) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size()) | |||
| fprintf(pyfp, ", "); | |||
| fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size() || r->shape.size() == 1) | |||
| fprintf(pyfp, ", "); | |||
| } | |||
| fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d, ", r->shape[i]); | |||
| } | |||
| fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| fprintf(pyfp, ")\n"); | |||
| input_names.push_back(input_name); | |||
| } | |||
| @@ -1755,15 +1792,26 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| const Operand* r = op->outputs[0]; | |||
| std::string input_name = std::string("v_") + sanitize_identifier(r->name); | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| if (type_is_integer(r->type)) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size()) | |||
| fprintf(pyfp, ", "); | |||
| fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size() || r->shape.size() == 1) | |||
| fprintf(pyfp, ", "); | |||
| } | |||
| fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d, ", r->shape[i]); | |||
| } | |||
| fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| fprintf(pyfp, ")\n"); | |||
| input_names.push_back(input_name); | |||
| } | |||
| @@ -2018,15 +2066,26 @@ int Graph::ncnn(const std::string& parampath, const std::string& binpath, const | |||
| if (!r) | |||
| break; | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| if (type_is_integer(r->type)) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size()) | |||
| fprintf(pyfp, ", "); | |||
| fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d", r->shape[i]); | |||
| if (i + 1 != r->shape.size() || r->shape.size() == 1) | |||
| fprintf(pyfp, ", "); | |||
| } | |||
| fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); | |||
| for (size_t i = 0; i < r->shape.size(); i++) | |||
| { | |||
| fprintf(pyfp, "%d, ", r->shape[i]); | |||
| } | |||
| fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); | |||
| } | |||
| fprintf(pyfp, ")\n"); | |||
| } | |||
| fprintf(pyfp, " out = []\n"); | |||
| @@ -40,86 +40,112 @@ static std::string get_basename(const std::string& path) | |||
| return path.substr(0, path.find_last_of('.')); | |||
| } | |||
| static std::vector<std::string> parse_comma_string_array_list(char* s) | |||
| static void parse_string_list(char* s, std::vector<std::string>& list) | |||
| { | |||
| std::vector<std::string> as; | |||
| list.clear(); | |||
| char* pch = strtok(s, ","); | |||
| while (pch != NULL) | |||
| { | |||
| as.push_back(std::string(pch)); | |||
| list.push_back(std::string(pch)); | |||
| pch = strtok(NULL, ","); | |||
| } | |||
| } | |||
| return as; | |||
| static void print_string_list(const std::vector<std::string>& list) | |||
| { | |||
| for (size_t i = 0; i < list.size(); i++) | |||
| { | |||
| fprintf(stderr, "%s", list[i].c_str()); | |||
| if (i + 1 != list.size()) | |||
| fprintf(stderr, ","); | |||
| } | |||
| } | |||
| static std::vector<std::vector<int64_t> > parse_comma_int_array_list(char* s) | |||
| static void parse_shape_list(char* s, std::vector<std::vector<int64_t> >& shapes, std::vector<std::string>& types) | |||
| { | |||
| std::vector<std::vector<int64_t> > aai; | |||
| shapes.clear(); | |||
| types.clear(); | |||
| char* pch = strtok(s, "[]"); | |||
| while (pch != NULL) | |||
| { | |||
| // assign user data type | |||
| if (!types.empty() && (pch[0] == 'f' || pch[0] == 'i' || pch[0] == 'u')) | |||
| { | |||
| char type[32]; | |||
| int nscan = sscanf(pch, "%31[^,]", type); | |||
| if (nscan == 1) | |||
| { | |||
| types[types.size() - 1] = std::string(type); | |||
| } | |||
| } | |||
| // parse a,b,c | |||
| int v; | |||
| int nconsumed = 0; | |||
| int nscan = sscanf(pch, "%d%n", &v, &nconsumed); | |||
| if (nscan == 1) | |||
| { | |||
| // ok we get array | |||
| // ok we get shape | |||
| pch += nconsumed; | |||
| std::vector<int64_t> ai; | |||
| ai.push_back(v); | |||
| std::vector<int64_t> s; | |||
| s.push_back(v); | |||
| nscan = sscanf(pch, ",%d%n", &v, &nconsumed); | |||
| while (nscan == 1) | |||
| { | |||
| pch += nconsumed; | |||
| ai.push_back(v); | |||
| s.push_back(v); | |||
| nscan = sscanf(pch, ",%d%n", &v, &nconsumed); | |||
| } | |||
| // array end | |||
| aai.push_back(ai); | |||
| // shape end | |||
| shapes.push_back(s); | |||
| types.push_back("f32"); | |||
| } | |||
| pch = strtok(NULL, "[]"); | |||
| } | |||
| return aai; | |||
| } | |||
| static void print_int64_array_list(const std::vector<std::vector<int64_t> >& list) | |||
| static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types) | |||
| { | |||
| for (size_t i = 0; i < list.size(); i++) | |||
| for (size_t i = 0; i < shapes.size(); i++) | |||
| { | |||
| const std::vector<int64_t>& array = list[i]; | |||
| const std::vector<int64_t>& s = shapes[i]; | |||
| const std::string& t = types[i]; | |||
| fprintf(stderr, "["); | |||
| for (size_t j = 0; j < array.size(); j++) | |||
| for (size_t j = 0; j < s.size(); j++) | |||
| { | |||
| fprintf(stderr, "%ld", array[j]); | |||
| if (j != array.size() - 1) | |||
| fprintf(stderr, "%ld", s[j]); | |||
| if (j != s.size() - 1) | |||
| fprintf(stderr, ","); | |||
| } | |||
| fprintf(stderr, "]"); | |||
| if (i != list.size() - 1) | |||
| fprintf(stderr, "%s", t.c_str()); | |||
| if (i != shapes.size() - 1) | |||
| fprintf(stderr, ","); | |||
| } | |||
| } | |||
| static void print_string_list(const std::vector<std::string>& list) | |||
| static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t) | |||
| { | |||
| for (size_t i = 0; i < list.size(); i++) | |||
| { | |||
| fprintf(stderr, "%s", list[i].c_str()); | |||
| if (i + 1 != list.size()) | |||
| fprintf(stderr, ","); | |||
| } | |||
| if (t == "f32") return torch::kFloat32; | |||
| if (t == "f16") return torch::kFloat16; | |||
| if (t == "f64") return torch::kFloat64; | |||
| if (t == "i32") return torch::kInt32; | |||
| if (t == "i16") return torch::kInt16; | |||
| if (t == "i64") return torch::kInt64; | |||
| if (t == "i8") return torch::kInt8; | |||
| if (t == "u8") return torch::kUInt8; | |||
| fprintf(stderr, "unsupported type %s fallback to f32\n", t.c_str()); | |||
| return torch::kFloat32; | |||
| } | |||
| static void show_usage() | |||
| @@ -142,7 +168,7 @@ static void show_usage() | |||
| #endif | |||
| fprintf(stderr, " moduleop=models.common.Focus,models.yolo.Detect,...\n"); | |||
| fprintf(stderr, "Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]\n"); | |||
| fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect\n"); | |||
| fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640]f32 inputshape2=[1,3,320,320]f32 device=gpu moduleop=models.common.Focus,models.yolo.Detect\n"); | |||
| } | |||
| int main(int argc, char** argv) | |||
| @@ -175,7 +201,9 @@ int main(int argc, char** argv) | |||
| int optlevel = 2; | |||
| std::string device = "cpu"; | |||
| std::vector<std::vector<int64_t> > input_shapes; | |||
| std::vector<std::string> input_types; | |||
| std::vector<std::vector<int64_t> > input_shapes2; | |||
| std::vector<std::string> input_types2; | |||
| std::vector<std::string> customop_modules; | |||
| std::vector<std::string> module_operators; | |||
| @@ -213,13 +241,13 @@ int main(int argc, char** argv) | |||
| if (strcmp(key, "device") == 0) | |||
| device = value; | |||
| if (strcmp(key, "inputshape") == 0) | |||
| input_shapes = parse_comma_int_array_list(value); | |||
| parse_shape_list(value, input_shapes, input_types); | |||
| if (strcmp(key, "inputshape2") == 0) | |||
| input_shapes2 = parse_comma_int_array_list(value); | |||
| parse_shape_list(value, input_shapes2, input_types2); | |||
| if (strcmp(key, "customop") == 0) | |||
| customop_modules = parse_comma_string_array_list(value); | |||
| parse_string_list(value, customop_modules); | |||
| if (strcmp(key, "moduleop") == 0) | |||
| module_operators = parse_comma_string_array_list(value); | |||
| parse_string_list(value, module_operators); | |||
| } | |||
| // print options | |||
| @@ -233,10 +261,10 @@ int main(int argc, char** argv) | |||
| fprintf(stderr, "optlevel = %d\n", optlevel); | |||
| fprintf(stderr, "device = %s\n", device.c_str()); | |||
| fprintf(stderr, "inputshape = "); | |||
| print_int64_array_list(input_shapes); | |||
| print_shape_list(input_shapes, input_types); | |||
| fprintf(stderr, "\n"); | |||
| fprintf(stderr, "inputshape2 = "); | |||
| print_int64_array_list(input_shapes2); | |||
| print_shape_list(input_shapes2, input_types2); | |||
| fprintf(stderr, "\n"); | |||
| fprintf(stderr, "customop = "); | |||
| print_string_list(customop_modules); | |||
| @@ -268,9 +296,12 @@ int main(int argc, char** argv) | |||
| } | |||
| std::vector<at::Tensor> input_tensors; | |||
| for (auto shape : input_shapes) | |||
| for (size_t i = 0; i < input_shapes.size(); i++) | |||
| { | |||
| at::Tensor t = torch::ones(shape); | |||
| const std::vector<int64_t>& shape = input_shapes[i]; | |||
| const std::string& type = input_types[i]; | |||
| at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); | |||
| if (device == "gpu") | |||
| t = t.cuda(); | |||
| @@ -278,9 +309,12 @@ int main(int argc, char** argv) | |||
| } | |||
| std::vector<at::Tensor> input_tensors2; | |||
| for (auto shape : input_shapes2) | |||
| for (size_t i = 0; i < input_shapes2.size(); i++) | |||
| { | |||
| at::Tensor t = torch::ones(shape); | |||
| const std::vector<int64_t>& shape = input_shapes2[i]; | |||
| const std::string& type = input_types2[i]; | |||
| at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type)); | |||
| if (device == "gpu") | |||
| t = t.cuda(); | |||
| @@ -23,8 +23,8 @@ public: | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_1 0 1 theta | |||
| pnnx.Input input_2 0 1 size | |||
| pnnx.Input input_0 0 1 theta | |||
| pnnx.Input input_1 0 1 size | |||
| prim::Constant op_0 0 1 align_corners value=%align_corners | |||
| aten::affine_grid_generator op_1 3 1 theta size align_corners out | |||
| pnnx.Output output 1 0 out | |||
| @@ -0,0 +1,44 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class F_embedding : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 weight | |||
| prim::Constant op_0 0 1 padding_idx value=* | |||
| prim::Constant op_1 0 1 scale_grad_by_freq value=%scale_grad_by_freq | |||
| prim::Constant op_2 0 1 sparse value=%sparse | |||
| aten::embedding op_3 5 1 weight input padding_idx scale_grad_by_freq sparse out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "F.embedding"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding, 10) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,69 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class F_embedding : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input 0 1 input | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| F.embedding op_0 2 1 input weight out scale_grad_by_freq=False sparse=False | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Embed"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "embed"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| } | |||
| op->params["0"] = weight.shape[1]; | |||
| op->params["1"] = weight.shape[0]; | |||
| op->params["2"] = 0; | |||
| op->params["3"] = (int)(weight.data.size() / sizeof(float)); | |||
| op->attrs["0"] = Attribute(); | |||
| op->attrs["0"].data = {0, 0, 0, 0}; | |||
| op->attrs["1"] = weight; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_embedding, 20) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -24,6 +24,7 @@ pnnx_add_test(F_conv_transpose1d) | |||
| pnnx_add_test(F_conv_transpose2d) | |||
| pnnx_add_test(F_conv_transpose3d) | |||
| pnnx_add_test(F_elu) | |||
| pnnx_add_test(F_embedding) | |||
| pnnx_add_test(F_gelu) | |||
| pnnx_add_test(F_grid_sample) | |||
| pnnx_add_test(F_group_norm) | |||
| @@ -91,6 +92,7 @@ pnnx_add_test(nn_ConvTranspose1d) | |||
| pnnx_add_test(nn_ConvTranspose2d) | |||
| pnnx_add_test(nn_ConvTranspose3d) | |||
| pnnx_add_test(nn_ELU) | |||
| pnnx_add_test(nn_Embedding) | |||
| pnnx_add_test(nn_GELU) | |||
| pnnx_add_test(nn_GroupNorm) | |||
| pnnx_add_test(nn_GRU) | |||
| @@ -22,6 +22,7 @@ pnnx_ncnn_add_test(F_conv1d) | |||
| pnnx_ncnn_add_test(F_conv2d) | |||
| pnnx_ncnn_add_test(F_conv3d) | |||
| pnnx_ncnn_add_test(F_elu) | |||
| pnnx_ncnn_add_test(F_embedding) | |||
| pnnx_ncnn_add_test(F_gelu) | |||
| pnnx_ncnn_add_test(F_group_norm) | |||
| pnnx_ncnn_add_test(F_hardsigmoid) | |||
| @@ -70,6 +71,7 @@ pnnx_ncnn_add_test(nn_Conv2d) | |||
| pnnx_ncnn_add_test(nn_Conv3d) | |||
| pnnx_ncnn_add_test(nn_ConvTranspose2d) | |||
| pnnx_ncnn_add_test(nn_ELU) | |||
| pnnx_ncnn_add_test(nn_Embedding) | |||
| pnnx_ncnn_add_test(nn_GELU) | |||
| pnnx_ncnn_add_test(nn_GroupNorm) | |||
| pnnx_ncnn_add_test(nn_GRU) | |||
| @@ -0,0 +1,56 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.w1 = nn.Parameter(torch.rand(10, 128)) | |||
| def forward(self, y): | |||
| y = F.embedding(y, self.w1) | |||
| return y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| y = torch.randint(10, (1, 11), dtype=torch.int) | |||
| a = net(y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (y)) | |||
| mod.save("test_F_embedding.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_F_embedding.pt inputshape=[1,11]i32") | |||
| # ncnn inference | |||
| import test_F_embedding_ncnn | |||
| b = test_F_embedding_ncnn.test_inference() | |||
| return torch.allclose(a, b, 1e-4, 1e-4) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,56 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10) | |||
| def forward(self, x): | |||
| x = self.embed_0(x) | |||
| return x | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.randint(10, (13,), dtype=torch.int) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_nn_Embedding.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_nn_Embedding.pt inputshape=[13]i32") | |||
| # ncnn inference | |||
| import test_nn_Embedding_ncnn | |||
| b = test_nn_Embedding_ncnn.test_inference() | |||
| return torch.allclose(a, b, 1e-4, 1e-4) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,59 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.w1 = nn.Parameter(torch.rand(10, 128)) | |||
| def forward(self, x, w0, y): | |||
| x = F.embedding(x, w0) | |||
| y = F.embedding(y, self.w1) | |||
| return x, y | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.randint(10, (1, 13), dtype=torch.int) | |||
| w0 = torch.rand(10, 128) | |||
| y = torch.randint(10, (1, 11), dtype=torch.int) | |||
| a0, a1 = net(x, w0, y) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, w0, y)) | |||
| mod.save("test_F_embedding.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_F_embedding.pt inputshape=[1,13]i32,[10,128],[1,11]i32") | |||
| # pnnx inference | |||
| import test_F_embedding_pnnx | |||
| b0, b1 = test_F_embedding_pnnx.test_inference() | |||
| return torch.equal(a0, b0) and torch.equal(a1, b1) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,56 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.embed_0 = nn.Embedding(embedding_dim=128, num_embeddings=10) | |||
| def forward(self, x): | |||
| x = self.embed_0(x) | |||
| return x | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.randint(10, (1, 13), dtype=torch.int) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_nn_Embedding.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_nn_Embedding.pt inputshape=[1,13]i32") | |||
| # pnnx inference | |||
| import test_nn_Embedding_pnnx | |||
| b = test_nn_Embedding_pnnx.test_inference() | |||
| return torch.equal(a, b) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||