| @@ -162,14 +162,17 @@ namespace Tensorflow | |||||
| /// Reverses specific dimensions of a tensor. | /// Reverses specific dimensions of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <param name="axis"></param> | |||||
| /// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param> | |||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||||
| public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
| => gen_array_ops.reverse(tensor, axis, name: name); | |||||
| public Tensor reverse(Tensor tensor, Axis axis, string name = null) | |||||
| { | |||||
| if (axis.IsScalar) | |||||
| { | |||||
| axis = new Axis(axis.axis); | |||||
| } | |||||
| return array_ops.reverse(tensor, axis, name: name); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the rank of a tensor. | /// Returns the rank of a tensor. | ||||
| @@ -413,6 +413,16 @@ namespace Tensorflow | |||||
| return gen_array_ops.reshape(tensor, dims, name: name); | return gen_array_ops.reshape(tensor, dims, name: name); | ||||
| } | } | ||||
| public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
| => tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis) | |||||
| { | |||||
| GetGradientAttrs = (op) => new | |||||
| { | |||||
| T = op.get_attr<TF_DataType>("T"), | |||||
| Tidx = op.get_attr<TF_DataType>("Tidx") | |||||
| } | |||||
| }); | |||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => | return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => | ||||
| @@ -658,11 +668,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| }); | }); | ||||
| public static Tensor tile(Tensor input, object[] multiples, string name = null) | |||||
| /*public static Tensor tile(Tensor input, Shape multiples, string name = null) | |||||
| { | { | ||||
| Shape dims = shape_utils.from_object_array(multiples); | |||||
| return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims) | |||||
| return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples) | |||||
| { | { | ||||
| GetGradientAttrs = (op) => new | GetGradientAttrs = (op) => new | ||||
| { | { | ||||
| @@ -670,7 +678,7 @@ namespace Tensorflow | |||||
| Tmultiples = op.get_attr<TF_DataType>("Tmultiples") | Tmultiples = op.get_attr<TF_DataType>("Tmultiples") | ||||
| } | } | ||||
| }); | }); | ||||
| } | |||||
| }*/ | |||||
| public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
| { | { | ||||
| @@ -2,6 +2,7 @@ | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Linq; | |||||
| namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
| { | { | ||||
| @@ -92,5 +93,17 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
| Assert.AreEqual(ta.read(1).numpy(), 20f); | Assert.AreEqual(ta.read(1).numpy(), 20f); | ||||
| Assert.AreEqual(ta.read(2).numpy(), 30f); | Assert.AreEqual(ta.read(2).numpy(), 30f); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/reverse | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void ReverseArray() | |||||
| { | |||||
| var a = tf.random.normal((2, 3)); | |||||
| var b = tf.reverse(a, -1); | |||||
| Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>())); | |||||
| Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>())); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||