From e631c1adfacdce4ea447588e459c8e081b8a01e5 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 1 Aug 2020 23:56:31 -0500 Subject: [PATCH] IInitializer for Keras. #355 --- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 ++ src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 21 ++++++++++++---- src/TensorFlowNET.Core/Keras/Layers/Dense.cs | 13 +++++----- .../Keras/Utils/base_layer_utils.cs | 6 ++++- .../Operations/Initializers/Constant.cs | 23 ++++++----------- .../Operations/Initializers/GlorotUniform.cs | 13 ---------- .../Operations/Initializers/IInitializer.cs | 3 +-- .../Initializers/InitializerArgs.cs | 13 ++++++++++ .../Operations/Initializers/Ones.cs | 13 +++------- .../Operations/Initializers/RandomNormal.cs | 19 +++----------- .../Operations/Initializers/RandomUniform.cs | 23 ++++++----------- .../Initializers/TruncatedNormal.cs | 17 +++---------- .../Initializers/VarianceScaling.cs | 25 ++++++------------- .../Operations/Initializers/Zeros.cs | 13 +++------- .../Operations/gen_math_ops.cs | 2 +- .../Operations/gen_random_ops.cs | 15 ++++++++++- .../Operations/random_ops.cs | 3 ++- .../Variables/_VariableStore.cs | 6 ++++- .../Keras/{EmbeddingTest.cs => LayersTest.cs} | 6 ++--- 19 files changed, 106 insertions(+), 130 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs rename test/TensorFlowNET.UnitTest/Keras/{EmbeddingTest.cs => LayersTest.cs} (87%) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 6c404276..7fbdf229 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -147,6 +147,7 @@ namespace Tensorflow /// public Graph as_default() { + tf.Context.graph_mode(); return ops.set_default_graph(this); } @@ -490,6 +491,7 @@ namespace Tensorflow protected override void DisposeManagedResources() { + tf.Context.eager_mode(); ops.default_graph_stack.remove(this); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 7ddb79c7..fd83ae7e 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -18,9 +18,11 @@ using System; using System.Collections.Generic; using System.Linq; using System.Threading; +using Tensorflow.Contexts; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; +using Tensorflow.Operations.Activation; using Tensorflow.Train; using static Tensorflow.Binding; @@ -46,7 +48,7 @@ namespace Tensorflow.Keras.Engine protected bool built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; - + /// /// A stateful layer is a layer whose updates are run during inference too, /// for instance stateful RNNs. @@ -110,8 +112,11 @@ namespace Tensorflow.Keras.Engine /// /// /// - public Tensor Apply(Tensor[] inputs, bool is_training = false) + public Tensor[] Apply(Tensor[] inputs, bool is_training = false) { + var input = inputs[0]; + Tensor[] outputs = null; + callContext = callContext ?? new ThreadLocal() { Value = new CallContext() @@ -120,7 +125,7 @@ namespace Tensorflow.Keras.Engine using var ctxManager = CallContext.enter(); string nameScope = ""; - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { nameScope = name; } @@ -129,15 +134,21 @@ namespace Tensorflow.Keras.Engine throw new NotImplementedException(""); } + using var graph = tf.keras.backend.get_graph().as_default(); + tf_with(ops.name_scope(nameScope), scope => { if (!built) MaybeBuild(inputs); - call(inputs, is_training: is_training); + outputs = call(inputs, is_training: is_training); + + (input, outputs) = _set_connectivity_metadata_(input, outputs); + _handle_activity_regularization(inputs[0], outputs); + _set_mask_metadata(inputs[0], outputs, null); }); - throw new NotImplementedException(""); + return outputs; } [Obsolete("User Apply()")] diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs index 90109c1e..c6485427 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs @@ -30,8 +30,9 @@ namespace Tensorflow.Keras.Layers public class Dense : Layer { DenseArgs args; - protected IVariableV1 kernel; - protected IVariableV1 bias; + IVariableV1 kernel; + IVariableV1 bias; + Activation activation => args.Activation; public Dense(DenseArgs args) : base(args) @@ -74,15 +75,15 @@ namespace Tensorflow.Keras.Layers } else { - outputs = gen_math_ops.mat_mul(inputs[0], kernel.Handle); + outputs = gen_math_ops.mat_mul(inputs[0], kernel.AsTensor()); } if (args.UseBias) outputs = tf.nn.bias_add(outputs, bias); - //if (args.Activation != null) - //outputs = args.Activation.Activate(outputs); + if (args.Activation != null) + outputs = activation(outputs); - return new[] { outputs, outputs }; + return new[] { outputs }; } } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index 18325a49..e8d16820 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -36,7 +36,11 @@ namespace Tensorflow.Keras.Utils ops.init_scope(); - Func init_val = () => args.Initializer.call(args.Shape, dtype: args.DType); + Func init_val = () => args.Initializer.Apply(new InitializerArgs + { + Shape = args.Shape, + DType = args.DType + }); var variable_dtype = args.DType.as_base_dtype(); var v = tf.Variable(init_val, diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index 708d9db6..cf230978 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -29,27 +29,18 @@ namespace Tensorflow.Operations.Initializers _verify_shape = verify_shape; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - if (dtype == TF_DataType.DtInvalid) - dtype = this.dtype; + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; - if (!verify_shape.HasValue) - verify_shape = _verify_shape; + if (!args.VerifyShape.HasValue) + args.VerifyShape = _verify_shape; - return constant_op._constant_impl(value, dtype, shape, + return constant_op._constant_impl(value, args.DType, args.Shape, name: "Const", - verify_shape: verify_shape.Value, + verify_shape: args.VerifyShape.Value, allow_broadcast: false); } - - public object get_config() - { - return new - { - value, - dtype = dtype.name() - }; - } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index 5d38aa7f..8e59370a 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -30,18 +30,5 @@ namespace Tensorflow.Operations.Initializers { } - -#pragma warning disable CS0114 // Member hides inherited member; missing override keyword - public object get_config() -#pragma warning restore CS0114 // Member hides inherited member; missing override keyword - { - return new - { - scale = _scale, - mode = _mode, - seed = _seed, - dtype = _dtype - }; - } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 0ac0865f..50d4d503 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -18,7 +18,6 @@ namespace Tensorflow { public interface IInitializer { - Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null); - object get_config(); + Tensor Apply(InitializerArgs args); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs new file mode 100644 index 00000000..561664bc --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class InitializerArgs + { + public TensorShape Shape { get; set; } + public TF_DataType DType { get; set; } + public bool? VerifyShape { get; set; } = null; + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 83e5b57d..02d3c93b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - if (dtype == TF_DataType.DtInvalid) - dtype = this.dtype; + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; - return array_ops.ones(shape.dims, dtype); - } - - public object get_config() - { - return new { dtype = dtype.name() }; + return array_ops.ones(args.Shape, dtype); } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index a3e2063f..31473912 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -38,22 +38,11 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - if (dtype == TF_DataType.DtInvalid) - dtype = this.dtype; - return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed); - } - - public object get_config() - { - return new - { - mean, - stddev, - seed, - dtype - }; + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; + return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index bd082214..c2e9889b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -27,32 +27,23 @@ namespace Tensorflow.Operations.Initializers #pragma warning disable CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 private float maxval; #pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 -#pragma warning disable CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value private TF_DataType dtype; -#pragma warning restore CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value - public RandomUniform() + public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid) { - + this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - return random_ops.random_uniform(shape, + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; + + return random_ops.random_uniform(args.Shape, minval: minval, maxval: maxval, dtype: dtype, seed: seed); } - - public object get_config() - { - return new { - minval, - maxval, - seed, - dtype - }; - } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 7d635f0c..e656f7ea 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -34,20 +34,11 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed); - } - - public object get_config() - { - return new - { - mean = mean, - stddev = stddev, - seed = seed, - dtype = dtype.name() - }; + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; + return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index 41b6689c..12d6cb68 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -53,10 +53,13 @@ namespace Tensorflow.Operations.Initializers _uniform = uniform; } - public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { + if (args.DType == TF_DataType.DtInvalid) + args.DType = this._dtype; + float n = 0; - var (fan_in, fan_out) = _compute_fans(shape); + var (fan_in, fan_out) = _compute_fans(args.Shape); if (_mode == "FAN_IN") n = fan_in; else if (_mode == "FAN_OUT") @@ -67,13 +70,12 @@ namespace Tensorflow.Operations.Initializers if(_uniform) { var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); - return random_ops.random_uniform(shape, -limit, limit, - dtype, seed: _seed); + return random_ops.random_uniform(args.Shape, -limit, limit, args.DType); } else { var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); - return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, + return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType, seed: _seed); } } @@ -98,18 +100,5 @@ namespace Tensorflow.Operations.Initializers return (fan_in, fan_out); } } - - public virtual object get_config() - { - return new - { - scale = _scale, - mode = _mode, - distribution = _distribution, - seed = _seed, - uniform = _uniform, - dtype = _dtype - }; - } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index bea9cf71..67e5d424 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers this.dtype = dtype; } - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null) + public Tensor Apply(InitializerArgs args) { - if (dtype == TF_DataType.DtInvalid) - dtype = this.dtype; + if (args.DType == TF_DataType.DtInvalid) + args.DType = this.dtype; - return array_ops.zeros(shape, dtype); - } - - public object get_config() - { - return new { dtype = dtype.name() }; + return array_ops.zeros(args.Shape, dtype); } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 6ec7e261..d88dca8c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -955,7 +955,7 @@ namespace Tensorflow /// public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null) { - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "MatMul", name, diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index af98802f..f3442be8 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -38,7 +38,7 @@ namespace Tensorflow if (!seed2.HasValue) seed2 = 0; - if (tf.Context.executing_eagerly()) + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "RandomStandardNormal", name, @@ -98,6 +98,19 @@ namespace Tensorflow if (!seed2.HasValue) seed2 = 0; + if (tf.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "RandomUniform", name, + null, + shape, + "seed", seed, + "seed2", seed2, + "dtype", dtype); + + return results[0]; + } + var _op = tf.OpDefLib._apply_op_helper("RandomUniform", name: name, args: new { shape, dtype, seed, seed2}); diff --git a/src/TensorFlowNET.Core/Operations/random_ops.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs index e92d5619..64530396 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.cs @@ -72,10 +72,11 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope => { name = scope; + var (seed1, seed2) = random_seed.get_seed(seed); var tensorShape = tensor_util.shape_tensor(shape); var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min"); var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max"); - var rnd = gen_random_ops.random_uniform(tensorShape, dtype); + var rnd = gen_random_ops.random_uniform(tensorShape, dtype, seed: seed1, seed2: seed2); return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name); }); } diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 291ad99b..fb76188b 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -162,7 +162,11 @@ namespace Tensorflow } else { - Func init_val = () => initializer.call(shape, dtype); + Func init_val = () => initializer.Apply(new InitializerArgs + { + Shape = shape, + DType = dtype + }); var variable_dtype = dtype.as_base_dtype(); v = variable_scope.default_variable_creator(init_val, diff --git a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs similarity index 87% rename from test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs rename to test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 58c7845f..0bae85bb 100644 --- a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -11,10 +11,10 @@ using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Keras { /// - /// https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Embedding + /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers /// - [TestClass, Ignore] - public class EmbeddingTest : GraphModeTestBase + [TestClass] + public class LayersTest : GraphModeTestBase { [TestMethod] public void Embedding()