diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs index f534df6a..63899e04 100644 --- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs +++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs @@ -24,6 +24,65 @@ namespace Tensorflow { public class image_ops_impl { + internal static Tensor _AssertAtLeast3DImage(Tensor image) + => control_flow_ops.with_dependencies( + _CheckAtLeast3DImage(image, require_static: false), image); + + internal static Array _CheckAtLeast3DImage(Tensor image, bool require_static) + { + throw new NotImplementedException(""); + } + + public static Tensor random_flip_up_down(Tensor image, int seed = 0) + => _random_flip(image: image, + flip_index: 0, + seed: seed, + scope_name: "random_flip_up_down"); + + public static Tensor random_flip_left_right(Tensor image, int seed = 0) + => _random_flip(image: image, + flip_index: 1, + seed: seed, + scope_name: "random_flip_left_right"); + + internal static Tensor _random_flip(Tensor image, int flipindex, int seed, + string scope_name) + { + using ( scope = ops.name_scope(null, scope_name, image)) + { + image = ops.convert_to_tensor(image, name: "image"); + image = AssertAtLeast3DImage(image); + var shape = image.get_shape(); + if ( shape.NDims == 3 || shape.NDims == null ) + { + var uniform_random = random_ops.random_uniform(new Tensor [], 0, 1.0, seed: seed); + var mirror_cond = math_ops.less(uniform_random, .5); + var result = control_flow_ops.cond( + pred: mirror_cond, + true_fn: array_ops.reverse(image, flipindex as int[]), + false_fn: image, + name: scope + ); + return fix_image_flip_shape(image, result); + } else if ( shape.NDims == 4 ) + { + var batch_size = array_ops.shape(image)[0]; + var uniform_random = random_ops.random_uniform(batch_size, + 0, + 1.0, + seed: seed); + var flips = math_ops.round( + array_ops.reshape(uniform_random, shape: new Tensor [batch_size, 1, 1, 1])); + flips = math_ops.cast(flips, image.dtype); + var flipped_input = array_ops.reverse(image, flip_index + 1 as int[]); + return flips * flipped_input + (1 - flips) * image; + } else + { + throw new ValueError("'\'image\' must have either 3 or 4 dimensions."); + } + } + } + public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, string name = null, bool expand_animations = true) {