diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index a0a151ef..6c61b12a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -14,6 +14,8 @@ namespace Tensorflow.Keras.Layers /// A layer is a class implementing common neural networks operations, such /// as convolution, batch norm, etc. These operations require managing weights, /// losses, updates, and inter-layer connectivity. + /// + /// tensorflow\python\keras\engine\base_layer.py /// public class Layer : AutoTrackable { @@ -55,9 +57,14 @@ namespace Tensorflow.Keras.Layers { this.trainable = trainable; this._dtype = dtype; + // A stateful layer is a layer whose updates are run during inference too, + // for instance stateful RNNs. stateful = false; + // Indicates whether `build` needs to be called upon layer call, to create + // the layer's weights. built = false; this.supports_masking = false; + _init_set_name(name); _trainable_weights = new List(); _compute_previous_mask = false; @@ -154,7 +161,8 @@ namespace Tensorflow.Keras.Layers if (_dtype == TF_DataType.DtInvalid) _dtype = input.dtype; - build(input.GetShape()); + var input_shapes = input.GetShape(); + build(input_shapes); built = true; } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index ff0fd28e..463da7dc 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -22,8 +22,13 @@ namespace Tensorflow.Layers TF_DataType dtype = TF_DataType.DtInvalid, bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype) { + // For backwards compatibility, legacy layers do not use `ResourceVariable` + // by default. this._use_resource_variables = false; this._reuse = _reuse; + + // Avoid an incorrect lint error + _trainable_weights = new List(); this.built = false; _keras_style = false; } @@ -130,13 +135,12 @@ namespace Tensorflow.Layers initializer: initializer, trainable: trainable, getter: (name1, shape1, dtype1, initializer1, trainable1) => - { - return tf.get_variable(name1, + tf.get_variable(name1, shape: new TensorShape(shape1), dtype: dtype1, initializer: initializer1, - trainable: trainable1); - }); + trainable: trainable1) + ); //if (init_graph != null) //var trainable_variables = variables.trainable_variables(); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index e9b79614..501d2575 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -15,7 +15,7 @@ namespace Tensorflow public static Tensor operator -(Tensor x, Tensor y) => BinaryOpWrapper("sub", x, y); public static Tensor operator -(Tensor x, int y) => BinaryOpWrapper("sub", x, y); public static Tensor operator -(Tensor x, double y) => BinaryOpWrapper("sub", x, y); - public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("Sub", x, y); + public static Tensor operator -(float x, Tensor y) => BinaryOpWrapper("sub", x, y); public static Tensor operator *(float x, Tensor y) => BinaryOpWrapper("mul", x, y); public static Tensor operator *(double x, Tensor y) => BinaryOpWrapper("mul", x, y); diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs index 3476fc99..ff7ae22a 100644 --- a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -16,7 +17,8 @@ namespace Tensorflow private _VariableScopeStore _var_scope_store; private VariableScope variable_scope_object; private VariableScope _cached_variable_scope_object; - + VariableScope _last_variable_scope_object; + Dictionary _old_subscopes; public PureVariableScope(string name, string old_name_scope = null, TF_DataType dtype = TF_DataType.DtInvalid) @@ -51,6 +53,7 @@ namespace Tensorflow if(_scope != null) { _var_scope_store.open_variable_scope(_new_name); + _old_subscopes = _var_scope_store.variable_scopes_count.ToDictionary(kv => kv.Key, kv => kv.Value); variable_scope_object = _cached_variable_scope_object; } else @@ -66,6 +69,7 @@ namespace Tensorflow _var_scope_store.open_variable_scope(_new_name); } _var_scope_store.current_scope = variable_scope_object; + _last_variable_scope_object = variable_scope_object; } public void Dispose() @@ -75,7 +79,12 @@ namespace Tensorflow public void __exit__() { - + // If jumping out from a non-prolonged scope, restore counts. + if (_scope != null) + _var_scope_store.variable_scopes_count = _old_subscopes; + else + _var_scope_store.close_variable_subscopes(_new_name); + _var_scope_store.current_scope = _old; } public static implicit operator VariableScope(PureVariableScope scope) diff --git a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs index 8edc9a0c..6a3879db 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs @@ -7,7 +7,7 @@ namespace Tensorflow public class _VariableScopeStore { public VariableScope current_scope { get; set; } - private Dictionary variable_scopes_count; + public Dictionary variable_scopes_count; public _VariableScopeStore() { @@ -23,6 +23,13 @@ namespace Tensorflow variable_scopes_count[scope_name] = 1; } + public void close_variable_subscopes(string scope_name) + { + foreach (var k in variable_scopes_count.Keys) + if (scope_name == null || k.StartsWith(scope_name + "/")) + variable_scopes_count[k] = 0; + } + public int variable_scope_count(string scope_name) { if (variable_scopes_count.ContainsKey(scope_name)) diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 09429915..c972ae99 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -106,13 +106,13 @@ namespace Tensorflow if (_name != null || _scope != null) { - var name_scope = _name == null ? _scope.name.Split('/').Last() : _name; + var name_scope = _scope.name.Split('/').Last(); if (current_name_scope == null) current_name_scope = ops.name_scope(name_scope); current_name_scope.__enter__(); var current_name_scope_name = current_name_scope; _current_name_scope = current_name_scope; - string old_name_scope = current_name_scope_name; + string old_name_scope = _scope.original_name_scope; if(_scope == null) pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); @@ -139,6 +139,11 @@ namespace Tensorflow } } + /// + /// Get a name with the given prefix unique in the current variable scope. + /// + /// + /// public static string _get_unique_variable_scope(string prefix) { var var_scope_store = get_variable_scope_store(); @@ -146,7 +151,10 @@ namespace Tensorflow string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix; if (var_scope_store.variable_scope_count(name) == 0) return prefix; - throw new NotImplementedException("_get_unique_variable_scope"); + var idx = 1; + while (var_scope_store.variable_scope_count($"{name}_{idx}") > 0) + idx += 1; + return $"{prefix}_{idx}"; } public static RefVariable default_variable_creator(object initial_value, @@ -250,6 +258,7 @@ namespace Tensorflow public void __exit__() { + _cached_pure_variable_scope.__exit__(); if (_current_name_scope != null) _current_name_scope.__exit__(); } diff --git a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll index 0be52700..d4c2474c 100644 Binary files a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll and b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll differ