diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index 94085605..5e257417 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -30,6 +30,15 @@ namespace Tensorflow.NumPy [AutoNumPy] public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); + [AutoNumPy] + public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis)); + + [AutoNumPy] + public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis)); + + [AutoNumPy] + public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis)); + [AutoNumPy] public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs index 36e44e48..9bc2fa76 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Collections.Generic; using Tensorflow.NumPy; using static Tensorflow.KerasApi; @@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers public class LayersMergingTest : EagerModeTestBase { [TestMethod] - public void Concatenate() + [DataRow(1, 4, 1, 5)] + [DataRow(2, 2, 2, 5)] + [DataRow(3, 2, 1, 10)] + public void Concatenate(int axis, int shapeA, int shapeB, int shapeC) { - var x = np.arange(20).reshape((2, 2, 5)); - var y = np.arange(20, 30).reshape((2, 1, 5)); - var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y)); - Assert.AreEqual((2, 3, 5), z.shape); + var x = np.arange(10).reshape((1, 2, 1, 5)); + var y = np.arange(10, 20).reshape((1, 2, 1, 5)); + var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y)); + Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape); } + } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs index cb570fc0..53a67cbf 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs @@ -1,10 +1,13 @@ using Microsoft.VisualStudio.TestPlatform.Utilities; using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json.Linq; using System.Linq; +using System.Xml.Linq; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.NumPy; +using static HDF.PInvoke.H5Z; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -124,4 +127,44 @@ public class ModelLoadTest var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; model.summary(); } + + + + [TestMethod] + public void CreateConcatenateModelSaveAndLoad() + { + // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. + var input_layer = tf.keras.layers.Input((8, 8, 5)); + + var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); + conv1.Name = "conv1"; + + var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); + conv2.Name = "conv2"; + + var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); + concat1.Name = "concat1"; + + var model = tf.keras.Model(input_layer, concat1); + model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); + + model.save(@"Assets/concat_axis3_model"); + + + var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); + + var tensors1 = model.predict(tensorInput); + + Assert.AreEqual((1, 8, 8, 4), tensors1.shape); + + model = null; + keras.backend.clear_session(); + + var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); + + var tensors2 = model2.predict(tensorInput); + + Assert.AreEqual(tensors1.shape, tensors2.shape); + } + }