Browse Source

Adjusted implementation of Gather to correspond to Python implementation.

Added minval, maxval, and seed arguments to the RandomUniformInitializer
Added unit tests for Gather and a simple Embedding test.
tags/v0.40-tf2.4-tstring
Niklas Gustafsson Esther Hu 4 years ago
parent
commit
37ca023036
5 changed files with 80 additions and 10 deletions
  1. +4
    -7
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  2. +32
    -0
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  4. +16
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  5. +27
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs

+ 4
- 7
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -18,20 +18,17 @@ namespace Tensorflow.Operations.Initializers
{ {
public class RandomUniform : IInitializer public class RandomUniform : IInitializer
{ {
#pragma warning disable CS0649 // Field 'RandomUniform.seed' is never assigned to, and will always have its default value
private int? seed; private int? seed;
#pragma warning restore CS0649 // Field 'RandomUniform.seed' is never assigned to, and will always have its default value
#pragma warning disable CS0649 // Field 'RandomUniform.minval' is never assigned to, and will always have its default value 0
private float minval; private float minval;
#pragma warning restore CS0649 // Field 'RandomUniform.minval' is never assigned to, and will always have its default value 0
#pragma warning disable CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
private float maxval; private float maxval;
#pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
private TF_DataType dtype; private TF_DataType dtype;


public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT)
public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null)
{ {
this.dtype = dtype; this.dtype = dtype;
this.minval = minval;
this.maxval = maxval;
this.seed = seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 32
- 0
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -119,11 +119,43 @@ namespace Tensorflow


public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null) public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null)
{ {
if (tf.Context.executing_eagerly())
{
try
{
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("GatherV2", name, @params, indices, axis, "batch_dims")
{
ctx = tf.Context,
device_name = tf.Context.DeviceName
});
return results[0];
}
catch (Exception exc)
{
return gather_v2_eager_fallback(@params, indices, axis, name, tf.Context);
}
}

var _op = tf.OpDefLib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis }); var _op = tf.OpDefLib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis });


return _op.outputs[0]; return _op.outputs[0];
} }


private static Tensor gather_v2_eager_fallback(object @params, object indices, int axis, string name, Context ctx)
{
var (_attr_T, param) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { @params });
var (_attr_Tindice, indice) = tf.Runner.ArgsToMatchingEager(ctx, default_dtype: tf.int32, args: new[] { indices });
var (_attr_Taxis, axiss) = tf.Runner.ArgsToMatchingEager(ctx, default_dtype: tf.int32, args: new object[] { axis });
var _inputs_flat = param.concat(indice).concat(axiss);
var _attrs = new object[] { "batch_dims", 0, "Tparams", _attr_T, "Tindices", _attr_Tindice, "Taxis", _attr_Taxis };

var results = tf.Runner.Execute(ctx, "GatherV2", 1, _inputs_flat, _attrs, name: name);
if (tf.Runner.MustRecordGradient())
tf.Runner.RecordGradient("GatherV2", _inputs_flat, _attrs, results);
return results[0];
}


public static Tensor pad(Tensor input, Tensor paddings, string name = null) public static Tensor pad(Tensor input, Tensor paddings, string name = null)
{ {
if (tf.Context.executing_eagerly()) if (tf.Context.executing_eagerly())


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Keras.Layers
if (args.BatchInputShape == null) if (args.BatchInputShape == null)
args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray();


embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer;
embeddings_initializer = args.EmbeddingsInitializer ?? tf.random_uniform_initializer;
SupportsMasking = mask_zero; SupportsMasking = mask_zero;
} }




+ 16
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Operations.Initializers;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


@@ -60,13 +61,26 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(model.Layers.Count, 8); Assert.AreEqual(model.Layers.Count, 8);
var result = model.predict(tf.constant(np.arange(24).astype(np.float32)[np.newaxis, Slice.All])); var result = model.predict(tf.constant(np.arange(24).astype(np.float32)[np.newaxis, Slice.All]));
Assert.AreEqual(result.shape, new TensorShape(1, 24)); Assert.AreEqual(result.shape, new TensorShape(1, 24));
model.fit(np.arange(24).astype(np.float32)[np.newaxis, Slice.All], np.arange(24).astype(np.float32)[np.newaxis, Slice.All], verbose: 0);
model.fit(np.arange(24).astype(np.float32)[np.newaxis, Slice.All], np.arange(24).astype(np.float32)[np.newaxis, Slice.All], verbose: 0);
} }


/// <summary> /// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary> /// </summary>
[TestMethod, Ignore]
[TestMethod]
public void Embedding_Simple()
{
var emb = keras.layers.Embedding(256, 12, input_length: 4);
var input_array = np.arange(12).reshape(3, 4).astype(np.float32);
var output = emb.Apply(input_array);
Assert.AreEqual(new TensorShape(3, 4, 12), output.shape);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary>
[TestMethod]
[Ignore]
public void Embedding() public void Embedding()
{ {
var model = keras.Sequential(); var model = keras.Sequential();


+ 27
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs View File

@@ -0,0 +1,27 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.ManagedAPI
{
[TestClass]
public class ArrayOpsTest : EagerModeTestBase
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary>
[TestMethod]
public void Gather()
{
var input_array = tf.constant(np.arange(12).reshape(3, 4).astype(np.float32));
var indices = tf.constant(np.array(new int[] { 0, 2 }));

var result = array_ops.gather(input_array, indices);
Assert.AreEqual(new TensorShape(2, 4), result.shape);
Assert.AreEqual(result.numpy()[0,0], 0.0f);
Assert.AreEqual(result.numpy()[0,1], 1.0f);
Assert.AreEqual(result.numpy()[1,3], 11.0f);
}
}
}

Loading…
Cancel
Save