You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_Conv2d.cpp 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "pass_level1.h"
  15. // #include "../pass_level3/fuse_expression.h"
  16. #include "../utils.h"
  17. namespace pnnx {
  18. class Conv2d : public FuseModulePass
  19. {
  20. public:
  21. const char* match_type_str() const
  22. {
  23. return "__torch__.torch.nn.modules.conv.Conv2d";
  24. }
  25. const char* type_str() const
  26. {
  27. return "nn.Conv2d";
  28. }
  29. void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
  30. {
  31. // {
  32. // pnnx::Graph pnnx_graph;
  33. //
  34. // pnnx_graph.load(mod, graph);
  35. //
  36. // pnnx::fuse_expression(pnnx_graph);
  37. //
  38. // pnnx_graph.save("tmp.param", "tmp.bin");
  39. // }
  40. const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
  41. const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode");
  42. const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
  43. const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d");
  44. const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d");
  45. if (convolution_mode)
  46. {
  47. convolution = convolution_mode;
  48. }
  49. const auto& weight = mod.attr("weight").toTensor();
  50. op->params["groups"] = convolution->namedInput("groups");
  51. op->params["in_channels"] = weight.size(1) * op->params["groups"].i;
  52. op->params["out_channels"] = weight.size(0);
  53. op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3)};
  54. op->params["stride"] = convolution->namedInput("stride");
  55. if (pad)
  56. {
  57. op->params["padding_mode"] = pad->namedInput("mode");
  58. op->params["padding"] = pad->namedInput("pad");
  59. std::vector<int>& padding = op->params["padding"].ai;
  60. if (padding.size() == 4)
  61. {
  62. // Conv2d only accepts tuple of two integers
  63. if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3])
  64. {
  65. padding.resize(2);
  66. }
  67. else if (padding[0] == padding[2] && padding[1] == padding[3] && padding[0] != padding[1])
  68. {
  69. padding.resize(0);
  70. op->params["padding"].s = "same";
  71. }
  72. }
  73. }
  74. else if (reflection_pad2d)
  75. {
  76. op->params["padding_mode"] = "reflect";
  77. op->params["padding"] = reflection_pad2d->namedInput("padding");
  78. std::vector<int>& padding = op->params["padding"].ai;
  79. if (padding.size() == 4)
  80. {
  81. // Conv2d only accepts tuple of two integers
  82. if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3])
  83. {
  84. padding.resize(2);
  85. }
  86. else if (padding[0] == padding[2] && padding[1] == padding[3] && padding[0] != padding[1])
  87. {
  88. padding.resize(0);
  89. op->params["padding"].s = "same";
  90. }
  91. }
  92. }
  93. else if (replication_pad2d)
  94. {
  95. op->params["padding_mode"] = "replicate";
  96. op->params["padding"] = replication_pad2d->namedInput("padding");
  97. std::vector<int>& padding = op->params["padding"].ai;
  98. if (padding.size() == 4)
  99. {
  100. // Conv2d only accepts tuple of two integers
  101. if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3])
  102. {
  103. padding.resize(2);
  104. }
  105. else if (padding[0] == padding[2] && padding[1] == padding[3] && padding[0] != padding[1])
  106. {
  107. padding.resize(0);
  108. op->params["padding"].s = "same";
  109. }
  110. }
  111. }
  112. else
  113. {
  114. op->params["padding_mode"] = "zeros";
  115. op->params["padding"] = convolution->namedInput("padding");
  116. }
  117. op->params["dilation"] = convolution->namedInput("dilation");
  118. op->params["bias"] = mod.hasattr("bias");
  119. op->attrs["weight"] = weight;
  120. if (mod.hasattr("bias"))
  121. {
  122. op->attrs["bias"] = mod.attr("bias").toTensor();
  123. }
  124. }
  125. };
  126. REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(Conv2d)
  127. } // namespace pnnx