You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tensor_slice.cpp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "pass_ncnn.h"
  15. #include <limits.h>
  16. namespace pnnx {
  17. namespace ncnn {
  18. class Tensor_slice : public GraphRewriterPass
  19. {
  20. public:
  21. const char* match_pattern_graph() const
  22. {
  23. return R"PNNXIR(7767517
  24. 3 2
  25. pnnx.Input input 0 1 input
  26. Tensor.slice op_0 1 1 input out dims=%dims starts=%starts ends=%ends steps=%steps
  27. pnnx.Output output 1 0 out
  28. )PNNXIR";
  29. }
  30. const char* type_str() const
  31. {
  32. return "Crop";
  33. }
  34. const char* name_str() const
  35. {
  36. return "slice";
  37. }
  38. void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
  39. {
  40. std::vector<int> axes = captured_params.at("dims").ai;
  41. const std::vector<int>& starts = captured_params.at("starts").ai;
  42. std::vector<int> ends = captured_params.at("ends").ai;
  43. const std::vector<int>& steps = captured_params.at("steps").ai;
  44. int axes_rank = axes.size();
  45. for (int i = 0; i < axes_rank; i++)
  46. {
  47. if (steps[i] != 1)
  48. {
  49. fprintf(stderr, "slice with step %d is not supported\n", steps[i]);
  50. return;
  51. }
  52. }
  53. const int batch_index = op->inputs[0]->params["__batch_index"].i;
  54. {
  55. int input_rank = op->inputs[0]->shape.size();
  56. if (batch_index >= 0 && batch_index < input_rank)
  57. input_rank -= 1;
  58. if (input_rank > 4)
  59. {
  60. fprintf(stderr, "slice %d-rank tensor with %d-rank axes is not possible!\n", input_rank, axes_rank);
  61. return;
  62. }
  63. }
  64. for (int i = 0; i < axes_rank; i++)
  65. {
  66. if (axes[i] == batch_index && (starts[i] != 0 || ends[i] != INT_MAX))
  67. {
  68. fprintf(stderr, "slice along batch axis is not supported\n");
  69. return;
  70. }
  71. if (axes[i] < 0)
  72. {
  73. int input_rank = op->inputs[0]->shape.size();
  74. axes[i] = input_rank + axes[i];
  75. }
  76. if (axes[i] > batch_index)
  77. axes[i] -= 1;
  78. if (ends[i] == INT_MAX)
  79. ends[i] = -233;
  80. }
  81. op->params["9"] = starts;
  82. op->params["10"] = ends;
  83. op->params["11"] = axes;
  84. }
  85. };
  86. REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_slice, 20)
  87. } // namespace ncnn
  88. } // namespace pnnx