You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reshape_pack4.comp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #version 450
  15. layout (constant_id = 0) const int ndim = 0;
  16. layout (local_size_x_id = 233) in;
  17. layout (local_size_y_id = 234) in;
  18. layout (local_size_z_id = 235) in;
  19. layout (binding = 0) readonly buffer bottom_blob { vec4 bottom_blob_data[]; };
  20. layout (binding = 1) writeonly buffer top_blob { vec4 top_blob_data[]; };
  21. layout (push_constant) uniform parameter
  22. {
  23. int dims;
  24. int w;
  25. int h;
  26. int c;
  27. int cstep;
  28. int outdims;
  29. int outw;
  30. int outh;
  31. int outc;
  32. int outcstep;
  33. } p;
  34. void main()
  35. {
  36. int gx = int(gl_GlobalInvocationID.x);
  37. int gy = int(gl_GlobalInvocationID.y);
  38. int gz = int(gl_GlobalInvocationID.z);
  39. if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
  40. return;
  41. ivec4 i4;
  42. if (ndim == 1) i4 = gx * 4 + ivec4(0, 1, 2, 3);
  43. if (ndim == 2) i4 = (gy * 4) * p.outw + gx + ivec4(0, 1, 2, 3) * p.outw;
  44. if (ndim == 3) i4 = (gz * 4) * p.outh * p.outw + gy * p.outw + gx + ivec4(0, 1, 2, 3) * p.outh * p.outw;
  45. ivec4 v_offset;
  46. ivec4 lane4;
  47. if (p.dims == 1)
  48. {
  49. v_offset = i4 / 4;
  50. lane4 = i4 % 4;
  51. }
  52. else if (p.dims == 2)
  53. {
  54. ivec4 y4 = i4 / p.w;
  55. ivec4 x4 = i4 % p.w;
  56. v_offset = (y4 / 4) * p.w + x4;
  57. lane4 = y4 % 4;
  58. }
  59. else // if (p.dims == 3)
  60. {
  61. int size = p.w * p.h;
  62. ivec4 z4 = i4 / size;
  63. ivec4 y4 = i4 % size / p.w;
  64. ivec4 x4 = i4 % size % p.w;
  65. v_offset = (z4 / 4) * p.cstep + y4 * p.w + x4;
  66. lane4 = z4 % 4;
  67. }
  68. vec4 v;
  69. vec4 v4;
  70. v4 = bottom_blob_data[v_offset.r];
  71. if (lane4.r == 0) v.r = v4.r;
  72. else if (lane4.r == 1) v.r = v4.g;
  73. else if (lane4.r == 2) v.r = v4.b;
  74. else /* if (lane4.r == 3) */ v.r = v4.a;
  75. v4 = bottom_blob_data[v_offset.g];
  76. if (lane4.g == 0) v.g = v4.r;
  77. else if (lane4.g == 1) v.g = v4.g;
  78. else if (lane4.g == 2) v.g = v4.b;
  79. else /* if (lane4.g == 3) */ v.g = v4.a;
  80. v4 = bottom_blob_data[v_offset.b];
  81. if (lane4.b == 0) v.b = v4.r;
  82. else if (lane4.b == 1) v.b = v4.g;
  83. else if (lane4.b == 2) v.b = v4.b;
  84. else /* if (lane4.b == 3) */ v.b = v4.a;
  85. v4 = bottom_blob_data[v_offset.a];
  86. if (lane4.a == 0) v.a = v4.r;
  87. else if (lane4.a == 1) v.a = v4.g;
  88. else if (lane4.a == 2) v.a = v4.b;
  89. else /* if (lane4.a == 3) */ v.a = v4.a;
  90. if (ndim == 1) top_blob_data[gx] = v;
  91. if (ndim == 2) top_blob_data[gy * p.outw + gx] = v;
  92. if (ndim == 3) top_blob_data[gz * p.outcstep + gy * p.outw + gx] = v;
  93. }