You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Cropping2D.cs 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using Tensorflow.Keras.ArgsDefinition.Reshaping;
  2. using Tensorflow.Keras.Engine;
  3. namespace Tensorflow.Keras.Layers.Reshaping
  4. {
  5. /// <summary>
  6. /// Crop the input along axis 1 and 2.
  7. /// <para> For example: </para>
  8. /// <para> shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) </para>
  9. /// </summary>
  10. public class Cropping2D : Layer
  11. {
  12. Cropping2DArgs args;
  13. public Cropping2D(Cropping2DArgs args) : base(args)
  14. {
  15. this.args = args;
  16. }
  17. public override void build(Shape input_shape)
  18. {
  19. built = true;
  20. _buildInputShape = input_shape;
  21. }
  22. protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
  23. {
  24. Tensor output = inputs;
  25. if (output.rank != 4)
  26. {
  27. // throw an ValueError exception
  28. throw new ValueError("Expected dim=4, found dim=" + output.rank);
  29. }
  30. if (args.cropping.shape == new Shape(1))
  31. {
  32. int crop = args.cropping[0];
  33. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  34. {
  35. output = output[new Slice(),
  36. new Slice(crop, (int)output.shape[1] - crop),
  37. new Slice(crop, (int)output.shape[2] - crop),
  38. new Slice()];
  39. }
  40. else
  41. {
  42. output = output[new Slice(),
  43. new Slice(),
  44. new Slice(crop, (int)output.shape[2] - crop),
  45. new Slice(crop, (int)output.shape[3] - crop)];
  46. }
  47. }
  48. // a tuple of 2 integers
  49. else if (args.cropping.shape == new Shape(2))
  50. {
  51. int crop_1 = args.cropping[0];
  52. int crop_2 = args.cropping[1];
  53. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  54. {
  55. output = output[new Slice(),
  56. new Slice(crop_1, (int)output.shape[1] - crop_1),
  57. new Slice(crop_2, (int)output.shape[2] - crop_2),
  58. new Slice()];
  59. }
  60. else
  61. {
  62. output = output[new Slice(),
  63. new Slice(),
  64. new Slice(crop_1, (int)output.shape[2] - crop_1),
  65. new Slice(crop_2, (int)output.shape[3] - crop_2)];
  66. }
  67. }
  68. else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2)
  69. {
  70. int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
  71. int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
  72. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  73. {
  74. output = output[new Slice(),
  75. new Slice(x_start, (int)output.shape[1] - x_end),
  76. new Slice(y_start, (int)output.shape[2] - y_end),
  77. new Slice()];
  78. }
  79. else
  80. {
  81. output = output[new Slice(),
  82. new Slice(),
  83. new Slice(x_start, (int)output.shape[2] - x_end),
  84. new Slice(y_start, (int)output.shape[3] - y_end)
  85. ];
  86. }
  87. }
  88. return output;
  89. }
  90. public override Shape ComputeOutputShape(Shape input_shape)
  91. {
  92. if (args.cropping.shape == new Shape(1))
  93. {
  94. int crop = args.cropping[0];
  95. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  96. {
  97. return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]);
  98. }
  99. else
  100. {
  101. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2);
  102. }
  103. }
  104. // a tuple of 2 integers
  105. else if (args.cropping.shape == new Shape(2))
  106. {
  107. int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
  108. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  109. {
  110. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]);
  111. }
  112. else
  113. {
  114. return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2);
  115. }
  116. }
  117. else if (args.cropping.shape == new Shape(2, 2))
  118. {
  119. int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
  120. int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
  121. if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
  122. {
  123. return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end,
  124. (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]);
  125. }
  126. else
  127. {
  128. return new Shape((int)input_shape[0], (int)input_shape[1],
  129. (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end);
  130. }
  131. }
  132. else
  133. {
  134. throw new ValueError();
  135. }
  136. }
  137. }
  138. }