You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping2D.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. using Tensorflow.Keras.Saving;
  4. namespace Tensorflow.Keras.Layers.Reshaping
  5. {
  6. /// <summary>
  7. /// Crop the input along axis 1 and 2.
  8. /// <para> For example: </para>
  9. /// <para> shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) </para>
  10. /// </summary>
  11. public class Cropping2D : Layer
  12. {
  13. Cropping2DArgs args;
  14. public Cropping2D(Cropping2DArgs args) : base(args)
  15. {
  16. this.args = args;
  17. }
  18. public override void build(KerasShapesWrapper input_shape)
  19. {
  20. built = true;
  21. _buildInputShape = input_shape;
  22. }
  23. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  24. {
  25. Tensor output = inputs;
  26. if (output.rank != 4)
  27. {
  28. // throw an ValueError exception
  29. throw new ValueError("Expected dim=4, found dim=" + output.rank);
  30. }
  31. if (args.cropping.shape == new Shape(1))
  32. {
  33. int crop = args.cropping[0];
  34. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  35. {
  36. output = output[new Slice(),
  37. new Slice(crop, (int)output.shape[1] - crop),
  38. new Slice(crop, (int)output.shape[2] - crop),
  39. new Slice()];
  40. }
  41. else
  42. {
  43. output = output[new Slice(),
  44. new Slice(),
  45. new Slice(crop, (int)output.shape[2] - crop),
  46. new Slice(crop, (int)output.shape[3] - crop)];
  47. }
  48. }
  49. // a tuple of 2 integers
  50. else if (args.cropping.shape == new Shape(2))
  51. {
  52. int crop_1 = args.cropping[0];
  53. int crop_2 = args.cropping[1];
  54. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  55. {
  56. output = output[new Slice(),
  57. new Slice(crop_1, (int)output.shape[1] - crop_1),
  58. new Slice(crop_2, (int)output.shape[2] - crop_2),
  59. new Slice()];
  60. }
  61. else
  62. {
  63. output = output[new Slice(),
  64. new Slice(),
  65. new Slice(crop_1, (int)output.shape[2] - crop_1),
  66. new Slice(crop_2, (int)output.shape[3] - crop_2)];
  67. }
  68. }
  69. else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2)
  70. {
  71. int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
  72. int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
  73. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  74. {
  75. output = output[new Slice(),
  76. new Slice(x_start, (int)output.shape[1] - x_end),
  77. new Slice(y_start, (int)output.shape[2] - y_end),
  78. new Slice()];
  79. }
  80. else
  81. {
  82. output = output[new Slice(),
  83. new Slice(),
  84. new Slice(x_start, (int)output.shape[2] - x_end),
  85. new Slice(y_start, (int)output.shape[3] - y_end)
  86. ];
  87. }
  88. }
  89. return output;
  90. }
  91. public override Shape ComputeOutputShape(Shape input_shape)
  92. {
  93. if (args.cropping.shape == new Shape(1))
  94. {
  95. int crop = args.cropping[0];
  96. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  97. {
  98. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]);
  99. }
  100. else
  101. {
  102. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2);
  103. }
  104. }
  105. // a tuple of 2 integers
  106. else if (args.cropping.shape == new Shape(2))
  107. {
  108. int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
  109. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  110. {
  111. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]);
  112. }
  113. else
  114. {
  115. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2);
  116. }
  117. }
  118. else if (args.cropping.shape == new Shape(2, 2))
  119. {
  120. int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
  121. int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
  122. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  123. {
  124. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end,
  125. (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]);
  126. }
  127. else
  128. {
  129. return new Shape((int)input_shape[0], (int)input_shape[1],
  130. (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end);
  131. }
  132. }
  133. else
  134. {
  135. throw new ValueError();
  136. }
  137. }
  138. }
  139. }