From 02ba6766bd2ecd08600736c033e7346142313466 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 20 Mar 2024 11:42:15 +0800 Subject: [PATCH] pnnx fix some undefined dtype (#5382) --- .../pnnx/src/pass_level2/Tensor_new_empty.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/Tensor_new_ones.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/Tensor_new_zeros.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/Tensor_to.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/torch_empty.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/torch_empty_like.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/torch_full.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/torch_full_like.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/torch_ones.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/torch_ones_like.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/torch_randn.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/torch_randn_like.cpp | 31 ++++++++++++------- tools/pnnx/src/pass_level2/torch_zeros.cpp | 31 ++++++++++++------- .../pnnx/src/pass_level2/torch_zeros_like.cpp | 31 ++++++++++++------- 14 files changed, 266 insertions(+), 168 deletions(-) diff --git a/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp index 6f0f0cb68..215b17e2c 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp b/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp index 03df3028b..3fe4c3390 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp b/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp index c3aa069dc..93963f2a2 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp index 8ab1f1249..6d7cd9e7d 100644 --- a/tools/pnnx/src/pass_level2/Tensor_to.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } op->params["copy"] = captured_params.at("copy"); diff --git a/tools/pnnx/src/pass_level2/torch_empty.cpp b/tools/pnnx/src/pass_level2/torch_empty.cpp index 92244e2e4..3c6a074cb 100644 --- a/tools/pnnx/src/pass_level2/torch_empty.cpp +++ b/tools/pnnx/src/pass_level2/torch_empty.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_empty_like.cpp b/tools/pnnx/src/pass_level2/torch_empty_like.cpp index 13c145c96..baa2f74c0 100644 --- a/tools/pnnx/src/pass_level2/torch_empty_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_empty_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_full.cpp b/tools/pnnx/src/pass_level2/torch_full.cpp index 293fad2e9..718a0796a 100644 --- a/tools/pnnx/src/pass_level2/torch_full.cpp +++ b/tools/pnnx/src/pass_level2/torch_full.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_full_like.cpp b/tools/pnnx/src/pass_level2/torch_full_like.cpp index 67f2a6f58..4d58df9c7 100644 --- a/tools/pnnx/src/pass_level2/torch_full_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_full_like.cpp @@ -42,18 +42,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_ones.cpp b/tools/pnnx/src/pass_level2/torch_ones.cpp index d055b3466..888397a97 100644 --- a/tools/pnnx/src/pass_level2/torch_ones.cpp +++ b/tools/pnnx/src/pass_level2/torch_ones.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_ones_like.cpp b/tools/pnnx/src/pass_level2/torch_ones_like.cpp index 312ea8ed9..8837b0fdd 100644 --- a/tools/pnnx/src/pass_level2/torch_ones_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_ones_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_randn.cpp b/tools/pnnx/src/pass_level2/torch_randn.cpp index 5cbfc33fe..345c4e495 100644 --- a/tools/pnnx/src/pass_level2/torch_randn.cpp +++ b/tools/pnnx/src/pass_level2/torch_randn.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_randn_like.cpp b/tools/pnnx/src/pass_level2/torch_randn_like.cpp index da74dec04..68c1dc9dc 100644 --- a/tools/pnnx/src/pass_level2/torch_randn_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_randn_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_zeros.cpp b/tools/pnnx/src/pass_level2/torch_zeros.cpp index 8b53d1652..90213fdde 100644 --- a/tools/pnnx/src/pass_level2/torch_zeros.cpp +++ b/tools/pnnx/src/pass_level2/torch_zeros.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_zeros_like.cpp b/tools/pnnx/src/pass_level2/torch_zeros_like.cpp index 85a0bd224..5babbbb55 100644 --- a/tools/pnnx/src/pass_level2/torch_zeros_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_zeros_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } };