From c976b818fc1a22fda9df3ec2de535ec143ab96f2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Mar 2019 22:48:36 -0600 Subject: [PATCH] unit test of how to reenter a premade variable scope safely --- src/TensorFlowNET.Core/APIs/tf.init.cs | 16 +++- .../Variables/PureVariableScope.cs | 50 +++++++++--- .../Variables/VariableScope.cs | 11 +-- .../Variables/variable_scope.py.cs | 79 +++++++++++++++---- src/TensorFlowNET.Core/ops.py.cs | 6 ++ test/TensorFlowNET.UnitTest/VariableTest.cs | 18 +++++ 6 files changed, 147 insertions(+), 33 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 32b75807..19876d62 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -9,9 +9,21 @@ namespace Tensorflow public static IInitializer zeros_initializer => new Zeros(); public static IInitializer glorot_uniform_initializer => new GlorotUniform(); - public static variable_scope variable_scope(string name_or_scope, + public static variable_scope variable_scope(string name, string default_name = null, - object values = null) => new variable_scope(name_or_scope, default_name, values); + object values = null, + bool auxiliary_name_scope = true) => new variable_scope(name, + default_name, + values, + auxiliary_name_scope); + + public static variable_scope variable_scope(VariableScope scope, + string default_name = null, + object values = null, + bool auxiliary_name_scope = true) => new variable_scope(scope, + default_name, + values, + auxiliary_name_scope); public class Zeros : IInitializer { diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs index 15401eea..6f97b19d 100644 --- a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs @@ -6,7 +6,8 @@ namespace Tensorflow { public class PureVariableScope : IPython { - private string _name_or_scope; + private string _name; + private VariableScope _scope; private string _new_name; private string _old_name_scope; private bool _reuse; @@ -14,29 +15,56 @@ namespace Tensorflow private VariableScope _old; private _VariableScopeStore _var_scope_store; private VariableScope variable_scope_object; + private VariableScope _cached_variable_scope_object; - public PureVariableScope(string name_or_scope, + public PureVariableScope(string name, string old_name_scope = null, TF_DataType dtype = TF_DataType.DtInvalid) { - _name_or_scope = name_or_scope; + _name = name; _old_name_scope = old_name_scope; _var_store = variable_scope._get_default_variable_store(); _var_scope_store = variable_scope.get_variable_scope_store(); } - public void __enter__() + public PureVariableScope(VariableScope scope, + string old_name_scope = null, + TF_DataType dtype = TF_DataType.DtInvalid) { - _old = _var_scope_store.current_scope; - _new_name = string.IsNullOrEmpty(_old.name) ? _name_or_scope : _old.name + "/" + _name_or_scope; - _reuse = _reuse || _old.resue; - string name_scope = _old_name_scope == null ? _name_or_scope : _old_name_scope; - - variable_scope_object = new VariableScope(_reuse, + _scope = scope; + _old_name_scope = old_name_scope; + _var_store = variable_scope._get_default_variable_store(); + _var_scope_store = variable_scope.get_variable_scope_store(); + _new_name = _scope._name; + + string name_scope = _scope._name_scope; + variable_scope_object = new VariableScope(_reuse, name: _new_name, name_scope: name_scope); - _var_scope_store.open_variable_scope(_new_name); + _cached_variable_scope_object = variable_scope_object; + } + + public void __enter__() + { + _old = _var_scope_store.current_scope; + if(_scope != null) + { + _var_scope_store.open_variable_scope(_new_name); + variable_scope_object = _cached_variable_scope_object; + } + else + { + _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name; + _reuse = _reuse || _old.resue; + string name_scope = _old_name_scope == null ? _name : _old_name_scope; + + variable_scope_object = new VariableScope(_reuse, + name: _new_name, + name_scope: name_scope); + + _var_scope_store.open_variable_scope(_new_name); + } _var_scope_store.current_scope = variable_scope_object; } diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 7b924ce6..29c03c19 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -14,16 +14,17 @@ namespace Tensorflow public bool resue; private TF_DataType _dtype; - public string name { get; set; } - public string name_scope { get; set; } + public string _name { get; set; } + public string _name_scope { get; set; } + public string original_name_scope => _name_scope; public VariableScope(bool reuse, string name = "", string name_scope = "", TF_DataType dtype = TF_DataType.TF_FLOAT) { - this.name = name; - this.name_scope = name_scope; + _name = name; + _name_scope = name_scope; _reuse = _ReuseMode.AUTO_REUSE; _dtype = dtype; } @@ -37,7 +38,7 @@ namespace Tensorflow VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation= VariableAggregation.NONE) { - string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; + string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name; return with(new ops.name_scope(null), scope => { if (dtype == TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index d59daa35..779e647b 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; namespace Tensorflow @@ -12,42 +13,83 @@ namespace Tensorflow private bool _use_resource; public bool UseResource => _use_resource; - private string _name_or_scope; + private string _name; + private VariableScope _scope; private string _default_name; private object _values; - private string _current_name_scope; + private ops.name_scope _current_name_scope; + private bool _auxiliary_name_scope; private PureVariableScope _cached_pure_variable_scope; - public variable_scope(string name_or_scope, string default_name = "", object values = null) + public variable_scope(string name, + string default_name = "", + object values = null, + bool auxiliary_name_scope = true) { - _name_or_scope = name_or_scope; + _name = name; _default_name = default_name; _values = values; _current_name_scope = null; _use_resource = false; - if (_default_name == null && _name_or_scope == null) - throw new TypeError("If default_name is None then name_or_scope is required"); + if (_default_name == null && _name == null) + throw new TypeError("If default_name is None then name is required"); + + _auxiliary_name_scope = auxiliary_name_scope; + } + + public variable_scope(VariableScope scope, + string default_name = "", + object values = null, + bool auxiliary_name_scope = true) + { + _scope = scope; + _default_name = default_name; + _values = values; + _current_name_scope = null; + + _use_resource = false; + if (_default_name == null && _scope == null) + throw new TypeError("If default_name is None then scope is required"); + + _auxiliary_name_scope = auxiliary_name_scope; } public void __enter__() { - _enter_scope_uncached(); + _scope = _enter_scope_uncached(); } - public VariableScope _enter_scope_uncached() + private VariableScope _enter_scope_uncached() { - ops.name_scope current_name_scope = null; - if(_name_or_scope != null) + ops.name_scope current_name_scope; + if (_auxiliary_name_scope) + // Create a new name scope later + current_name_scope = null; + else { - var name_scope = _name_or_scope; + // Reenter the current name scope + string name_scope = ops.get_name_scope(); + if(!string.IsNullOrEmpty(name_scope)) + // Hack to reenter + name_scope += "/"; + current_name_scope = new ops.name_scope(name_scope); + } + + if (_name != null || _scope != null) + { + var name_scope = _name == null ? _scope._name.Split('/').Last() : _name; if (name_scope != null || current_name_scope != null) current_name_scope = new ops.name_scope(name_scope); current_name_scope.__enter__(); - string current_name_scope_name = current_name_scope; + var current_name_scope_name = current_name_scope; _current_name_scope = current_name_scope; string old_name_scope = current_name_scope_name; - var pure_variable_scope = new PureVariableScope(_name_or_scope, old_name_scope: old_name_scope); + PureVariableScope pure_variable_scope = null; + if(_scope == null) + pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); + else + pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); pure_variable_scope.__enter__(); VariableScope entered_pure_variable_scope = pure_variable_scope; _cached_pure_variable_scope = pure_variable_scope; @@ -149,14 +191,21 @@ namespace Tensorflow return trainable.Value; } + public static implicit operator VariableScope(variable_scope scope) + { + return scope._scope; + } + public void __exit__() { - + if (_current_name_scope != null) + _current_name_scope.__exit__(); } public void Dispose() { - + if (_current_name_scope != null) + _current_name_scope.Dispose(); } } } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index c993ee8c..50fd59c9 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -475,5 +475,11 @@ namespace Tensorflow return name; } } + + public static string get_name_scope() + { + var g = get_default_graph(); + return g.get_name_scope(); + } } } diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 0964aaa0..7713e774 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -47,10 +47,28 @@ namespace TensorFlowNET.UnitTest }); } + /// + /// how to reenter a premade variable scope safely + /// [TestMethod] public void ReenterVariableScope() { + variable_scope vs = null; + with(tf.variable_scope("foo"), v => vs = v); + // Re-enter the variable scope. + with(tf.variable_scope(vs, auxiliary_name_scope: false), v => + { + var vs1 = (VariableScope)v; + // Restore the original name_scope. + with(tf.name_scope(vs1.original_name_scope), delegate + { + var v1 = tf.get_variable("v", new TensorShape(1)); + Assert.AreEqual(v1.name, "foo/v:0"); + var c1 = tf.constant(new int[] { 1 }, name: "c"); + Assert.AreEqual(c1.name, "foo/c:0"); + }); + }); } [TestMethod]