You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping3D.cs 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. using Tensorflow.Keras.Saving;
  4. using Tensorflow.Common.Types;
  5. namespace Tensorflow.Keras.Layers.Reshaping
  6. {
  7. /// <summary>
  8. /// Similar to copping 2D
  9. /// </summary>
  10. public class Cropping3D : Layer
  11. {
  12. Cropping3DArgs args;
  13. public Cropping3D(Cropping3DArgs args) : base(args)
  14. {
  15. this.args = args;
  16. }
  17. public override void build(KerasShapesWrapper input_shape)
  18. {
  19. built = true;
  20. _buildInputShape = input_shape;
  21. }
  22. <<<<<<< HEAD
  23. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  24. =======
  25. protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null)
  26. >>>>>>> master
  27. {
  28. Tensor output = inputs;
  29. if (output.rank != 5)
  30. {
  31. // throw an ValueError exception
  32. throw new ValueError("Expected dim=5, found dim=" + output.rank);
  33. }
  34. if (args.cropping.shape == new Shape(1))
  35. {
  36. int crop = args.cropping[0];
  37. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  38. {
  39. output = output[new Slice(),
  40. new Slice(crop, (int)output.shape[1] - crop),
  41. new Slice(crop, (int)output.shape[2] - crop),
  42. new Slice(crop, (int)output.shape[3] - crop),
  43. new Slice()];
  44. }
  45. else
  46. {
  47. output = output[new Slice(),
  48. new Slice(),
  49. new Slice(crop, (int)output.shape[2] - crop),
  50. new Slice(crop, (int)output.shape[3] - crop),
  51. new Slice(crop, (int)output.shape[4] - crop)];
  52. }
  53. }
  54. // int[1][3] equivalent to a tuple of 3 integers
  55. else if (args.cropping.shape == new Shape(3))
  56. {
  57. var crop_1 = args.cropping[0];
  58. var crop_2 = args.cropping[1];
  59. var crop_3 = args.cropping[2];
  60. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  61. {
  62. output = output[new Slice(),
  63. new Slice(crop_1, (int)output.shape[1] - crop_1),
  64. new Slice(crop_2, (int)output.shape[2] - crop_2),
  65. new Slice(crop_3, (int)output.shape[3] - crop_3),
  66. new Slice()];
  67. }
  68. else
  69. {
  70. output = output[new Slice(),
  71. new Slice(),
  72. new Slice(crop_1, (int)output.shape[2] - crop_1),
  73. new Slice(crop_2, (int)output.shape[3] - crop_2),
  74. new Slice(crop_3, (int)output.shape[4] - crop_3)];
  75. }
  76. }
  77. else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2)
  78. {
  79. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  80. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  81. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  82. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  83. {
  84. output = output[new Slice(),
  85. new Slice(x, (int)output.shape[1] - x_end),
  86. new Slice(y, (int)output.shape[2] - y_end),
  87. new Slice(z, (int)output.shape[3] - z_end),
  88. new Slice()];
  89. }
  90. else
  91. {
  92. output = output[new Slice(),
  93. new Slice(),
  94. new Slice(x, (int)output.shape[2] - x_end),
  95. new Slice(y, (int)output.shape[3] - y_end),
  96. new Slice(z, (int)output.shape[4] - z_end)
  97. ];
  98. }
  99. }
  100. return output;
  101. }
  102. public override Shape ComputeOutputShape(Shape input_shape)
  103. {
  104. if (args.cropping.shape == new Shape(1))
  105. {
  106. int crop = args.cropping[0];
  107. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  108. {
  109. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]);
  110. }
  111. else
  112. {
  113. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2);
  114. }
  115. }
  116. // int[1][3] equivalent to a tuple of 3 integers
  117. else if (args.cropping.shape == new Shape(3))
  118. {
  119. var crop_start_1 = args.cropping[0];
  120. var crop_start_2 = args.cropping[1];
  121. var crop_start_3 = args.cropping[2];
  122. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  123. {
  124. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]);
  125. }
  126. else
  127. {
  128. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2);
  129. }
  130. }
  131. else if (args.cropping.shape == new Shape(3, 2))
  132. {
  133. int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
  134. int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
  135. int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
  136. if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
  137. {
  138. return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]);
  139. }
  140. else
  141. {
  142. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end);
  143. }
  144. }
  145. else
  146. {
  147. throw new ValueError();
  148. }
  149. }
  150. }
  151. }