Browse Source

pnnx fix some undefined dtype (#5382)

tags/20240410
nihui GitHub 2 years ago
parent
commit
02ba6766bd
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
14 changed files with 266 additions and 168 deletions
  1. +19
    -12
      tools/pnnx/src/pass_level2/Tensor_new_empty.cpp
  2. +19
    -12
      tools/pnnx/src/pass_level2/Tensor_new_ones.cpp
  3. +19
    -12
      tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp
  4. +19
    -12
      tools/pnnx/src/pass_level2/Tensor_to.cpp
  5. +19
    -12
      tools/pnnx/src/pass_level2/torch_empty.cpp
  6. +19
    -12
      tools/pnnx/src/pass_level2/torch_empty_like.cpp
  7. +19
    -12
      tools/pnnx/src/pass_level2/torch_full.cpp
  8. +19
    -12
      tools/pnnx/src/pass_level2/torch_full_like.cpp
  9. +19
    -12
      tools/pnnx/src/pass_level2/torch_ones.cpp
  10. +19
    -12
      tools/pnnx/src/pass_level2/torch_ones_like.cpp
  11. +19
    -12
      tools/pnnx/src/pass_level2/torch_randn.cpp
  12. +19
    -12
      tools/pnnx/src/pass_level2/torch_randn_like.cpp
  13. +19
    -12
      tools/pnnx/src/pass_level2/torch_zeros.cpp
  14. +19
    -12
      tools/pnnx/src/pass_level2/torch_zeros_like.cpp

+ 19
- 12
tools/pnnx/src/pass_level2/Tensor_new_empty.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/Tensor_new_ones.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/Tensor_to.cpp View File

@@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}

op->params["copy"] = captured_params.at("copy");



+ 19
- 12
tools/pnnx/src/pass_level2/torch_empty.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_empty_like.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_full.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_full_like.cpp View File

@@ -42,18 +42,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_ones.cpp View File

@@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_ones_like.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_randn.cpp View File

@@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_randn_like.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_zeros.cpp View File

@@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



+ 19
- 12
tools/pnnx/src/pass_level2/torch_zeros_like.cpp View File

@@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};



Loading…
Cancel
Save