| @@ -99,13 +99,12 @@ namespace Tensorflow | |||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| /// <param name="axis"></param> | /// <param name="axis"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <param name="dim"></param> | |||||
| /// <returns> | /// <returns> | ||||
| /// A `Tensor` with the same data as `input`, but its shape has an additional | /// A `Tensor` with the same data as `input`, but its shape has an additional | ||||
| /// dimension of size 1 added. | /// dimension of size 1 added. | ||||
| /// </returns> | /// </returns> | ||||
| public Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||||
| => array_ops.expand_dims(input, axis, name, dim); | |||||
| public Tensor expand_dims(Tensor input, int axis = -1, string name = null) | |||||
| => array_ops.expand_dims(input, axis, name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a tensor filled with a scalar value. | /// Creates a tensor filled with a scalar value. | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow.NumPy | |||||
| public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); | public static NDArray dstack(params NDArray[] tup) => throw new NotImplementedException(""); | ||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray expand_dims(NDArray a, Axis? axis = null) => throw new NotImplementedException(""); | |||||
| public static NDArray expand_dims(NDArray a, Axis? axis = null) => new NDArray(array_ops.expand_dims(a, axis: axis ?? -1)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); | public static NDArray reshape(NDArray x1, Shape newshape) => x1.reshape(newshape); | ||||
| @@ -300,10 +300,7 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| } | } | ||||
| public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) | |||||
| => expand_dims_v2(input, axis, name); | |||||
| private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) | |||||
| public static Tensor expand_dims(Tensor input, int axis = -1, string name = null) | |||||
| => gen_array_ops.expand_dims(input, axis, name); | => gen_array_ops.expand_dims(input, axis, name); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,28 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow; | |||||
| using Tensorflow.NumPy; | |||||
| namespace TensorFlowNET.UnitTest.NumPy | |||||
| { | |||||
| /// <summary> | |||||
| /// https://numpy.org/doc/stable/reference/routines.array-manipulation.html | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class ManipulationTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void expand_dims() | |||||
| { | |||||
| var x = np.array(new[] { 1, 2 }); | |||||
| var y = np.expand_dims(x, axis: 0); | |||||
| Assert.AreEqual(y.shape, (1, 2)); | |||||
| y = np.expand_dims(x, axis: 1); | |||||
| Assert.AreEqual(y.shape, (2, 1)); | |||||
| } | |||||
| } | |||||
| } | |||||