You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

torch_transpose.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "pass_ncnn.h"
  15. namespace pnnx {
  16. namespace ncnn {
  17. class torch_transpose : public GraphRewriterPass
  18. {
  19. public:
  20. const char* match_pattern_graph() const
  21. {
  22. return R"PNNXIR(7767517
  23. 3 2
  24. pnnx.Input input 0 1 input
  25. torch.transpose op_0 1 1 input out dim0=%dim0 dim1=%dim1
  26. pnnx.Output output 1 0 out
  27. )PNNXIR";
  28. }
  29. const char* type_str() const
  30. {
  31. return "Permute";
  32. }
  33. const char* name_str() const
  34. {
  35. return "transpose";
  36. }
  37. void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
  38. {
  39. op->params["0"] = 0;
  40. const int batch_index = op->inputs[0]->params["__batch_index"].i;
  41. int dim0 = captured_params.at("dim0").i;
  42. int dim1 = captured_params.at("dim1").i;
  43. int input_rank = op->inputs[0]->shape.size();
  44. if (dim0 < 0)
  45. {
  46. dim0 = input_rank + dim0;
  47. }
  48. if (dim1 < 0)
  49. {
  50. dim1 = input_rank + dim1;
  51. }
  52. if (dim0 == batch_index || dim1 == batch_index)
  53. {
  54. fprintf(stderr, "permute across batch dim is not supported yet!\n");
  55. return;
  56. }
  57. if (batch_index >= 0 && batch_index < input_rank)
  58. input_rank -= 1;
  59. if (input_rank > 4)
  60. {
  61. fprintf(stderr, "permute %d-rank tensor is not supported yet!\n", input_rank);
  62. return;
  63. }
  64. if (dim0 > batch_index)
  65. dim0 -= 1;
  66. if (dim1 > batch_index)
  67. dim1 -= 1;
  68. if (input_rank == 1)
  69. {
  70. // noop
  71. op->type = "Noop";
  72. }
  73. if (input_rank == 2)
  74. {
  75. if (dim0 == 0 && dim1 == 1) op->params["0"] = 1;
  76. if (dim0 == 1 && dim1 == 0) op->params["0"] = 1;
  77. }
  78. if (input_rank == 3)
  79. {
  80. if (dim0 == 0 && dim1 == 1) op->params["0"] = 2;
  81. if (dim0 == 1 && dim1 == 0) op->params["0"] = 2;
  82. if (dim0 == 0 && dim1 == 2) op->params["0"] = 5;
  83. if (dim0 == 2 && dim1 == 0) op->params["0"] = 5;
  84. if (dim0 == 1 && dim1 == 2) op->params["0"] = 1;
  85. if (dim0 == 2 && dim1 == 1) op->params["0"] = 1;
  86. }
  87. if (input_rank == 4)
  88. {
  89. if (dim0 == 0 && dim1 == 1) op->params["0"] = 6;
  90. if (dim0 == 1 && dim1 == 0) op->params["0"] = 6;
  91. if (dim0 == 0 && dim1 == 2) op->params["0"] = 14;
  92. if (dim0 == 2 && dim1 == 0) op->params["0"] = 14;
  93. if (dim0 == 0 && dim1 == 3) op->params["0"] = 21;
  94. if (dim0 == 3 && dim1 == 0) op->params["0"] = 21;
  95. if (dim0 == 1 && dim1 == 2) op->params["0"] = 2;
  96. if (dim0 == 2 && dim1 == 1) op->params["0"] = 2;
  97. if (dim0 == 1 && dim1 == 3) op->params["0"] = 5;
  98. if (dim0 == 3 && dim1 == 1) op->params["0"] = 5;
  99. if (dim0 == 2 && dim1 == 3) op->params["0"] = 1;
  100. if (dim0 == 3 && dim1 == 2) op->params["0"] = 1;
  101. }
  102. }
  103. };
  104. REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_transpose, 20)
  105. } // namespace ncnn
  106. } // namespace pnnx