Browse Source

Fix tf.reverse.

tags/v0.110.4-Transformer-Model
Haiping Chen 2 years ago
parent
commit
03472997e4
3 changed files with 35 additions and 11 deletions
  1. +9
    -6
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +13
    -5
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +13
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

+ 9
- 6
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -162,14 +162,17 @@ namespace Tensorflow
/// Reverses specific dimensions of a tensor. /// Reverses specific dimensions of a tensor.
/// </summary> /// </summary>
/// <param name="tensor"></param> /// <param name="tensor"></param>
/// <param name="axis"></param>
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name);

public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
=> gen_array_ops.reverse(tensor, axis, name: name);
public Tensor reverse(Tensor tensor, Axis axis, string name = null)
{
if (axis.IsScalar)
{
axis = new Axis(axis.axis);
}
return array_ops.reverse(tensor, axis, name: name);
}


/// <summary> /// <summary>
/// Returns the rank of a tensor. /// Returns the rank of a tensor.


+ 13
- 5
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -413,6 +413,16 @@ namespace Tensorflow
return gen_array_ops.reshape(tensor, dims, name: name); return gen_array_ops.reshape(tensor, dims, name: name);
} }


public static Tensor reverse(Tensor tensor, Tensor axis, string name = null)
=> tf.Context.ExecuteOp("ReverseV2", name, new ExecuteOpArgs(tensor, axis)
{
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
Tidx = op.get_attr<TF_DataType>("Tidx")
}
});

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{ {
return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope =>
@@ -658,11 +668,9 @@ namespace Tensorflow
} }
}); });


public static Tensor tile(Tensor input, object[] multiples, string name = null)
/*public static Tensor tile(Tensor input, Shape multiples, string name = null)
{ {
Shape dims = shape_utils.from_object_array(multiples);

return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)
{ {
GetGradientAttrs = (op) => new GetGradientAttrs = (op) => new
{ {
@@ -670,7 +678,7 @@ namespace Tensorflow
Tmultiples = op.get_attr<TF_DataType>("Tmultiples") Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
} }
}); });
}
}*/


public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{ {


+ 13
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.NumPy; using Tensorflow.NumPy;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Linq;


namespace TensorFlowNET.UnitTest.ManagedAPI namespace TensorFlowNET.UnitTest.ManagedAPI
{ {
@@ -92,5 +93,17 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(ta.read(1).numpy(), 20f); Assert.AreEqual(ta.read(1).numpy(), 20f);
Assert.AreEqual(ta.read(2).numpy(), 30f); Assert.AreEqual(ta.read(2).numpy(), 30f);
} }

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/reverse
/// </summary>
[TestMethod]
public void ReverseArray()
{
var a = tf.random.normal((2, 3));
var b = tf.reverse(a, -1);
Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>()));
Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>()));
}
} }
} }

Loading…
Cancel
Save