From c909310838682ecd35f7504f769ae03fad39a6c1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 16 Nov 2019 13:23:44 -0600 Subject: [PATCH] variance_scaling_initializer --- src/TensorFlowNET.Core/APIs/tf.init.cs | 12 ++++++------ src/TensorFlowNET.Core/APIs/tf.random.cs | 9 ++++++--- src/TensorFlowNET.Core/Binding.Util.cs | 3 +++ src/TensorFlowNET.Core/Data/DatasetV2.cs | 8 +++++++- src/TensorFlowNET.Core/Graphs/Graph.cs | 10 ++++++++++ src/TensorFlowNET.Core/Keras/Initializers.cs | 2 +- .../Operations/Initializers/GlorotUniform.cs | 5 ++++- .../Operations/Initializers/VarianceScaling.cs | 17 +++++++++++------ 8 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 15bcd766..db2ea1b1 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -66,20 +66,20 @@ namespace Tensorflow /// /// Initializer capable of adapting its scale to the shape of weights tensors. /// - /// + /// /// /// /// /// /// - public IInitializer variance_scaling_initializer(float scale = 1.0f, - string mode = "fan_in", - string distribution = "truncated_normal", + public IInitializer variance_scaling_initializer(float factor = 1.0f, + string mode = "FAN_IN", + bool uniform = false, int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling( - scale: scale, + factor: factor, mode: mode, - distribution: distribution, + uniform: uniform, seed: seed, dtype: dtype); } diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index c331eb7f..c11ca791 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -28,21 +28,21 @@ namespace Tensorflow /// /// /// - public Tensor random_normal(int[] shape, + public Tensor random_normal(TensorShape shape, float mean = 0.0f, float stddev = 1.0f, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); - public Tensor random_uniform(int[] shape, + public Tensor random_uniform(TensorShape shape, float minval = 0, float maxval = 1, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); - public Tensor truncated_normal(int[] shape, + public Tensor truncated_normal(TensorShape shape, float mean = 0.0f, float stddev = 1.0f, TF_DataType dtype = TF_DataType.TF_FLOAT, @@ -62,5 +62,8 @@ namespace Tensorflow /// public Tensor random_shuffle(Tensor value, int? seed = null, string name = null) => random_ops.random_shuffle(value, seed: seed, name: name); + + public void set_random_seed(int seed) + => ops.get_default_graph().seed = seed; } } diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index ab7a1703..9df8d45c 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -273,6 +273,9 @@ namespace Tensorflow return sum; } + public static double sum(IEnumerable enumerable) + => enumerable.Sum(); + public static double sum(Dictionary values) { return sum(values.Keys); diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 1b0e7c57..0c6f6291 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -1,4 +1,6 @@ -namespace Tensorflow.Data +using System; + +namespace Tensorflow.Data { /// /// Represents a potentially large set of elements. @@ -11,5 +13,9 @@ /// public class DatasetV2 { + public static DatasetV2 from_generator() + { + throw new NotImplementedException(""); + } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a162f54d..1f62295a 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -107,6 +107,16 @@ namespace Tensorflow public bool building_function; + int _seed; + public int seed + { + get => _seed; + set + { + _seed = value; + } + } + public Graph() { _handle = c_api.TF_NewGraph(); diff --git a/src/TensorFlowNET.Core/Keras/Initializers.cs b/src/TensorFlowNET.Core/Keras/Initializers.cs index 1a4fe9e4..b432cc97 100644 --- a/src/TensorFlowNET.Core/Keras/Initializers.cs +++ b/src/TensorFlowNET.Core/Keras/Initializers.cs @@ -27,7 +27,7 @@ namespace Tensorflow.Keras /// public IInitializer he_normal(int? seed = null) { - return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed); + return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index 1a8b8ba9..f418f8a3 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -22,7 +22,10 @@ namespace Tensorflow.Operations.Initializers string mode = "fan_avg", string distribution = "uniform", int? seed = null, - TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) + TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, + mode: mode, + seed: seed, + dtype: dtype) { } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index 636b1451..c0bbcd88 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -31,17 +31,22 @@ namespace Tensorflow.Operations.Initializers protected int? _seed; protected TF_DataType _dtype; - public VarianceScaling(float scale = 1.0f, - string mode = "fan_in", - string distribution = "truncated_normal", + public VarianceScaling(float factor = 2.0f, + string mode = "FAN_IN", + bool uniform = false, int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { - if (scale < 0) + if (!dtype.is_floating()) + throw new TypeError("Cannot create initializer for non-floating point type."); + if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode)) + throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]"); + + if (factor < 0) throw new ValueError("`scale` must be positive float."); - _scale = scale; + + _scale = factor; _mode = mode; - _distribution = distribution; _seed = seed; _dtype = dtype; }