You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping2D.cs 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. using Tensorflow.Keras.Saving;
  4. using Tensorflow.Common.Types;
  5. namespace Tensorflow.Keras.Layers.Reshaping
  6. {
  7. /// <summary>
  8. /// Crop the input along axis 1 and 2.
  9. /// <para> For example: </para>
  10. /// <para> shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) </para>
  11. /// </summary>
  12. public class Cropping2D : Layer
  13. {
  14. Cropping2DArgs args;
  15. public Cropping2D(Cropping2DArgs args) : base(args)
  16. {
  17. this.args = args;
  18. }
  19. public override void build(KerasShapesWrapper input_shape)
  20. {
  21. built = true;
  22. _buildInputShape = input_shape;
  23. }
  24. protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
  25. {
  26. Tensor output = inputs;
  27. if (output.rank != 4)
  28. {
  29. // throw an ValueError exception
  30. throw new ValueError("Expected dim=4, found dim=" + output.rank);
  31. }
  32. if (args.cropping.shape == new Shape(1))
  33. {
  34. int crop = args.cropping[0];
  35. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  36. {
  37. output = output[new Slice(),
  38. new Slice(crop, (int)output.shape[1] - crop),
  39. new Slice(crop, (int)output.shape[2] - crop),
  40. new Slice()];
  41. }
  42. else
  43. {
  44. output = output[new Slice(),
  45. new Slice(),
  46. new Slice(crop, (int)output.shape[2] - crop),
  47. new Slice(crop, (int)output.shape[3] - crop)];
  48. }
  49. }
  50. // a tuple of 2 integers
  51. else if (args.cropping.shape == new Shape(2))
  52. {
  53. int crop_1 = args.cropping[0];
  54. int crop_2 = args.cropping[1];
  55. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  56. {
  57. output = output[new Slice(),
  58. new Slice(crop_1, (int)output.shape[1] - crop_1),
  59. new Slice(crop_2, (int)output.shape[2] - crop_2),
  60. new Slice()];
  61. }
  62. else
  63. {
  64. output = output[new Slice(),
  65. new Slice(),
  66. new Slice(crop_1, (int)output.shape[2] - crop_1),
  67. new Slice(crop_2, (int)output.shape[3] - crop_2)];
  68. }
  69. }
  70. else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2)
  71. {
  72. int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
  73. int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
  74. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  75. {
  76. output = output[new Slice(),
  77. new Slice(x_start, (int)output.shape[1] - x_end),
  78. new Slice(y_start, (int)output.shape[2] - y_end),
  79. new Slice()];
  80. }
  81. else
  82. {
  83. output = output[new Slice(),
  84. new Slice(),
  85. new Slice(x_start, (int)output.shape[2] - x_end),
  86. new Slice(y_start, (int)output.shape[3] - y_end)
  87. ];
  88. }
  89. }
  90. return output;
  91. }
  92. public override Shape ComputeOutputShape(Shape input_shape)
  93. {
  94. if (args.cropping.shape == new Shape(1))
  95. {
  96. int crop = args.cropping[0];
  97. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  98. {
  99. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]);
  100. }
  101. else
  102. {
  103. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2);
  104. }
  105. }
  106. // a tuple of 2 integers
  107. else if (args.cropping.shape == new Shape(2))
  108. {
  109. int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
  110. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  111. {
  112. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]);
  113. }
  114. else
  115. {
  116. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2);
  117. }
  118. }
  119. else if (args.cropping.shape == new Shape(2, 2))
  120. {
  121. int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
  122. int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
  123. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  124. {
  125. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end,
  126. (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]);
  127. }
  128. else
  129. {
  130. return new Shape((int)input_shape[0], (int)input_shape[1],
  131. (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end);
  132. }
  133. }
  134. else
  135. {
  136. throw new ValueError();
  137. }
  138. }
  139. }
  140. }