Browse Source

IInitializer for Keras. #355

tags/v0.20
Oceania2018 5 years ago
parent
commit
e631c1adfa
19 changed files with 106 additions and 130 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +16
    -5
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  3. +7
    -6
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  4. +5
    -1
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  5. +7
    -16
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  6. +0
    -13
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  7. +1
    -2
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  8. +13
    -0
      src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
  9. +4
    -9
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  10. +4
    -15
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  11. +7
    -16
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  12. +4
    -13
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  13. +7
    -18
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  14. +4
    -9
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  16. +14
    -1
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  17. +2
    -1
      src/TensorFlowNET.Core/Operations/random_ops.cs
  18. +5
    -1
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  19. +3
    -3
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

+ 2
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -147,6 +147,7 @@ namespace Tensorflow
/// <returns></returns>
public Graph as_default()
{
tf.Context.graph_mode();
return ops.set_default_graph(this);
}

@@ -490,6 +491,7 @@ namespace Tensorflow

protected override void DisposeManagedResources()
{
tf.Context.eager_mode();
ops.default_graph_stack.remove(this);
}



+ 16
- 5
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -18,9 +18,11 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using static Tensorflow.Binding;

@@ -46,7 +48,7 @@ namespace Tensorflow.Keras.Engine
protected bool built;
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;
/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
@@ -110,8 +112,11 @@ namespace Tensorflow.Keras.Engine
/// <param name="input"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensor Apply(Tensor[] inputs, bool is_training = false)
public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
{
var input = inputs[0];
Tensor[] outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
@@ -120,7 +125,7 @@ namespace Tensorflow.Keras.Engine
using var ctxManager = CallContext.enter();

string nameScope = "";
if (tf.Context.executing_eagerly())
if (tf.executing_eagerly())
{
nameScope = name;
}
@@ -129,15 +134,21 @@ namespace Tensorflow.Keras.Engine
throw new NotImplementedException("");
}

using var graph = tf.keras.backend.get_graph().as_default();
tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs);

call(inputs, is_training: is_training);
outputs = call(inputs, is_training: is_training);

(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
});

throw new NotImplementedException("");
return outputs;
}

[Obsolete("User Apply()")]


+ 7
- 6
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -30,8 +30,9 @@ namespace Tensorflow.Keras.Layers
public class Dense : Layer
{
DenseArgs args;
protected IVariableV1 kernel;
protected IVariableV1 bias;
IVariableV1 kernel;
IVariableV1 bias;
Activation activation => args.Activation;

public Dense(DenseArgs args) :
base(args)
@@ -74,15 +75,15 @@ namespace Tensorflow.Keras.Layers
}
else
{
outputs = gen_math_ops.mat_mul(inputs[0], kernel.Handle);
outputs = gen_math_ops.mat_mul(inputs[0], kernel.AsTensor());
}

if (args.UseBias)
outputs = tf.nn.bias_add(outputs, bias);
//if (args.Activation != null)
//outputs = args.Activation.Activate(outputs);
if (args.Activation != null)
outputs = activation(outputs);

return new[] { outputs, outputs };
return new[] { outputs };
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -36,7 +36,11 @@ namespace Tensorflow.Keras.Utils

ops.init_scope();

Func<Tensor> init_val = () => args.Initializer.call(args.Shape, dtype: args.DType);
Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs
{
Shape = args.Shape,
DType = args.DType
});

var variable_dtype = args.DType.as_base_dtype();
var v = tf.Variable(init_val,


+ 7
- 16
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -29,27 +29,18 @@ namespace Tensorflow.Operations.Initializers
_verify_shape = verify_shape;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;

if (!verify_shape.HasValue)
verify_shape = _verify_shape;
if (!args.VerifyShape.HasValue)
args.VerifyShape = _verify_shape;

return constant_op._constant_impl(value, dtype, shape,
return constant_op._constant_impl(value, args.DType, args.Shape,
name: "Const",
verify_shape: verify_shape.Value,
verify_shape: args.VerifyShape.Value,
allow_broadcast: false);
}

public object get_config()
{
return new
{
value,
dtype = dtype.name()
};
}
}
}

+ 0
- 13
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -30,18 +30,5 @@ namespace Tensorflow.Operations.Initializers
{

}

#pragma warning disable CS0114 // Member hides inherited member; missing override keyword
public object get_config()
#pragma warning restore CS0114 // Member hides inherited member; missing override keyword
{
return new
{
scale = _scale,
mode = _mode,
seed = _seed,
dtype = _dtype
};
}
}
}

+ 1
- 2
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -18,7 +18,6 @@ namespace Tensorflow
{
public interface IInitializer
{
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null);
object get_config();
Tensor Apply(InitializerArgs args);
}
}

+ 13
- 0
src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class InitializerArgs
{
public TensorShape Shape { get; set; }
public TF_DataType DType { get; set; }
public bool? VerifyShape { get; set; } = null;
}
}

+ 4
- 9
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;

return array_ops.ones(shape.dims, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
return array_ops.ones(args.Shape, dtype);
}
}
}

+ 4
- 15
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -38,22 +38,11 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;
return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed);
}

public object get_config()
{
return new
{
mean,
stddev,
seed,
dtype
};
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;
return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed);
}
}
}

+ 7
- 16
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -27,32 +27,23 @@ namespace Tensorflow.Operations.Initializers
#pragma warning disable CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
private float maxval;
#pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
#pragma warning disable CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value
private TF_DataType dtype;
#pragma warning restore CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value

public RandomUniform()
public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid)
{
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
return random_ops.random_uniform(shape,
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;

return random_ops.random_uniform(args.Shape,
minval: minval,
maxval: maxval,
dtype: dtype,
seed: seed);
}

public object get_config()
{
return new {
minval,
maxval,
seed,
dtype
};
}
}
}

+ 4
- 13
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -34,20 +34,11 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed);
}

public object get_config()
{
return new
{
mean = mean,
stddev = stddev,
seed = seed,
dtype = dtype.name()
};
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;
return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed);
}
}
}

+ 7
- 18
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -53,10 +53,13 @@ namespace Tensorflow.Operations.Initializers
_uniform = uniform;
}

public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
if (args.DType == TF_DataType.DtInvalid)
args.DType = this._dtype;

float n = 0;
var (fan_in, fan_out) = _compute_fans(shape);
var (fan_in, fan_out) = _compute_fans(args.Shape);
if (_mode == "FAN_IN")
n = fan_in;
else if (_mode == "FAN_OUT")
@@ -67,13 +70,12 @@ namespace Tensorflow.Operations.Initializers
if(_uniform)
{
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
return random_ops.random_uniform(shape, -limit, limit,
dtype, seed: _seed);
return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
}
else
{
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype,
return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
seed: _seed);
}
}
@@ -98,18 +100,5 @@ namespace Tensorflow.Operations.Initializers
return (fan_in, fan_out);
}
}

public virtual object get_config()
{
return new
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
uniform = _uniform,
dtype = _dtype
};
}
}
}

+ 4
- 9
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
public Tensor Apply(InitializerArgs args)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;

return array_ops.zeros(shape, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
return array_ops.zeros(args.Shape, dtype);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -955,7 +955,7 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null)
{
if (tf.Context.executing_eagerly())
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MatMul", name,


+ 14
- 1
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;

if (tf.Context.executing_eagerly())
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomStandardNormal", name,
@@ -98,6 +98,19 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomUniform", name,
null,
shape,
"seed", seed,
"seed2", seed2,
"dtype", dtype);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("RandomUniform",
name: name,
args: new { shape, dtype, seed, seed2});


+ 2
- 1
src/TensorFlowNET.Core/Operations/random_ops.cs View File

@@ -72,10 +72,11 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{
name = scope;
var (seed1, seed2) = random_seed.get_seed(seed);
var tensorShape = tensor_util.shape_tensor(shape);
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
var rnd = gen_random_ops.random_uniform(tensorShape, dtype);
var rnd = gen_random_ops.random_uniform(tensorShape, dtype, seed: seed1, seed2: seed2);
return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name);
});
}


+ 5
- 1
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -162,7 +162,11 @@ namespace Tensorflow
}
else
{
Func<Tensor> init_val = () => initializer.call(shape, dtype);
Func<Tensor> init_val = () => initializer.Apply(new InitializerArgs
{
Shape = shape,
DType = dtype
});
var variable_dtype = dtype.as_base_dtype();

v = variable_scope.default_variable_creator(init_val,


test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs → test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -11,10 +11,10 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.Keras
{
/// <summary>
/// https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Embedding
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
/// </summary>
[TestClass, Ignore]
public class EmbeddingTest : GraphModeTestBase
[TestClass]
public class LayersTest : GraphModeTestBase
{
[TestMethod]
public void Embedding()

Loading…
Cancel
Save