diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index ecac37eb..4d9c3da5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -162,14 +162,17 @@ namespace Tensorflow /// Reverses specific dimensions of a tensor. /// /// - /// + /// The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)). /// /// - public Tensor reverse(Tensor tensor, int[] axis, string name = null) - => gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); - - public Tensor reverse(Tensor tensor, Tensor axis, string name = null) - => gen_array_ops.reverse(tensor, axis, name: name); + public Tensor reverse(Tensor tensor, Axis axis, string name = null) + { + if (axis.IsScalar) + { + axis = new Axis(axis.axis); + } + return array_ops.reverse(tensor, axis, name: name); + } /// /// Returns the rank of a tensor. diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 9d4647fa..f80dcd2c 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -413,6 +413,16 @@ namespace Tensorflow return gen_array_ops.reshape(tensor, dims, name: name); } + public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) + => tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tidx = op.get_attr("Tidx") + } + }); + private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => @@ -658,11 +668,9 @@ namespace Tensorflow } }); - public static Tensor tile(Tensor input, object[] multiples, string name = null) + /*public static Tensor tile(Tensor input, Shape multiples, string name = null) { - Shape dims = shape_utils.from_object_array(multiples); - - return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims) + return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples) { GetGradientAttrs = (op) => new { @@ -670,7 +678,7 @@ namespace Tensorflow Tmultiples = op.get_attr("Tmultiples") } }); - } + }*/ public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) { diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs index 72f598e4..675689bb 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs @@ -2,6 +2,7 @@ using Tensorflow.NumPy; using Tensorflow; using static Tensorflow.Binding; +using System.Linq; namespace TensorFlowNET.UnitTest.ManagedAPI { @@ -92,5 +93,17 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreEqual(ta.read(1).numpy(), 20f); Assert.AreEqual(ta.read(2).numpy(), 30f); } + + /// + /// https://www.tensorflow.org/api_docs/python/tf/reverse + /// + [TestMethod] + public void ReverseArray() + { + var a = tf.random.normal((2, 3)); + var b = tf.reverse(a, -1); + Assert.IsTrue(Equal(a[0].ToArray().Reverse().ToArray(), b[0].ToArray())); + Assert.IsTrue(Equal(a[1].ToArray().Reverse().ToArray(), b[1].ToArray())); + } } }