diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index ecac37eb..4d9c3da5 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -162,14 +162,17 @@ namespace Tensorflow
/// Reverses specific dimensions of a tensor.
///
///
- ///
+ /// The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).
///
///
- public Tensor reverse(Tensor tensor, int[] axis, string name = null)
- => gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name);
-
- public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
- => gen_array_ops.reverse(tensor, axis, name: name);
+ public Tensor reverse(Tensor tensor, Axis axis, string name = null)
+ {
+ if (axis.IsScalar)
+ {
+ axis = new Axis(axis.axis);
+ }
+ return array_ops.reverse(tensor, axis, name: name);
+ }
///
/// Returns the rank of a tensor.
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 9d4647fa..f80dcd2c 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -413,6 +413,16 @@ namespace Tensorflow
return gen_array_ops.reshape(tensor, dims, name: name);
}
+ public static Tensor reverse(Tensor tensor, Tensor axis, string name = null)
+ => tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis)
+ {
+ GetGradientAttrs = (op) => new
+ {
+ T = op.get_attr("T"),
+ Tidx = op.get_attr("Tidx")
+ }
+ });
+
private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
@@ -658,11 +668,9 @@ namespace Tensorflow
}
});
- public static Tensor tile(Tensor input, object[] multiples, string name = null)
+ /*public static Tensor tile(Tensor input, Shape multiples, string name = null)
{
- Shape dims = shape_utils.from_object_array(multiples);
-
- return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
+ return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)
{
GetGradientAttrs = (op) => new
{
@@ -670,7 +678,7 @@ namespace Tensorflow
Tmultiples = op.get_attr("Tmultiples")
}
});
- }
+ }*/
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
index 72f598e4..675689bb 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
@@ -2,6 +2,7 @@
using Tensorflow.NumPy;
using Tensorflow;
using static Tensorflow.Binding;
+using System.Linq;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
@@ -92,5 +93,17 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(ta.read(1).numpy(), 20f);
Assert.AreEqual(ta.read(2).numpy(), 30f);
}
+
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/reverse
+ ///
+ [TestMethod]
+ public void ReverseArray()
+ {
+ var a = tf.random.normal((2, 3));
+ var b = tf.reverse(a, -1);
+ Assert.IsTrue(Equal(a[0].ToArray().Reverse().ToArray(), b[0].ToArray()));
+ Assert.IsTrue(Equal(a[1].ToArray().Reverse().ToArray(), b[1].ToArray()));
+ }
}
}