Browse Source

Add keras.Concatenate.

tags/v0.30
Oceania2018 5 years ago
parent
commit
b58c24136c
12 changed files with 122 additions and 10 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs
  2. +1
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  3. +20
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  4. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  5. +2
    -1
      src/TensorFlowNET.Keras/Layers/BatchNormalization.cs
  6. +2
    -1
      src/TensorFlowNET.Keras/Layers/Convolutional.cs
  7. +2
    -1
      src/TensorFlowNET.Keras/Layers/Dense.cs
  8. +1
    -1
      src/TensorFlowNET.Keras/Layers/Embedding.cs
  9. +22
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
  10. +2
    -2
      src/TensorFlowNET.Keras/Layers/Merge.cs
  11. +47
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
  12. +20
    -0
      test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs View File

@@ -7,5 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public class MergeArgs : LayerArgs public class MergeArgs : LayerArgs
{ {
public Tensors Inputs { get; set; } public Tensors Inputs { get; set; }
public int Axis { get; set; }
} }
} }

+ 1
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -407,7 +407,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));


var ret = tensor.TensorShape.unknown_shape(shape.dims[0]); var ret = tensor.TensorShape.unknown_shape(shape.dims[0]);
var value = constant_value(tensor); var value = constant_value(tensor);
if (value != null)
if (!(value is null))
{ {
int[] d_ = { }; int[] d_ = { };
foreach (int d in value) foreach (int d in value)
@@ -418,7 +418,6 @@ would not be rank 1.", tensor.op.get_attr("axis")));
d_[d_.Length] = -1; // None d_[d_.Length] = -1; // None
} }
ret = ret.merge_with(new TensorShape(d_)); ret = ret.merge_with(new TensorShape(d_));

} }
return ret; return ret;
} }


+ 20
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -226,5 +226,25 @@ namespace Tensorflow.Keras
x.set_shape(output_shape); x.set_shape(output_shape);
return x; return x;
} }

/// <summary>
/// Concatenates a list of tensors alongside the specified axis.
/// </summary>
/// <param name="tensors">list of tensors to concatenate.</param>
/// <param name="axis">concatenation axis.</param>
/// <returns></returns>
public Tensor concatenate(Tensors tensors, int axis = -1)
{
if(axis < 0)
{
var rank = tensors[0].NDims;
if (rank > -1)
axis %= rank;
else
axis = 0;
}

return array_ops.concat(tensors, axis);
}
} }
} }

+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -177,13 +177,13 @@ namespace Tensorflow.Keras.Engine
tf.init_scope(); tf.init_scope();


tf.Context.eager_mode(); tf.Context.eager_mode();
build(inputs.shape);
build(inputs);
tf.Context.restore_mode(); tf.Context.restore_mode();


built = true; built = true;
} }


protected virtual void build(TensorShape input_shape)
protected virtual void build(Tensors inputs)
{ {
built = true; built = true;
} }


+ 2
- 1
src/TensorFlowNET.Keras/Layers/BatchNormalization.cs View File

@@ -52,8 +52,9 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.dims; axis = args.Axis.dims;
} }


protected override void build(TensorShape input_shape)
protected override void build(Tensors inputs)
{ {
TensorShape input_shape = inputs.shape;
var ndims = input_shape.ndim; var ndims = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis)) foreach (var (idx, x) in enumerate(axis))
if (x < 0) if (x < 0)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Convolutional.cs View File

@@ -56,8 +56,9 @@ namespace Tensorflow.Keras.Layers
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
} }


protected override void build(TensorShape input_shape)
protected override void build(Tensors inputs)
{ {
TensorShape input_shape = inputs.shape;
int channel_axis = data_format == "channels_first" ? 1 : -1; int channel_axis = data_format == "channels_first" ? 1 : -1;
int input_channel = channel_axis < 0 ? int input_channel = channel_axis < 0 ?
input_shape.dims[input_shape.ndim + channel_axis] : input_shape.dims[input_shape.ndim + channel_axis] :


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Dense.cs View File

@@ -41,8 +41,9 @@ namespace Tensorflow.Keras.Layers
this.inputSpec = new InputSpec(min_ndim: 2); this.inputSpec = new InputSpec(min_ndim: 2);
} }


protected override void build(TensorShape input_shape)
protected override void build(Tensors inputs)
{ {
TensorShape input_shape = inputs.shape;
var last_dim = input_shape.dims.Last(); var last_dim = input_shape.dims.Last();
var axes = new Dictionary<int, int>(); var axes = new Dictionary<int, int>();
axes[-1] = last_dim; axes[-1] = last_dim;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Embedding.cs View File

@@ -52,7 +52,7 @@ namespace Tensorflow.Keras.Layers
SupportsMasking = mask_zero; SupportsMasking = mask_zero;
} }


protected override void build(TensorShape input_shape)
protected override void build(Tensors inputs)
{ {
tf.Context.eager_mode(); tf.Context.eager_mode();
embeddings = add_weight(shape: (input_dim, output_dim), embeddings = add_weight(shape: (input_dim, output_dim),


+ 22
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs View File

@@ -0,0 +1,22 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
{
/// <summary>
/// Layer that concatenates a list of inputs.
/// </summary>
/// <param name="axis">Axis along which to concatenate.</param>
/// <returns></returns>
public Concatenate Concatenate(int axis = -1)
=> new Concatenate(new MergeArgs
{
Axis = axis
});
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Merge.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers


} }


protected override void build(TensorShape input_shape)
protected override void build(Tensors inputs)
{ {
// output_shape = input_shape.dims[1^]; // output_shape = input_shape.dims[1^];
} }
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers
return _merge_function(inputs); return _merge_function(inputs);
} }


Tensors _merge_function(Tensors inputs)
protected virtual Tensors _merge_function(Tensors inputs)
{ {
var output = inputs[0]; var output = inputs[0];
foreach (var i in range(1, inputs.Length)) foreach (var i in range(1, inputs.Length))


+ 47
- 0
src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs View File

@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Layer that concatenates a list of inputs.
/// </summary>
public class Concatenate : Merge
{
MergeArgs args;
int axis => args.Axis;

public Concatenate(MergeArgs args) : base(args)
{
this.args = args;
}

protected override void build(Tensors inputs)
{
/*var shape_set = new HashSet<TensorShape>();
var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray();
for (var i = 0; i < reduced_inputs_shapes.Length; i++)
{
int seq = -1;
TensorShape shape = reduced_inputs_shapes[i].Where(x =>
{
seq++;
return seq != i;
}).ToArray();
shape_set.Add(shape);
}*/
}

protected override Tensors _merge_function(Tensors inputs)
{
return keras.backend.concatenate(inputs, axis: axis);
}
}
}

+ 20
- 0
test/TensorFlowNET.UnitTest/Keras/Layers.Merging.Test.cs View File

@@ -0,0 +1,20 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Keras
{
[TestClass]
public class LayersMergingTest : EagerModeTestBase
{
[TestMethod]
public void Concatenate()
{
var x = np.arange(20).reshape(2, 2, 5);
var y = np.arange(20, 30).reshape(2, 1, 5);
var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
Assert.AreEqual((2, 3, 5), z.shape);
}
}
}

Loading…
Cancel
Save