You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping3D.cs 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. using Tensorflow.Keras.Saving;
  4. using Tensorflow.Common.Types;
  5. namespace Tensorflow.Keras.Layers.Reshaping
  6. {
  7. /// <summary>
  8. /// Similar to copping 2D
  9. /// </summary>
  10. public class Cropping3D : Layer
  11. {
  12. Cropping3DArgs args;
  13. public Cropping3D(Cropping3DArgs args) : base(args)
  14. {
  15. this.args = args;
  16. }
  17. public override void build(KerasShapesWrapper input_shape)
  18. {
  19. built = true;
  20. _buildInputShape = input_shape;
  21. }
  22. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  23. {
  24. Tensor output = inputs;
  25. if (output.rank != 5)
  26. {
  27. // throw an ValueError exception
  28. throw new ValueError("Expected dim=5, found dim=" + output.rank);
  29. }
  30. if (args.cropping.shape == new Shape(1))
  31. {
  32. int crop = args.cropping[0];
  33. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  34. {
  35. output = output[new Slice(),
  36. new Slice(crop, (int)output.shape[1] - crop),
  37. new Slice(crop, (int)output.shape[2] - crop),
  38. new Slice(crop, (int)output.shape[3] - crop),
  39. new Slice()];
  40. }
  41. else
  42. {
  43. output = output[new Slice(),
  44. new Slice(),
  45. new Slice(crop, (int)output.shape[2] - crop),
  46. new Slice(crop, (int)output.shape[3] - crop),
  47. new Slice(crop, (int)output.shape[4] - crop)];
  48. }
  49. }
  50. // int[1][3] equivalent to a tuple of 3 integers
  51. else if (args.cropping.shape == new Shape(3))
  52. {
  53. var crop_1 = args.cropping[0];
  54. var crop_2 = args.cropping[1];
  55. var crop_3 = args.cropping[2];
  56. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  57. {
  58. output = output[new Slice(),
  59. new Slice(crop_1, (int)output.shape[1] - crop_1),
  60. new Slice(crop_2, (int)output.shape[2] - crop_2),
  61. new Slice(crop_3, (int)output.shape[3] - crop_3),
  62. new Slice()];
  63. }
  64. else
  65. {
  66. output = output[new Slice(),
  67. new Slice(),
  68. new Slice(crop_1, (int)output.shape[2] - crop_1),
  69. new Slice(crop_2, (int)output.shape[3] - crop_2),
  70. new Slice(crop_3, (int)output.shape[4] - crop_3)];
  71. }
  72. }
  73. else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2)
  74. {
  75. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  76. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  77. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  78. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  79. {
  80. output = output[new Slice(),
  81. new Slice(x, (int)output.shape[1] - x_end),
  82. new Slice(y, (int)output.shape[2] - y_end),
  83. new Slice(z, (int)output.shape[3] - z_end),
  84. new Slice()];
  85. }
  86. else
  87. {
  88. output = output[new Slice(),
  89. new Slice(),
  90. new Slice(x, (int)output.shape[2] - x_end),
  91. new Slice(y, (int)output.shape[3] - y_end),
  92. new Slice(z, (int)output.shape[4] - z_end)
  93. ];
  94. }
  95. }
  96. return output;
  97. }
  98. public override Shape ComputeOutputShape(Shape input_shape)
  99. {
  100. if (args.cropping.shape == new Shape(1))
  101. {
  102. int crop = args.cropping[0];
  103. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  104. {
  105. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]);
  106. }
  107. else
  108. {
  109. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2);
  110. }
  111. }
  112. // int[1][3] equivalent to a tuple of 3 integers
  113. else if (args.cropping.shape == new Shape(3))
  114. {
  115. var crop_start_1 = args.cropping[0];
  116. var crop_start_2 = args.cropping[1];
  117. var crop_start_3 = args.cropping[2];
  118. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  119. {
  120. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]);
  121. }
  122. else
  123. {
  124. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2);
  125. }
  126. }
  127. else if (args.cropping.shape == new Shape(3, 2))
  128. {
  129. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  130. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  131. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  132. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  133. {
  134. return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]);
  135. }
  136. else
  137. {
  138. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end);
  139. }
  140. }
  141. else
  142. {
  143. throw new ValueError();
  144. }
  145. }
  146. }
  147. }