From 958fdb87913188c46cdc28bf54911b16b51717b7 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 8 Mar 2019 23:31:13 -0600 Subject: [PATCH] gen_random_ops.truncated_normal --- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 12 +++++++++- src/TensorFlowNET.Core/Keras/Initializers.cs | 2 +- src/TensorFlowNET.Core/Keras/backend.cs | 14 +++++++++++ src/TensorFlowNET.Core/Layers/Layer.cs | 2 +- .../Initializers/VarianceScaling.cs | 3 ++- .../Operations/gen_random_ops.py.cs | 23 +++++++++++++++++++ .../Operations/random_ops.py.cs | 21 +++++++++++++++++ .../Checkpointable/CheckpointableBase.cs | 11 ++++++++- .../Variables/RefVariable.cs | 2 +- 9 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/backend.cs diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 3c5825e2..e35343d1 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -19,6 +19,13 @@ namespace Tensorflow.Keras.Engine /// protected bool built; + protected List _trainable_weights; + + public Layer() + { + _trainable_weights = new List(); + } + public Tensor __call__(Tensor inputs, VariableScope scope = null) { @@ -36,6 +43,7 @@ namespace Tensorflow.Keras.Engine if (!built) { _maybe_build(inputs); + built = true; } }); @@ -65,13 +73,15 @@ namespace Tensorflow.Keras.Engine bool? trainable = null, Func getter = null) { - _add_variable_with_custom_getter(name, + var variable = _add_variable_with_custom_getter(name, shape, dtype: dtype, getter: getter, overwrite: true, initializer: initializer, trainable: trainable.Value); + backend.track_variable(variable); + _trainable_weights.Add(variable); } } } diff --git a/src/TensorFlowNET.Core/Keras/Initializers.cs b/src/TensorFlowNET.Core/Keras/Initializers.cs index cea77ae9..27cd384e 100644 --- a/src/TensorFlowNET.Core/Keras/Initializers.cs +++ b/src/TensorFlowNET.Core/Keras/Initializers.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras /// public IInitializer he_normal(int? seed = null) { - return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed); + return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs new file mode 100644 index 00000000..51d74c04 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public class backend + { + public static void track_variable(RefVariable v) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index b8e618f1..afe17dbb 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -29,7 +29,7 @@ namespace Tensorflow.Layers public Layer(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, - bool? _reuse = null) + bool? _reuse = null) : base() { this.trainable = trainable; this.stateful = false; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index 7a8d9af8..16149261 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -43,7 +43,8 @@ namespace Tensorflow.Operations.Initializers if (_distribution == "normal" || _distribution == "truncated_normal") { - throw new NotImplementedException("truncated_normal"); + float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; + return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); } else if (_distribution == "untruncated_normal") { diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs index 354177f9..5e9689ab 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs @@ -53,5 +53,28 @@ namespace Tensorflow return _op.outputs[0]; } + + /// + /// Outputs random values from a truncated normal distribution. + /// + /// + /// + /// + /// + /// + /// + public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + + var _op = _op_def_lib._apply_op_helper("TruncatedNormal", + name: name, + args: new { shape, dtype, seed, seed2 }); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 62d4346a..2162c1a4 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -64,6 +64,27 @@ namespace Tensorflow }); } + public static Tensor truncated_normal(int[] shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) + { + return with(ops.name_scope(name, "truncated_normal", new { shape, mean, stddev }), scope => + { + name = scope; + var shape_tensor = _ShapeTensor(shape); + var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean"); + var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev"); + var (seed1, seed2) = random_seed.get_seed(seed); + var rnd = gen_random_ops.truncated_normal(shape_tensor, dtype, seed: seed1, seed2: seed2); + var mul = rnd * stddev_tensor; + var value = math_ops.add(mul, mean_tensor, name: name); + return value; + }); + } + private static Tensor _ShapeTensor(int[] shape) { return ops.convert_to_tensor(shape, name: "shape"); diff --git a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs index 558a7177..7a61ec5b 100644 --- a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs +++ b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs @@ -19,7 +19,16 @@ namespace Tensorflow bool trainable = false) { var new_variable = getter(name, shape, dtype, initializer, trainable); - throw new NotImplementedException("_add_variable_with_custom_getter"); + if (!overwrite || new_variable is RefVariable) + return _track_checkpointable(new_variable, name: name, + overwrite: overwrite); + else + return new_variable; + } + + protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) + { + return checkpointable; } } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index df068477..05982d95 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -136,8 +136,8 @@ namespace Tensorflow { _initial_value = (initial_value as Func)(); _initial_value = ops.convert_to_tensor(_initial_value, name: "initial_value", dtype: dtype); - _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); }); + _variable = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); } // Or get the initial value from a Tensor or Python object. else