using Tensorflow.Keras.ArgsDefinition.Reshaping; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; namespace Tensorflow.Keras.Layers.Reshaping { /// /// Crop the input along axis 1 and 2. /// For example: /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) /// public class Cropping2D : Layer { Cropping2DArgs args; public Cropping2D(Cropping2DArgs args) : base(args) { this.args = args; } public override void build(KerasShapesWrapper input_shape) { built = true; _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { Tensor output = inputs; if (output.rank != 4) { // throw an ValueError exception throw new ValueError("Expected dim=4, found dim=" + output.rank); } if (args.cropping.shape == new Shape(1)) { int crop = args.cropping[0]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(crop, (int)output.shape[1] - crop), new Slice(crop, (int)output.shape[2] - crop), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(crop, (int)output.shape[2] - crop), new Slice(crop, (int)output.shape[3] - crop)]; } } // a tuple of 2 integers else if (args.cropping.shape == new Shape(2)) { int crop_1 = args.cropping[0]; int crop_2 = args.cropping[1]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(crop_1, (int)output.shape[1] - crop_1), new Slice(crop_2, (int)output.shape[2] - crop_2), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(crop_1, (int)output.shape[2] - crop_1), new Slice(crop_2, (int)output.shape[3] - crop_2)]; } } else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2) { int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1]; int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { output = output[new Slice(), new Slice(x_start, (int)output.shape[1] - x_end), new Slice(y_start, (int)output.shape[2] - y_end), new Slice()]; } else { output = output[new Slice(), new Slice(), new Slice(x_start, (int)output.shape[2] - x_end), new Slice(y_start, (int)output.shape[3] - y_end) ]; } } return output; } public override Shape ComputeOutputShape(Shape input_shape) { if (args.cropping.shape == new Shape(1)) { int crop = args.cropping[0]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2); } } // a tuple of 2 integers else if (args.cropping.shape == new Shape(2)) { int crop_1 = args.cropping[0], crop_2 = args.cropping[1]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2); } } else if (args.cropping.shape == new Shape(2, 2)) { int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1]; int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1]; if (args.data_format == Cropping2DArgs.DataFormat.channels_last) { return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end, (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]); } else { return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end); } } else { throw new ValueError(); } } } }