diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index 5a06b136..be6cf12a 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -71,6 +71,18 @@ namespace Tensorflow return new Operation[] {}; } } + + internal static Tensor fix_image_flip_shape(Tensor image, Tensor result) + { + TensorShape image_shape = image.get_shape(); + if (image_shape == tensor_shape.unknown_shape()) + { + result.set_shape(new { null, null, null }); + } else { + result.set_shape(image_shape); + } + return result; + } public static Tensor random_flip_up_down(Tensor image, int seed = 0) => _random_flip(image: image, @@ -122,6 +134,32 @@ namespace Tensorflow } } + public static Tensor flip_left_right(Tensor image) + => _flip(image, 1, "flip_left_right"); + + public static Tensor flip_up_down(Tensor image) + => _flip(image, 1, "flip_up_down"); + + internal static Tensor _flip(Tensor image, int flip_index, string scope_name) + { + return tf_with(ops.name_scope(null, scope_name, new { image }), delegate + { + image = ops.convert_to_tensor(image, name: "image"); + image = _AssertAtLeast3DImage(image); + TensorShape shape = image.get_shape(); + if ( shape.NDims == 3 || shape.NDims == null ) + { + return fix_image_flip_shape(image, array_ops.reverse(image, new { flip_index })); + } else if ( shape.NDims == 4 ) + { + return array_ops.reverse(image, new { flip_index + 1 }); + } else + { + throw new ValueError("\'image\' must have either 3 or 4 dimensions."); + } + }); + } + public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, string name = null, bool expand_animations = true) {