Browse Source

Keras.Layers.BatchNormalization

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
b2b083a5b5
7 changed files with 212 additions and 35 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/APIs/tf.init.cs
  2. +16
    -1
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  3. +37
    -3
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  4. +109
    -0
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  5. +8
    -31
      src/TensorFlowNET.Core/Layers/Layer.cs
  6. +29
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  7. +12
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs

+ 1
- 0
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow
public static partial class tf
{
public static IInitializer zeros_initializer => new Zeros();
public static IInitializer ones_initializer => new Ones();
public static IInitializer glorot_uniform_initializer => new GlorotUniform();
public static variable_scope variable_scope(string name,


+ 16
- 1
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -83,7 +83,22 @@ namespace Tensorflow
bool renorm = false,
float renorm_momentum = 0.99f)
{
throw new NotImplementedException("batch_normalization");
var layer = new BatchNormalization(
axis: axis,
momentum: momentum,
epsilon: epsilon,
center: center,
scale: scale,
beta_initializer: beta_initializer,
gamma_initializer: gamma_initializer,
moving_mean_initializer: moving_mean_initializer,
moving_variance_initializer: moving_variance_initializer,
renorm: renorm,
renorm_momentum: renorm_momentum,
trainable: trainable,
name: name);

return layer.apply(inputs, training: training);
}
}
}


+ 37
- 3
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -18,12 +18,33 @@ namespace Tensorflow.Keras.Engine
/// the layer's weights.
/// </summary>
protected bool built;

protected bool trainable;
protected TF_DataType _dtype;
/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
/// </summary>
protected bool stateful;
/// <summary>
/// Provides information about which inputs are compatible with the layer.
/// </summary>
protected InputSpec input_spec;
protected bool supports_masking;
protected List<RefVariable> _trainable_weights;
protected string _name;
protected string _base_name;
protected bool _compute_previous_mask;

public Layer()
public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
this.trainable = trainable;
this._dtype = dtype;
stateful = false;
built = false;
this.supports_masking = false;
_init_set_name(name);
_trainable_weights = new List<RefVariable>();
_compute_previous_mask = false;
}

public Tensor __call__(Tensor inputs,
@@ -97,7 +118,7 @@ namespace Tensorflow.Keras.Engine

protected virtual void build(TensorShape input_shape)
{
throw new NotImplementedException("Layer.build");
}

protected virtual RefVariable add_weight(string name,
@@ -119,5 +140,18 @@ namespace Tensorflow.Keras.Engine

return variable;
}

protected virtual void _init_set_name(string name)
{
if (string.IsNullOrEmpty(name))
(_name, _base_name) = _make_unique_name();
}

protected virtual (string, string) _make_unique_name()
{
string base_name = "conv2d";
string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name);
}
}
}

+ 109
- 0
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -0,0 +1,109 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Layers;

namespace Tensorflow.Keras.Layers
{
public class BatchNormalization : Layer
{
private bool _USE_V2_BEHAVIOR = true;
private float momentum;
private float epsilon;
private bool center;
private bool scale;
private bool renorm;
private bool fused;
private bool _bessels_correction_test_only;
private int[] axis;
private string _data_format;
private IInitializer beta_initializer;
private IInitializer gamma_initializer;
private IInitializer moving_mean_initializer;
private IInitializer moving_variance_initializer;
private RefVariable gamma;
private RefVariable beta;
private RefVariable moving_mean;

public BatchNormalization(int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
bool center = true,
bool scale = true,
IInitializer beta_initializer = null,
IInitializer gamma_initializer = null,
IInitializer moving_mean_initializer = null,
IInitializer moving_variance_initializer = null,
bool renorm = false,
float renorm_momentum = 0.99f,
bool trainable = true,
string name = null) : base(trainable: trainable,
name: name)
{
this.axis = new int[] { axis };
this.momentum = momentum;
this.epsilon = epsilon;
this.center = center;
this.scale = scale;
if (beta_initializer == null)
beta_initializer = tf.zeros_initializer;
if (gamma_initializer == null)
gamma_initializer = tf.ones_initializer;
if (moving_mean_initializer == null)
moving_mean_initializer = tf.zeros_initializer;
if (moving_variance_initializer == null)
moving_variance_initializer = tf.ones_initializer;
this.beta_initializer = beta_initializer;
this.gamma_initializer = gamma_initializer;
this.moving_mean_initializer = moving_mean_initializer;
this.moving_variance_initializer = moving_variance_initializer;
this.renorm = renorm;
this.fused = true;
this.supports_masking = true;
this._bessels_correction_test_only = true;
}

protected override void build(TensorShape input_shape)
{
var ndims = input_shape.NDim;
foreach (var (idx, x) in Python.enumerate(axis))
if (x < 0)
axis[idx] = ndims + x;

if (fused)
if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
_data_format = "NHWC";

var param_dtype = _dtype == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : _dtype;
var param_shape = new int[] { input_shape.Dimensions[axis[0]] };

if (scale)
gamma = add_weight("gamma",
param_shape,
dtype: param_dtype,
initializer: gamma_initializer,
trainable: true);
else
throw new NotImplementedException("add_weight gamma");

if (center)
beta = add_weight("beta",
param_shape,
dtype: param_dtype,
initializer: beta_initializer,
trainable: true);
else
throw new NotImplementedException("add_weight beta");

if(_scope != null)
{
}

moving_mean = add_weight("moving_mean",
param_shape,
dtype: param_dtype);
}
}
}

+ 8
- 31
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -7,39 +7,27 @@ namespace Tensorflow.Layers
{
public class Layer : Keras.Engine.Layer
{
protected bool trainable;
protected string _name;
protected TF_DataType _dtype;
protected Graph _graph;
protected string _base_name;
protected VariableScope _scope;
protected VariableScope _current_scope;
/// <summary>
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
/// </summary>
protected bool stateful;
/// <summary>
/// Provides information about which inputs are compatible with the layer.
/// </summary>
protected InputSpec input_spec;
protected bool supports_masking;
protected bool? _reuse;
protected bool _use_resource_variables;
protected bool _keras_style;

public Layer(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null) : base()
bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype)
{
this.trainable = trainable;
this.stateful = false;
this._use_resource_variables = false;
this._reuse = _reuse;
this.built = false;
this.supports_masking = false;
_init_set_name(name);
_keras_style = false;
}

public Tensor apply(Tensor inputs)
public virtual Tensor apply(Tensor inputs, Tensor training = null)
{
return __call__(inputs);
}
@@ -126,18 +114,7 @@ namespace Tensorflow.Layers
});
}

private void _init_set_name(string name)
{
if (string.IsNullOrEmpty(name))
(_name, _base_name) = _make_unique_name();
}

private (string, string) _make_unique_name()
{
string base_name = "conv2d";
string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name);
}

protected override string _name_scope()
{


+ 29
- 0
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
public class Ones : IInitializer
{
private TF_DataType dtype;

public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;

return array_ops.ones(shape.Dimensions, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
}
}
}

+ 12
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -92,6 +92,18 @@ namespace Tensorflow
});
}
public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
return with(ops.name_scope(name, "ones", new { dims }), scope =>
{
name = scope;
var shape = ops.convert_to_tensor(dims, dtype: TF_DataType.TF_INT32);
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
return output;
});
}
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
{
if( x == null && y == null)


Loading…
Cancel
Save