diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index 698e6fcc..091509fd 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -25,5 +25,8 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); + + [AutoNumPy] + public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 71ae89bd..638559e6 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -792,6 +792,26 @@ namespace Tensorflow }); } + public static Tensor moveaxis(NDArray array, Axis source, Axis destination) + { + List perm = null; + source = source.axis.Select(x => x < 0 ? array.rank + x : x).ToArray(); + destination = destination.axis.Select(x => x < 0 ? array.rank + x : x).ToArray(); + + if (array.shape.rank > -1) + { + perm = range(0, array.rank).Where(i => !source.axis.Contains(i)).ToList(); + foreach (var (dest, src) in zip(destination.axis, source.axis).OrderBy(x => x.Item1)) + { + perm.Insert(dest, src); + } + } + else + throw new NotImplementedException(""); + + return array_ops.transpose(array, perm.ToArray()); + } + /// /// Computes the shape of a broadcast given symbolic shapes. /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of diff --git a/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs index a7437f66..d9c04be6 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs @@ -24,5 +24,19 @@ namespace TensorFlowNET.UnitTest.NumPy y = np.expand_dims(x, axis: 1); Assert.AreEqual(y.shape, (2, 1)); } + + [TestMethod] + public void moveaxis() + { + var x = np.zeros((3, 4, 5)); + var y = np.moveaxis(x, 0, -1); + Assert.AreEqual(y.shape, (4, 5, 3)); + + y = np.moveaxis(x, (0, 1), (-1, -2)); + Assert.AreEqual(y.shape, (5, 4, 3)); + + y = np.moveaxis(x, (0, 1, 2), (-1, -2, -3)); + Assert.AreEqual(y.shape, (5, 4, 3)); + } } }