diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 6c404276..7fbdf229 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -147,6 +147,7 @@ namespace Tensorflow
///
public Graph as_default()
{
+ tf.Context.graph_mode();
return ops.set_default_graph(this);
}
@@ -490,6 +491,7 @@ namespace Tensorflow
protected override void DisposeManagedResources()
{
+ tf.Context.eager_mode();
ops.default_graph_stack.remove(this);
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
index 7ddb79c7..fd83ae7e 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
@@ -18,9 +18,11 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
+using Tensorflow.Contexts;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
+using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using static Tensorflow.Binding;
@@ -46,7 +48,7 @@ namespace Tensorflow.Keras.Engine
protected bool built;
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;
-
+
///
/// A stateful layer is a layer whose updates are run during inference too,
/// for instance stateful RNNs.
@@ -110,8 +112,11 @@ namespace Tensorflow.Keras.Engine
///
///
///
- public Tensor Apply(Tensor[] inputs, bool is_training = false)
+ public Tensor[] Apply(Tensor[] inputs, bool is_training = false)
{
+ var input = inputs[0];
+ Tensor[] outputs = null;
+
callContext = callContext ?? new ThreadLocal()
{
Value = new CallContext()
@@ -120,7 +125,7 @@ namespace Tensorflow.Keras.Engine
using var ctxManager = CallContext.enter();
string nameScope = "";
- if (tf.Context.executing_eagerly())
+ if (tf.executing_eagerly())
{
nameScope = name;
}
@@ -129,15 +134,21 @@ namespace Tensorflow.Keras.Engine
throw new NotImplementedException("");
}
+ using var graph = tf.keras.backend.get_graph().as_default();
+
tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs);
- call(inputs, is_training: is_training);
+ outputs = call(inputs, is_training: is_training);
+
+ (input, outputs) = _set_connectivity_metadata_(input, outputs);
+ _handle_activity_regularization(inputs[0], outputs);
+ _set_mask_metadata(inputs[0], outputs, null);
});
- throw new NotImplementedException("");
+ return outputs;
}
[Obsolete("User Apply()")]
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
index 90109c1e..c6485427 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Dense.cs
@@ -30,8 +30,9 @@ namespace Tensorflow.Keras.Layers
public class Dense : Layer
{
DenseArgs args;
- protected IVariableV1 kernel;
- protected IVariableV1 bias;
+ IVariableV1 kernel;
+ IVariableV1 bias;
+ Activation activation => args.Activation;
public Dense(DenseArgs args) :
base(args)
@@ -74,15 +75,15 @@ namespace Tensorflow.Keras.Layers
}
else
{
- outputs = gen_math_ops.mat_mul(inputs[0], kernel.Handle);
+ outputs = gen_math_ops.mat_mul(inputs[0], kernel.AsTensor());
}
if (args.UseBias)
outputs = tf.nn.bias_add(outputs, bias);
- //if (args.Activation != null)
- //outputs = args.Activation.Activate(outputs);
+ if (args.Activation != null)
+ outputs = activation(outputs);
- return new[] { outputs, outputs };
+ return new[] { outputs };
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
index 18325a49..e8d16820 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
@@ -36,7 +36,11 @@ namespace Tensorflow.Keras.Utils
ops.init_scope();
- Func init_val = () => args.Initializer.call(args.Shape, dtype: args.DType);
+ Func init_val = () => args.Initializer.Apply(new InitializerArgs
+ {
+ Shape = args.Shape,
+ DType = args.DType
+ });
var variable_dtype = args.DType.as_base_dtype();
var v = tf.Variable(init_val,
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
index 708d9db6..cf230978 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
@@ -29,27 +29,18 @@ namespace Tensorflow.Operations.Initializers
_verify_shape = verify_shape;
}
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- if (dtype == TF_DataType.DtInvalid)
- dtype = this.dtype;
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
- if (!verify_shape.HasValue)
- verify_shape = _verify_shape;
+ if (!args.VerifyShape.HasValue)
+ args.VerifyShape = _verify_shape;
- return constant_op._constant_impl(value, dtype, shape,
+ return constant_op._constant_impl(value, args.DType, args.Shape,
name: "Const",
- verify_shape: verify_shape.Value,
+ verify_shape: args.VerifyShape.Value,
allow_broadcast: false);
}
-
- public object get_config()
- {
- return new
- {
- value,
- dtype = dtype.name()
- };
- }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
index 5d38aa7f..8e59370a 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
@@ -30,18 +30,5 @@ namespace Tensorflow.Operations.Initializers
{
}
-
-#pragma warning disable CS0114 // Member hides inherited member; missing override keyword
- public object get_config()
-#pragma warning restore CS0114 // Member hides inherited member; missing override keyword
- {
- return new
- {
- scale = _scale,
- mode = _mode,
- seed = _seed,
- dtype = _dtype
- };
- }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
index 0ac0865f..50d4d503 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
@@ -18,7 +18,6 @@ namespace Tensorflow
{
public interface IInitializer
{
- Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null);
- object get_config();
+ Tensor Apply(InitializerArgs args);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
new file mode 100644
index 00000000..561664bc
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Initializers/InitializerArgs.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class InitializerArgs
+ {
+ public TensorShape Shape { get; set; }
+ public TF_DataType DType { get; set; }
+ public bool? VerifyShape { get; set; } = null;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
index 83e5b57d..02d3c93b 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
@@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- if (dtype == TF_DataType.DtInvalid)
- dtype = this.dtype;
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
- return array_ops.ones(shape.dims, dtype);
- }
-
- public object get_config()
- {
- return new { dtype = dtype.name() };
+ return array_ops.ones(args.Shape, dtype);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
index a3e2063f..31473912 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
@@ -38,22 +38,11 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- if (dtype == TF_DataType.DtInvalid)
- dtype = this.dtype;
- return random_ops.random_normal(shape, mean, stddev, dtype, seed: seed);
- }
-
- public object get_config()
- {
- return new
- {
- mean,
- stddev,
- seed,
- dtype
- };
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
+ return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
index bd082214..c2e9889b 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
@@ -27,32 +27,23 @@ namespace Tensorflow.Operations.Initializers
#pragma warning disable CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
private float maxval;
#pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0
-#pragma warning disable CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value
private TF_DataType dtype;
-#pragma warning restore CS0649 // Field 'RandomUniform.dtype' is never assigned to, and will always have its default value
- public RandomUniform()
+ public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid)
{
-
+ this.dtype = dtype;
}
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- return random_ops.random_uniform(shape,
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
+
+ return random_ops.random_uniform(args.Shape,
minval: minval,
maxval: maxval,
dtype: dtype,
seed: seed);
}
-
- public object get_config()
- {
- return new {
- minval,
- maxval,
- seed,
- dtype
- };
- }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
index 7d635f0c..e656f7ea 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
@@ -34,20 +34,11 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}
- public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- return random_ops.truncated_normal(shape, mean, stddev, dtype : dtype, seed: seed);
- }
-
- public object get_config()
- {
- return new
- {
- mean = mean,
- stddev = stddev,
- seed = seed,
- dtype = dtype.name()
- };
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
+ return random_ops.truncated_normal(args.Shape, mean, stddev, dtype : dtype, seed: seed);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
index 41b6689c..12d6cb68 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -53,10 +53,13 @@ namespace Tensorflow.Operations.Initializers
_uniform = uniform;
}
- public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this._dtype;
+
float n = 0;
- var (fan_in, fan_out) = _compute_fans(shape);
+ var (fan_in, fan_out) = _compute_fans(args.Shape);
if (_mode == "FAN_IN")
n = fan_in;
else if (_mode == "FAN_OUT")
@@ -67,13 +70,12 @@ namespace Tensorflow.Operations.Initializers
if(_uniform)
{
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
- return random_ops.random_uniform(shape, -limit, limit,
- dtype, seed: _seed);
+ return random_ops.random_uniform(args.Shape, -limit, limit, args.DType);
}
else
{
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
- return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype,
+ return random_ops.truncated_normal(args.Shape, 0.0f, trunc_stddev, args.DType,
seed: _seed);
}
}
@@ -98,18 +100,5 @@ namespace Tensorflow.Operations.Initializers
return (fan_in, fan_out);
}
}
-
- public virtual object get_config()
- {
- return new
- {
- scale = _scale,
- mode = _mode,
- distribution = _distribution,
- seed = _seed,
- uniform = _uniform,
- dtype = _dtype
- };
- }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
index bea9cf71..67e5d424 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
@@ -25,17 +25,12 @@ namespace Tensorflow.Operations.Initializers
this.dtype = dtype;
}
- public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, bool? verify_shape = null)
+ public Tensor Apply(InitializerArgs args)
{
- if (dtype == TF_DataType.DtInvalid)
- dtype = this.dtype;
+ if (args.DType == TF_DataType.DtInvalid)
+ args.DType = this.dtype;
- return array_ops.zeros(shape, dtype);
- }
-
- public object get_config()
- {
- return new { dtype = dtype.name() };
+ return array_ops.zeros(args.Shape, dtype);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 6ec7e261..d88dca8c 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -955,7 +955,7 @@ namespace Tensorflow
///
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null)
{
- if (tf.Context.executing_eagerly())
+ if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MatMul", name,
diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
index af98802f..f3442be8 100644
--- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
@@ -38,7 +38,7 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;
- if (tf.Context.executing_eagerly())
+ if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"RandomStandardNormal", name,
@@ -98,6 +98,19 @@ namespace Tensorflow
if (!seed2.HasValue)
seed2 = 0;
+ if (tf.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "RandomUniform", name,
+ null,
+ shape,
+ "seed", seed,
+ "seed2", seed2,
+ "dtype", dtype);
+
+ return results[0];
+ }
+
var _op = tf.OpDefLib._apply_op_helper("RandomUniform",
name: name,
args: new { shape, dtype, seed, seed2});
diff --git a/src/TensorFlowNET.Core/Operations/random_ops.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs
index e92d5619..64530396 100644
--- a/src/TensorFlowNET.Core/Operations/random_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/random_ops.cs
@@ -72,10 +72,11 @@ namespace Tensorflow
return tf_with(ops.name_scope(name, "random_uniform", new { shape, minval, maxval }), scope =>
{
name = scope;
+ var (seed1, seed2) = random_seed.get_seed(seed);
var tensorShape = tensor_util.shape_tensor(shape);
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
- var rnd = gen_random_ops.random_uniform(tensorShape, dtype);
+ var rnd = gen_random_ops.random_uniform(tensorShape, dtype, seed: seed1, seed2: seed2);
return math_ops.add(rnd * (maxTensor - minTensor), minTensor, name: name);
});
}
diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
index 291ad99b..fb76188b 100644
--- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs
+++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs
@@ -162,7 +162,11 @@ namespace Tensorflow
}
else
{
- Func init_val = () => initializer.call(shape, dtype);
+ Func init_val = () => initializer.Apply(new InitializerArgs
+ {
+ Shape = shape,
+ DType = dtype
+ });
var variable_dtype = dtype.as_base_dtype();
v = variable_scope.default_variable_creator(init_val,
diff --git a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
similarity index 87%
rename from test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs
rename to test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
index 58c7845f..0bae85bb 100644
--- a/test/TensorFlowNET.UnitTest/Keras/EmbeddingTest.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
@@ -11,10 +11,10 @@ using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest.Keras
{
///
- /// https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Embedding
+ /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
///
- [TestClass, Ignore]
- public class EmbeddingTest : GraphModeTestBase
+ [TestClass]
+ public class LayersTest : GraphModeTestBase
{
[TestMethod]
public void Embedding()