| @@ -7,5 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| public class MergeArgs : LayerArgs | public class MergeArgs : LayerArgs | ||||
| { | { | ||||
| public Tensors Inputs { get; set; } | public Tensors Inputs { get; set; } | ||||
| public int Axis { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -407,7 +407,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| var ret = tensor.TensorShape.unknown_shape(shape.dims[0]); | var ret = tensor.TensorShape.unknown_shape(shape.dims[0]); | ||||
| var value = constant_value(tensor); | var value = constant_value(tensor); | ||||
| if (value != null) | |||||
| if (!(value is null)) | |||||
| { | { | ||||
| int[] d_ = { }; | int[] d_ = { }; | ||||
| foreach (int d in value) | foreach (int d in value) | ||||
| @@ -418,7 +418,6 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| d_[d_.Length] = -1; // None | d_[d_.Length] = -1; // None | ||||
| } | } | ||||
| ret = ret.merge_with(new TensorShape(d_)); | ret = ret.merge_with(new TensorShape(d_)); | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -226,5 +226,25 @@ namespace Tensorflow.Keras | |||||
| x.set_shape(output_shape); | x.set_shape(output_shape); | ||||
| return x; | return x; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Concatenates a list of tensors alongside the specified axis. | |||||
| /// </summary> | |||||
| /// <param name="tensors">list of tensors to concatenate.</param> | |||||
| /// <param name="axis">concatenation axis.</param> | |||||
| /// <returns></returns> | |||||
| public Tensor concatenate(Tensors tensors, int axis = -1) | |||||
| { | |||||
| if(axis < 0) | |||||
| { | |||||
| var rank = tensors[0].NDims; | |||||
| if (rank > -1) | |||||
| axis %= rank; | |||||
| else | |||||
| axis = 0; | |||||
| } | |||||
| return array_ops.concat(tensors, axis); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -177,13 +177,13 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.init_scope(); | tf.init_scope(); | ||||
| tf.Context.eager_mode(); | tf.Context.eager_mode(); | ||||
| build(inputs.shape); | |||||
| build(inputs); | |||||
| tf.Context.restore_mode(); | tf.Context.restore_mode(); | ||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected virtual void build(TensorShape input_shape) | |||||
| protected virtual void build(Tensors inputs) | |||||
| { | { | ||||
| built = true; | built = true; | ||||
| } | } | ||||
| @@ -52,8 +52,9 @@ namespace Tensorflow.Keras.Layers | |||||
| axis = args.Axis.dims; | axis = args.Axis.dims; | ||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | |||||
| protected override void build(Tensors inputs) | |||||
| { | { | ||||
| TensorShape input_shape = inputs.shape; | |||||
| var ndims = input_shape.ndim; | var ndims = input_shape.ndim; | ||||
| foreach (var (idx, x) in enumerate(axis)) | foreach (var (idx, x) in enumerate(axis)) | ||||
| if (x < 0) | if (x < 0) | ||||
| @@ -56,8 +56,9 @@ namespace Tensorflow.Keras.Layers | |||||
| _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); | ||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | |||||
| protected override void build(Tensors inputs) | |||||
| { | { | ||||
| TensorShape input_shape = inputs.shape; | |||||
| int channel_axis = data_format == "channels_first" ? 1 : -1; | int channel_axis = data_format == "channels_first" ? 1 : -1; | ||||
| int input_channel = channel_axis < 0 ? | int input_channel = channel_axis < 0 ? | ||||
| input_shape.dims[input_shape.ndim + channel_axis] : | input_shape.dims[input_shape.ndim + channel_axis] : | ||||
| @@ -41,8 +41,9 @@ namespace Tensorflow.Keras.Layers | |||||
| this.inputSpec = new InputSpec(min_ndim: 2); | this.inputSpec = new InputSpec(min_ndim: 2); | ||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | |||||
| protected override void build(Tensors inputs) | |||||
| { | { | ||||
| TensorShape input_shape = inputs.shape; | |||||
| var last_dim = input_shape.dims.Last(); | var last_dim = input_shape.dims.Last(); | ||||
| var axes = new Dictionary<int, int>(); | var axes = new Dictionary<int, int>(); | ||||
| axes[-1] = last_dim; | axes[-1] = last_dim; | ||||
| @@ -52,7 +52,7 @@ namespace Tensorflow.Keras.Layers | |||||
| SupportsMasking = mask_zero; | SupportsMasking = mask_zero; | ||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | |||||
| protected override void build(Tensors inputs) | |||||
| { | { | ||||
| tf.Context.eager_mode(); | tf.Context.eager_mode(); | ||||
| embeddings = add_weight(shape: (input_dim, output_dim), | embeddings = add_weight(shape: (input_dim, output_dim), | ||||
| @@ -0,0 +1,22 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public partial class LayersApi | |||||
| { | |||||
| /// <summary> | |||||
| /// Layer that concatenates a list of inputs. | |||||
| /// </summary> | |||||
| /// <param name="axis">Axis along which to concatenate.</param> | |||||
| /// <returns></returns> | |||||
| public Concatenate Concatenate(int axis = -1) | |||||
| => new Concatenate(new MergeArgs | |||||
| { | |||||
| Axis = axis | |||||
| }); | |||||
| } | |||||
| } | |||||
| @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers | |||||
| } | } | ||||
| protected override void build(TensorShape input_shape) | |||||
| protected override void build(Tensors inputs) | |||||
| { | { | ||||
| // output_shape = input_shape.dims[1^]; | // output_shape = input_shape.dims[1^]; | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers | |||||
| return _merge_function(inputs); | return _merge_function(inputs); | ||||
| } | } | ||||
| Tensors _merge_function(Tensors inputs) | |||||
| protected virtual Tensors _merge_function(Tensors inputs) | |||||
| { | { | ||||
| var output = inputs[0]; | var output = inputs[0]; | ||||
| foreach (var i in range(1, inputs.Length)) | foreach (var i in range(1, inputs.Length)) | ||||
| @@ -0,0 +1,47 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| /// <summary> | |||||
| /// Layer that concatenates a list of inputs. | |||||
| /// </summary> | |||||
| public class Concatenate : Merge | |||||
| { | |||||
| MergeArgs args; | |||||
| int axis => args.Axis; | |||||
| public Concatenate(MergeArgs args) : base(args) | |||||
| { | |||||
| this.args = args; | |||||
| } | |||||
| protected override void build(Tensors inputs) | |||||
| { | |||||
| /*var shape_set = new HashSet<TensorShape>(); | |||||
| var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); | |||||
| for (var i = 0; i < reduced_inputs_shapes.Length; i++) | |||||
| { | |||||
| int seq = -1; | |||||
| TensorShape shape = reduced_inputs_shapes[i].Where(x => | |||||
| { | |||||
| seq++; | |||||
| return seq != i; | |||||
| }).ToArray(); | |||||
| shape_set.Add(shape); | |||||
| }*/ | |||||
| } | |||||
| protected override Tensors _merge_function(Tensors inputs) | |||||
| { | |||||
| return keras.backend.concatenate(inputs, axis: axis); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,20 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class LayersMergingTest : EagerModeTestBase | |||||
| { | |||||
| [TestMethod] | |||||
| public void Concatenate() | |||||
| { | |||||
| var x = np.arange(20).reshape(2, 2, 5); | |||||
| var y = np.arange(20, 30).reshape(2, 1, 5); | |||||
| var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y)); | |||||
| Assert.AreEqual((2, 3, 5), z.shape); | |||||
| } | |||||
| } | |||||
| } | |||||