From 0ff31a35ec13229c712b79765ca19edca34d9754 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 13 Dec 2020 21:56:22 -0600 Subject: [PATCH] fix tf.transpose #673 --- src/TensorFlowNET.Core/APIs/tf.array.cs | 2 +- .../Operations/array_ops.cs | 17 ++++- .../Operations/gen_array_ops.cs | 2 +- .../ManagedAPI/TensorOperate.cs | 72 +++++++++++++++---- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index a87e041e..76ff6e54 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -152,7 +152,7 @@ namespace Tensorflow /// /// /// - public Tensor transpose(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false) + public Tensor transpose(T1 a, TensorShape perm = null, string name = "transpose", bool conjugate = false) => array_ops.transpose(a, perm, name, conjugate); /// diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 72439100..b0d741f1 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -779,7 +779,22 @@ namespace Tensorflow return gen_array_ops.gather_v2(@params, indices, axis, name: name); } - public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) + public static Tensor transpose(T1 a, TensorShape perm, string name = "transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + var a_tensor = ops.convert_to_tensor(a); + if(perm == null) + { + var rank = a_tensor.rank; + perm = range(0, rank).OrderByDescending(x => x).ToArray(); + } + + return gen_array_ops.transpose(a_tensor, perm, name: scope); + }); + } + + public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false) { return tf_with(ops.name_scope(name, "transpose", new { a }), scope => { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 3087639b..50dde5c3 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -531,7 +531,7 @@ namespace Tensorflow input, multiples).FirstOrDefault(), input); - public static Tensor transpose(T1 x, T2 perm, string name = null) + public static Tensor transpose(Tensor x, T1 perm, string name = null) { if (tf.Context.executing_eagerly()) { diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs index 99becb7f..70ccab89 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs @@ -8,27 +8,73 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestClass] public class TensorOperate { - [TestMethod, Ignore] + [TestMethod] public void TransposeTest() { // https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2 - var x = tf.constant(new int[,] { + var x = tf.constant(new int[,] + { { 1, 2, 3 }, { 4, 5, 6 } }); var transpose_x = tf.transpose(x); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 4 }, transpose_x[0].numpy().ToArray())); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 5 }, transpose_x[1].numpy().ToArray())); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 6 }, transpose_x[2].numpy().ToArray())); + Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy()); + Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy()); + Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy()); + + #region constant a + var a = tf.constant(np.array(new[, , ,] + { + { + { + { 1, 11, 2, 22 } + }, + { + { 3, 33, 4, 44 } + } + }, + { + { + { 5, 55, 6, 66 } + }, + { + { 7, 77, 8, 88 } + } + } + })); + + #endregion + var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 }); - var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } }, - { { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } })); - var b = tf.transpose(a, new[] { 3, 1, 2, 0 }); - var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } }, - { { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } }, - { { { 22, 66 } }, { { 44, 88 } } } })); - Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape)); - Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray(), b.numpy().ToArray())); + #region constant transpose_a + var expected_transposed_a = tf.constant(np.array(new[, , ,] + { + { + { { 1, 5 } }, { { 3, 7 } } + }, + { + { { 11, 55 } }, { { 33, 77 } } + }, + { + { + { 2, 6 } + }, + { + { 4, 8 } + } + }, + { + { + { 22, 66 } + }, + { + { 44, 88 } + } + } + })); + #endregion + Assert.AreEqual((4, 2, 1, 2 ), actual_transposed_a.TensorShape); + Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy()); } [TestMethod]