diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs
index ce7203de..3e6791e3 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs
@@ -7,5 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public class MergeArgs : LayerArgs
{
public Tensors Inputs { get; set; }
+ public int Axis { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 68ad21c2..87f16380 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -407,7 +407,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
var ret = tensor.TensorShape.unknown_shape(shape.dims[0]);
var value = constant_value(tensor);
- if (value != null)
+ if (!(value is null))
{
int[] d_ = { };
foreach (int d in value)
@@ -418,7 +418,6 @@ would not be rank 1.", tensor.op.get_attr("axis")));
d_[d_.Length] = -1; // None
}
ret = ret.merge_with(new TensorShape(d_));
-
}
return ret;
}
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index 39557173..a55791d2 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -226,5 +226,25 @@ namespace Tensorflow.Keras
x.set_shape(output_shape);
return x;
}
+
+ ///
+ /// Concatenates a list of tensors alongside the specified axis.
+ ///
+ /// list of tensors to concatenate.
+ /// concatenation axis.
+ ///
+ public Tensor concatenate(Tensors tensors, int axis = -1)
+ {
+ if(axis < 0)
+ {
+ var rank = tensors[0].NDims;
+ if (rank > -1)
+ axis %= rank;
+ else
+ axis = 0;
+ }
+
+ return array_ops.concat(tensors, axis);
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index 958ef07e..22fba034 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -177,13 +177,13 @@ namespace Tensorflow.Keras.Engine
tf.init_scope();
tf.Context.eager_mode();
- build(inputs.shape);
+ build(inputs);
tf.Context.restore_mode();
built = true;
}
- protected virtual void build(TensorShape input_shape)
+ protected virtual void build(Tensors inputs)
{
built = true;
}
diff --git a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs
index 18bd5c55..bbbe495c 100644
--- a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs
@@ -52,8 +52,9 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.dims;
}
- protected override void build(TensorShape input_shape)
+ protected override void build(Tensors inputs)
{
+ TensorShape input_shape = inputs.shape;
var ndims = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
diff --git a/src/TensorFlowNET.Keras/Layers/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolutional.cs
index a7eb9aa6..7814f9c0 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolutional.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolutional.cs
@@ -56,8 +56,9 @@ namespace Tensorflow.Keras.Layers
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
}
- protected override void build(TensorShape input_shape)
+ protected override void build(Tensors inputs)
{
+ TensorShape input_shape = inputs.shape;
int channel_axis = data_format == "channels_first" ? 1 : -1;
int input_channel = channel_axis < 0 ?
input_shape.dims[input_shape.ndim + channel_axis] :
diff --git a/src/TensorFlowNET.Keras/Layers/Dense.cs b/src/TensorFlowNET.Keras/Layers/Dense.cs
index a01f3df7..7f992c5e 100644
--- a/src/TensorFlowNET.Keras/Layers/Dense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Dense.cs
@@ -41,8 +41,9 @@ namespace Tensorflow.Keras.Layers
this.inputSpec = new InputSpec(min_ndim: 2);
}
- protected override void build(TensorShape input_shape)
+ protected override void build(Tensors inputs)
{
+ TensorShape input_shape = inputs.shape;
var last_dim = input_shape.dims.Last();
var axes = new Dictionary();
axes[-1] = last_dim;
diff --git a/src/TensorFlowNET.Keras/Layers/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Embedding.cs
index 9962ff25..36bbd152 100644
--- a/src/TensorFlowNET.Keras/Layers/Embedding.cs
+++ b/src/TensorFlowNET.Keras/Layers/Embedding.cs
@@ -52,7 +52,7 @@ namespace Tensorflow.Keras.Layers
SupportsMasking = mask_zero;
}
- protected override void build(TensorShape input_shape)
+ protected override void build(Tensors inputs)
{
tf.Context.eager_mode();
embeddings = add_weight(shape: (input_dim, output_dim),
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
new file mode 100644
index 00000000..beaabd48
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
@@ -0,0 +1,22 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial class LayersApi
+ {
+ ///
+ /// Layer that concatenates a list of inputs.
+ ///
+ /// Axis along which to concatenate.
+ ///
+ public Concatenate Concatenate(int axis = -1)
+ => new Concatenate(new MergeArgs
+ {
+ Axis = axis
+ });
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merge.cs
index bfed03ad..c0fa3f36 100644
--- a/src/TensorFlowNET.Keras/Layers/Merge.cs
+++ b/src/TensorFlowNET.Keras/Layers/Merge.cs
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers
}
- protected override void build(TensorShape input_shape)
+ protected override void build(Tensors inputs)
{
// output_shape = input_shape.dims[1^];
}
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers
return _merge_function(inputs);
}
- Tensors _merge_function(Tensors inputs)
+ protected virtual Tensors _merge_function(Tensors inputs)
{
var output = inputs[0];
foreach (var i in range(1, inputs.Length))
diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
new file mode 100644
index 00000000..a4309949
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
@@ -0,0 +1,47 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Utils;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Layer that concatenates a list of inputs.
+ ///
+ public class Concatenate : Merge
+ {
+ MergeArgs args;
+ int axis => args.Axis;
+
+ public Concatenate(MergeArgs args) : base(args)
+ {
+ this.args = args;
+ }
+
+ protected override void build(Tensors inputs)
+ {
+ /*var shape_set = new HashSet();
+ var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray();
+ for (var i = 0; i < reduced_inputs_shapes.Length; i++)
+ {
+ int seq = -1;
+ TensorShape shape = reduced_inputs_shapes[i].Where(x =>
+ {
+ seq++;
+ return seq != i;
+ }).ToArray();
+ shape_set.Add(shape);
+ }*/
+ }
+
+ protected override Tensors _merge_function(Tensors inputs)
+ {
+ return keras.backend.concatenate(inputs, axis: axis);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs
new file mode 100644
index 00000000..5dad1390
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs
@@ -0,0 +1,20 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class LayersMergingTest : EagerModeTestBase
+ {
+ [TestMethod]
+ public void Concatenate()
+ {
+ var x = np.arange(20).reshape(2, 2, 5);
+ var y = np.arange(20, 30).reshape(2, 1, 5);
+ var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
+ Assert.AreEqual((2, 3, 5), z.shape);
+ }
+ }
+}