diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 15bcd766..db2ea1b1 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -66,20 +66,20 @@ namespace Tensorflow
///
/// Initializer capable of adapting its scale to the shape of weights tensors.
///
- ///
+ ///
///
///
///
///
///
- public IInitializer variance_scaling_initializer(float scale = 1.0f,
- string mode = "fan_in",
- string distribution = "truncated_normal",
+ public IInitializer variance_scaling_initializer(float factor = 1.0f,
+ string mode = "FAN_IN",
+ bool uniform = false,
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) => new VarianceScaling(
- scale: scale,
+ factor: factor,
mode: mode,
- distribution: distribution,
+ uniform: uniform,
seed: seed,
dtype: dtype);
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs
index c331eb7f..c11ca791 100644
--- a/src/TensorFlowNET.Core/APIs/tf.random.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.random.cs
@@ -28,21 +28,21 @@ namespace Tensorflow
///
///
///
- public Tensor random_normal(int[] shape,
+ public Tensor random_normal(TensorShape shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
- public Tensor random_uniform(int[] shape,
+ public Tensor random_uniform(TensorShape shape,
float minval = 0,
float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
- public Tensor truncated_normal(int[] shape,
+ public Tensor truncated_normal(TensorShape shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
@@ -62,5 +62,8 @@ namespace Tensorflow
///
public Tensor random_shuffle(Tensor value, int? seed = null, string name = null)
=> random_ops.random_shuffle(value, seed: seed, name: name);
+
+ public void set_random_seed(int seed)
+ => ops.get_default_graph().seed = seed;
}
}
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index ab7a1703..9df8d45c 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -273,6 +273,9 @@ namespace Tensorflow
return sum;
}
+ public static double sum(IEnumerable enumerable)
+ => enumerable.Sum();
+
public static double sum(Dictionary values)
{
return sum(values.Keys);
diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 1b0e7c57..0c6f6291 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -1,4 +1,6 @@
-namespace Tensorflow.Data
+using System;
+
+namespace Tensorflow.Data
{
///
/// Represents a potentially large set of elements.
@@ -11,5 +13,9 @@
///
public class DatasetV2
{
+ public static DatasetV2 from_generator()
+ {
+ throw new NotImplementedException("");
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index a162f54d..1f62295a 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -107,6 +107,16 @@ namespace Tensorflow
public bool building_function;
+ int _seed;
+ public int seed
+ {
+ get => _seed;
+ set
+ {
+ _seed = value;
+ }
+ }
+
public Graph()
{
_handle = c_api.TF_NewGraph();
diff --git a/src/TensorFlowNET.Core/Keras/Initializers.cs b/src/TensorFlowNET.Core/Keras/Initializers.cs
index 1a4fe9e4..b432cc97 100644
--- a/src/TensorFlowNET.Core/Keras/Initializers.cs
+++ b/src/TensorFlowNET.Core/Keras/Initializers.cs
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras
///
public IInitializer he_normal(int? seed = null)
{
- return new VarianceScaling(scale: 2.0f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
+ return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
index 1a8b8ba9..f418f8a3 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
@@ -22,7 +22,10 @@ namespace Tensorflow.Operations.Initializers
string mode = "fan_avg",
string distribution = "uniform",
int? seed = null,
- TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
+ TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
+ mode: mode,
+ seed: seed,
+ dtype: dtype)
{
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
index 636b1451..c0bbcd88 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -31,17 +31,22 @@ namespace Tensorflow.Operations.Initializers
protected int? _seed;
protected TF_DataType _dtype;
- public VarianceScaling(float scale = 1.0f,
- string mode = "fan_in",
- string distribution = "truncated_normal",
+ public VarianceScaling(float factor = 2.0f,
+ string mode = "FAN_IN",
+ bool uniform = false,
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
- if (scale < 0)
+ if (!dtype.is_floating())
+ throw new TypeError("Cannot create initializer for non-floating point type.");
+ if (!new string[] { "FAN_IN", "FAN_OUT", "FAN_AVG" }.Contains(mode))
+ throw new TypeError($"Unknown {mode} %s [FAN_IN, FAN_OUT, FAN_AVG]");
+
+ if (factor < 0)
throw new ValueError("`scale` must be positive float.");
- _scale = scale;
+
+ _scale = factor;
_mode = mode;
- _distribution = distribution;
_seed = seed;
_dtype = dtype;
}