From b2b083a5b52982397a51b806fe109d0408fd2374 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 10 Mar 2019 23:21:14 -0500 Subject: [PATCH] Keras.Layers.BatchNormalization --- src/TensorFlowNET.Core/APIs/tf.init.cs | 1 + src/TensorFlowNET.Core/APIs/tf.layers.cs | 17 ++- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 40 ++++++- .../Keras/Layers/BatchNormalization.cs | 109 ++++++++++++++++++ src/TensorFlowNET.Core/Layers/Layer.cs | 39 ++----- .../Operations/Initializers/Ones.cs | 29 +++++ .../Operations/array_ops.py.cs | 12 ++ 7 files changed, 212 insertions(+), 35 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/Ones.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index d2d59e52..fb5421c3 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -8,6 +8,7 @@ namespace Tensorflow public static partial class tf { public static IInitializer zeros_initializer => new Zeros(); + public static IInitializer ones_initializer => new Ones(); public static IInitializer glorot_uniform_initializer => new GlorotUniform(); public static variable_scope variable_scope(string name, diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 8532a24d..dc883a75 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -83,7 +83,22 @@ namespace Tensorflow bool renorm = false, float renorm_momentum = 0.99f) { - throw new NotImplementedException("batch_normalization"); + var layer = new BatchNormalization( + axis: axis, + momentum: momentum, + epsilon: epsilon, + center: center, + scale: scale, + beta_initializer: beta_initializer, + gamma_initializer: gamma_initializer, + moving_mean_initializer: moving_mean_initializer, + moving_variance_initializer: moving_variance_initializer, + renorm: renorm, + renorm_momentum: renorm_momentum, + trainable: trainable, + name: name); + + return layer.apply(inputs, training: training); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 7fb1a12d..48955d11 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -18,12 +18,33 @@ namespace Tensorflow.Keras.Engine /// the layer's weights. /// protected bool built; - + protected bool trainable; + protected TF_DataType _dtype; + /// + /// A stateful layer is a layer whose updates are run during inference too, + /// for instance stateful RNNs. + /// + protected bool stateful; + /// + /// Provides information about which inputs are compatible with the layer. + /// + protected InputSpec input_spec; + protected bool supports_masking; protected List _trainable_weights; + protected string _name; + protected string _base_name; + protected bool _compute_previous_mask; - public Layer() + public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) { + this.trainable = trainable; + this._dtype = dtype; + stateful = false; + built = false; + this.supports_masking = false; + _init_set_name(name); _trainable_weights = new List(); + _compute_previous_mask = false; } public Tensor __call__(Tensor inputs, @@ -97,7 +118,7 @@ namespace Tensorflow.Keras.Engine protected virtual void build(TensorShape input_shape) { - + throw new NotImplementedException("Layer.build"); } protected virtual RefVariable add_weight(string name, @@ -119,5 +140,18 @@ namespace Tensorflow.Keras.Engine return variable; } + + protected virtual void _init_set_name(string name) + { + if (string.IsNullOrEmpty(name)) + (_name, _base_name) = _make_unique_name(); + } + + protected virtual (string, string) _make_unique_name() + { + string base_name = "conv2d"; + string name = base_layer_utils.unique_layer_name(base_name); + return (name, base_name); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs new file mode 100644 index 00000000..80d3d655 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Layers; + +namespace Tensorflow.Keras.Layers +{ + public class BatchNormalization : Layer + { + private bool _USE_V2_BEHAVIOR = true; + private float momentum; + private float epsilon; + private bool center; + private bool scale; + private bool renorm; + private bool fused; + private bool _bessels_correction_test_only; + private int[] axis; + private string _data_format; + private IInitializer beta_initializer; + private IInitializer gamma_initializer; + private IInitializer moving_mean_initializer; + private IInitializer moving_variance_initializer; + private RefVariable gamma; + private RefVariable beta; + private RefVariable moving_mean; + + public BatchNormalization(int axis = -1, + float momentum = 0.99f, + float epsilon = 0.001f, + bool center = true, + bool scale = true, + IInitializer beta_initializer = null, + IInitializer gamma_initializer = null, + IInitializer moving_mean_initializer = null, + IInitializer moving_variance_initializer = null, + bool renorm = false, + float renorm_momentum = 0.99f, + bool trainable = true, + string name = null) : base(trainable: trainable, + name: name) + { + this.axis = new int[] { axis }; + this.momentum = momentum; + this.epsilon = epsilon; + this.center = center; + this.scale = scale; + if (beta_initializer == null) + beta_initializer = tf.zeros_initializer; + if (gamma_initializer == null) + gamma_initializer = tf.ones_initializer; + if (moving_mean_initializer == null) + moving_mean_initializer = tf.zeros_initializer; + if (moving_variance_initializer == null) + moving_variance_initializer = tf.ones_initializer; + this.beta_initializer = beta_initializer; + this.gamma_initializer = gamma_initializer; + this.moving_mean_initializer = moving_mean_initializer; + this.moving_variance_initializer = moving_variance_initializer; + this.renorm = renorm; + this.fused = true; + this.supports_masking = true; + this._bessels_correction_test_only = true; + } + + protected override void build(TensorShape input_shape) + { + var ndims = input_shape.NDim; + foreach (var (idx, x) in Python.enumerate(axis)) + if (x < 0) + axis[idx] = ndims + x; + + if (fused) + if (Enumerable.SequenceEqual(axis, new int[] { 3 })) + _data_format = "NHWC"; + + var param_dtype = _dtype == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : _dtype; + var param_shape = new int[] { input_shape.Dimensions[axis[0]] }; + + if (scale) + gamma = add_weight("gamma", + param_shape, + dtype: param_dtype, + initializer: gamma_initializer, + trainable: true); + else + throw new NotImplementedException("add_weight gamma"); + + if (center) + beta = add_weight("beta", + param_shape, + dtype: param_dtype, + initializer: beta_initializer, + trainable: true); + else + throw new NotImplementedException("add_weight beta"); + + if(_scope != null) + { + + } + + moving_mean = add_weight("moving_mean", + param_shape, + dtype: param_dtype); + } + } +} diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 3132048e..997153be 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -7,39 +7,27 @@ namespace Tensorflow.Layers { public class Layer : Keras.Engine.Layer { - protected bool trainable; - protected string _name; - protected TF_DataType _dtype; protected Graph _graph; - protected string _base_name; + protected VariableScope _scope; protected VariableScope _current_scope; - /// - /// A stateful layer is a layer whose updates are run during inference too, - /// for instance stateful RNNs. - /// - protected bool stateful; - /// - /// Provides information about which inputs are compatible with the layer. - /// - protected InputSpec input_spec; - protected bool supports_masking; + protected bool? _reuse; + protected bool _use_resource_variables; + protected bool _keras_style; public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, - bool? _reuse = null) : base() + bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype) { - this.trainable = trainable; - this.stateful = false; + this._use_resource_variables = false; this._reuse = _reuse; this.built = false; - this.supports_masking = false; - _init_set_name(name); + _keras_style = false; } - public Tensor apply(Tensor inputs) + public virtual Tensor apply(Tensor inputs, Tensor training = null) { return __call__(inputs); } @@ -126,18 +114,7 @@ namespace Tensorflow.Layers }); } - private void _init_set_name(string name) - { - if (string.IsNullOrEmpty(name)) - (_name, _base_name) = _make_unique_name(); - } - private (string, string) _make_unique_name() - { - string base_name = "conv2d"; - string name = base_layer_utils.unique_layer_name(base_name); - return (name, base_name); - } protected override string _name_scope() { diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs new file mode 100644 index 00000000..750c4ec8 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class Ones : IInitializer + { + private TF_DataType dtype; + + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + + return array_ops.ones(shape.Dimensions, dtype); + } + + public object get_config() + { + return new { dtype = dtype.name() }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 0eaf6251..ffdc0d8f 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -92,6 +92,18 @@ namespace Tensorflow }); } + public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return with(ops.name_scope(name, "ones", new { dims }), scope => + { + name = scope; + var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32); + var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); + return output; + }); + } + public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null) { if( x == null && y == null)