diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs index ce7203de..3e6791e3 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs @@ -7,5 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition public class MergeArgs : LayerArgs { public Tensors Inputs { get; set; } + public int Axis { get; set; } } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 68ad21c2..87f16380 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -407,7 +407,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); var ret = tensor.TensorShape.unknown_shape(shape.dims[0]); var value = constant_value(tensor); - if (value != null) + if (!(value is null)) { int[] d_ = { }; foreach (int d in value) @@ -418,7 +418,6 @@ would not be rank 1.", tensor.op.get_attr("axis"))); d_[d_.Length] = -1; // None } ret = ret.merge_with(new TensorShape(d_)); - } return ret; } diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 39557173..a55791d2 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -226,5 +226,25 @@ namespace Tensorflow.Keras x.set_shape(output_shape); return x; } + + /// + /// Concatenates a list of tensors alongside the specified axis. + /// + /// list of tensors to concatenate. + /// concatenation axis. + /// + public Tensor concatenate(Tensors tensors, int axis = -1) + { + if(axis < 0) + { + var rank = tensors[0].NDims; + if (rank > -1) + axis %= rank; + else + axis = 0; + } + + return array_ops.concat(tensors, axis); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 958ef07e..22fba034 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -177,13 +177,13 @@ namespace Tensorflow.Keras.Engine tf.init_scope(); tf.Context.eager_mode(); - build(inputs.shape); + build(inputs); tf.Context.restore_mode(); built = true; } - protected virtual void build(TensorShape input_shape) + protected virtual void build(Tensors inputs) { built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs index 18bd5c55..bbbe495c 100644 --- a/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/BatchNormalization.cs @@ -52,8 +52,9 @@ namespace Tensorflow.Keras.Layers axis = args.Axis.dims; } - protected override void build(TensorShape input_shape) + protected override void build(Tensors inputs) { + TensorShape input_shape = inputs.shape; var ndims = input_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) diff --git a/src/TensorFlowNET.Keras/Layers/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolutional.cs index a7eb9aa6..7814f9c0 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolutional.cs @@ -56,8 +56,9 @@ namespace Tensorflow.Keras.Layers _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); } - protected override void build(TensorShape input_shape) + protected override void build(Tensors inputs) { + TensorShape input_shape = inputs.shape; int channel_axis = data_format == "channels_first" ? 1 : -1; int input_channel = channel_axis < 0 ? input_shape.dims[input_shape.ndim + channel_axis] : diff --git a/src/TensorFlowNET.Keras/Layers/Dense.cs b/src/TensorFlowNET.Keras/Layers/Dense.cs index a01f3df7..7f992c5e 100644 --- a/src/TensorFlowNET.Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Dense.cs @@ -41,8 +41,9 @@ namespace Tensorflow.Keras.Layers this.inputSpec = new InputSpec(min_ndim: 2); } - protected override void build(TensorShape input_shape) + protected override void build(Tensors inputs) { + TensorShape input_shape = inputs.shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Embedding.cs index 9962ff25..36bbd152 100644 --- a/src/TensorFlowNET.Keras/Layers/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Embedding.cs @@ -52,7 +52,7 @@ namespace Tensorflow.Keras.Layers SupportsMasking = mask_zero; } - protected override void build(TensorShape input_shape) + protected override void build(Tensors inputs) { tf.Context.eager_mode(); embeddings = add_weight(shape: (input_dim, output_dim), diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs new file mode 100644 index 00000000..beaabd48 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs @@ -0,0 +1,22 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Layers +{ + public partial class LayersApi + { + /// + /// Layer that concatenates a list of inputs. + /// + /// Axis along which to concatenate. + /// + public Concatenate Concatenate(int axis = -1) + => new Concatenate(new MergeArgs + { + Axis = axis + }); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merge.cs index bfed03ad..c0fa3f36 100644 --- a/src/TensorFlowNET.Keras/Layers/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merge.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers } - protected override void build(TensorShape input_shape) + protected override void build(Tensors inputs) { // output_shape = input_shape.dims[1^]; } @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers return _merge_function(inputs); } - Tensors _merge_function(Tensors inputs) + protected virtual Tensors _merge_function(Tensors inputs) { var output = inputs[0]; foreach (var i in range(1, inputs.Length)) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs new file mode 100644 index 00000000..a4309949 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Layer that concatenates a list of inputs. + /// + public class Concatenate : Merge + { + MergeArgs args; + int axis => args.Axis; + + public Concatenate(MergeArgs args) : base(args) + { + this.args = args; + } + + protected override void build(Tensors inputs) + { + /*var shape_set = new HashSet(); + var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); + for (var i = 0; i < reduced_inputs_shapes.Length; i++) + { + int seq = -1; + TensorShape shape = reduced_inputs_shapes[i].Where(x => + { + seq++; + return seq != i; + }).ToArray(); + shape_set.Add(shape); + }*/ + } + + protected override Tensors _merge_function(Tensors inputs) + { + return keras.backend.concatenate(inputs, axis: axis); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs new file mode 100644 index 00000000..5dad1390 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs @@ -0,0 +1,20 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class LayersMergingTest : EagerModeTestBase + { + [TestMethod] + public void Concatenate() + { + var x = np.arange(20).reshape(2, 2, 5); + var y = np.arange(20, 30).reshape(2, 1, 5); + var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y)); + Assert.AreEqual((2, 3, 5), z.shape); + } + } +}