Browse Source

fix tf.transpose #673

tags/keras_v0.3.0
Oceania2018 5 years ago
parent
commit
0ff31a35ec
4 changed files with 77 additions and 16 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +16
    -1
      src/TensorFlowNET.Core/Operations/array_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  4. +59
    -13
      test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -152,7 +152,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <param name="conjugate"></param>
/// <returns></returns>
public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false)
public Tensor transpose<T1>(T1 a, TensorShape perm = null, string name = "transpose", bool conjugate = false)
=> array_ops.transpose(a, perm, name, conjugate);

/// <summary>


+ 16
- 1
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -779,7 +779,22 @@ namespace Tensorflow
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
}

public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_tensor = ops.convert_to_tensor(a);
if(perm == null)
{
var rank = a_tensor.rank;
perm = range(0, rank).OrderByDescending(x => x).ToArray();
}

return gen_array_ops.transpose(a_tensor, perm, name: scope);
});
}

public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -531,7 +531,7 @@ namespace Tensorflow
input, multiples).FirstOrDefault(),
input);

public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null)
public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
{
if (tf.Context.executing_eagerly())
{


+ 59
- 13
test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs View File

@@ -8,27 +8,73 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestClass]
public class TensorOperate
{
[TestMethod, Ignore]
[TestMethod]
public void TransposeTest()
{
// https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2
var x = tf.constant(new int[,] {
var x = tf.constant(new int[,]
{
{ 1, 2, 3 },
{ 4, 5, 6 }
});
var transpose_x = tf.transpose(x);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 4 }, transpose_x[0].numpy().ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 5 }, transpose_x[1].numpy().ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 6 }, transpose_x[2].numpy().ToArray<int>()));
Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy());
Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy());
Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy());

#region constant a
var a = tf.constant(np.array(new[, , ,]
{
{
{
{ 1, 11, 2, 22 }
},
{
{ 3, 33, 4, 44 }
}
},
{
{
{ 5, 55, 6, 66 }
},
{
{ 7, 77, 8, 88 }
}
}
}));

#endregion
var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 });

var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } },
{ { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } }));
var b = tf.transpose(a, new[] { 3, 1, 2, 0 });
var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } },
{ { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } },
{ { { 22, 66 } }, { { 44, 88 } } } }));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape));
Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>()));
#region constant transpose_a
var expected_transposed_a = tf.constant(np.array(new[, , ,]
{
{
{ { 1, 5 } }, { { 3, 7 } }
},
{
{ { 11, 55 } }, { { 33, 77 } }
},
{
{
{ 2, 6 }
},
{
{ 4, 8 }
}
},
{
{
{ 22, 66 }
},
{
{ 44, 88 }
}
}
}));
#endregion
Assert.AreEqual((4, 2, 1, 2 ), actual_transposed_a.TensorShape);
Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy());
}

[TestMethod]


Loading…
Cancel
Save