diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index be614294..1d2e55a7 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -99,13 +99,12 @@ namespace Tensorflow
///
///
///
- ///
///
/// A `Tensor` with the same data as `input`, but its shape has an additional
/// dimension of size 1 added.
///
- public Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
- => array_ops.expand_dims(input, axis, name, dim);
+ public Tensor expand_dims(Tensor input, int axis = -1, string name = null)
+ => array_ops.expand_dims(input, axis, name);
///
/// Creates a tensor filled with a scalar value.
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
index 685b0e38..698e6fcc 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
@@ -15,7 +15,7 @@ namespace Tensorflow.NumPy
public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException("");
[AutoNumPy]
- public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException("");
+ public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1));
[AutoNumPy]
public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape);
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index e821dfb0..3dc8cf12 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -300,10 +300,7 @@ namespace Tensorflow
return result;
}
- public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
- => expand_dims_v2(input, axis, name);
-
- private static Tensor expand_dims_v2(Tensor input, int axis, string name = null)
+ public static Tensor expand_dims(Tensor input, int axis = -1, string name = null)
=> gen_array_ops.expand_dims(input, axis, name);
///
diff --git a/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs
new file mode 100644
index 00000000..a7437f66
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/NumPy/Manipulation.Test.cs
@@ -0,0 +1,28 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using Tensorflow.NumPy;
+
+namespace TensorFlowNET.UnitTest.NumPy
+{
+ ///
+ /// https://numpy.org/doc/stable/reference/routines.array-manipulation.html
+ ///
+ [TestClass]
+ public class ManipulationTest : EagerModeTestBase
+ {
+ [TestMethod]
+ public void expand_dims()
+ {
+ var x = np.array(new[] { 1, 2 });
+ var y = np.expand_dims(x, axis: 0);
+ Assert.AreEqual(y.shape, (1, 2));
+
+ y = np.expand_dims(x, axis: 1);
+ Assert.AreEqual(y.shape, (2, 1));
+ }
+ }
+}