| @@ -9,9 +9,21 @@ namespace Tensorflow | |||||
| public static IInitializer zeros_initializer => new Zeros(); | public static IInitializer zeros_initializer => new Zeros(); | ||||
| public static IInitializer glorot_uniform_initializer => new GlorotUniform(); | public static IInitializer glorot_uniform_initializer => new GlorotUniform(); | ||||
| public static variable_scope variable_scope(string name_or_scope, | |||||
| public static variable_scope variable_scope(string name, | |||||
| string default_name = null, | string default_name = null, | ||||
| object values = null) => new variable_scope(name_or_scope, default_name, values); | |||||
| object values = null, | |||||
| bool auxiliary_name_scope = true) => new variable_scope(name, | |||||
| default_name, | |||||
| values, | |||||
| auxiliary_name_scope); | |||||
| public static variable_scope variable_scope(VariableScope scope, | |||||
| string default_name = null, | |||||
| object values = null, | |||||
| bool auxiliary_name_scope = true) => new variable_scope(scope, | |||||
| default_name, | |||||
| values, | |||||
| auxiliary_name_scope); | |||||
| public class Zeros : IInitializer | public class Zeros : IInitializer | ||||
| { | { | ||||
| @@ -6,7 +6,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class PureVariableScope : IPython | public class PureVariableScope : IPython | ||||
| { | { | ||||
| private string _name_or_scope; | |||||
| private string _name; | |||||
| private VariableScope _scope; | |||||
| private string _new_name; | private string _new_name; | ||||
| private string _old_name_scope; | private string _old_name_scope; | ||||
| private bool _reuse; | private bool _reuse; | ||||
| @@ -14,29 +15,56 @@ namespace Tensorflow | |||||
| private VariableScope _old; | private VariableScope _old; | ||||
| private _VariableScopeStore _var_scope_store; | private _VariableScopeStore _var_scope_store; | ||||
| private VariableScope variable_scope_object; | private VariableScope variable_scope_object; | ||||
| private VariableScope _cached_variable_scope_object; | |||||
| public PureVariableScope(string name_or_scope, | |||||
| public PureVariableScope(string name, | |||||
| string old_name_scope = null, | string old_name_scope = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid) | TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| _name_or_scope = name_or_scope; | |||||
| _name = name; | |||||
| _old_name_scope = old_name_scope; | _old_name_scope = old_name_scope; | ||||
| _var_store = variable_scope._get_default_variable_store(); | _var_store = variable_scope._get_default_variable_store(); | ||||
| _var_scope_store = variable_scope.get_variable_scope_store(); | _var_scope_store = variable_scope.get_variable_scope_store(); | ||||
| } | } | ||||
| public void __enter__() | |||||
| public PureVariableScope(VariableScope scope, | |||||
| string old_name_scope = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| _old = _var_scope_store.current_scope; | |||||
| _new_name = string.IsNullOrEmpty(_old.name) ? _name_or_scope : _old.name + "/" + _name_or_scope; | |||||
| _reuse = _reuse || _old.resue; | |||||
| string name_scope = _old_name_scope == null ? _name_or_scope : _old_name_scope; | |||||
| variable_scope_object = new VariableScope(_reuse, | |||||
| _scope = scope; | |||||
| _old_name_scope = old_name_scope; | |||||
| _var_store = variable_scope._get_default_variable_store(); | |||||
| _var_scope_store = variable_scope.get_variable_scope_store(); | |||||
| _new_name = _scope._name; | |||||
| string name_scope = _scope._name_scope; | |||||
| variable_scope_object = new VariableScope(_reuse, | |||||
| name: _new_name, | name: _new_name, | ||||
| name_scope: name_scope); | name_scope: name_scope); | ||||
| _var_scope_store.open_variable_scope(_new_name); | |||||
| _cached_variable_scope_object = variable_scope_object; | |||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| _old = _var_scope_store.current_scope; | |||||
| if(_scope != null) | |||||
| { | |||||
| _var_scope_store.open_variable_scope(_new_name); | |||||
| variable_scope_object = _cached_variable_scope_object; | |||||
| } | |||||
| else | |||||
| { | |||||
| _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name; | |||||
| _reuse = _reuse || _old.resue; | |||||
| string name_scope = _old_name_scope == null ? _name : _old_name_scope; | |||||
| variable_scope_object = new VariableScope(_reuse, | |||||
| name: _new_name, | |||||
| name_scope: name_scope); | |||||
| _var_scope_store.open_variable_scope(_new_name); | |||||
| } | |||||
| _var_scope_store.current_scope = variable_scope_object; | _var_scope_store.current_scope = variable_scope_object; | ||||
| } | } | ||||
| @@ -14,16 +14,17 @@ namespace Tensorflow | |||||
| public bool resue; | public bool resue; | ||||
| private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
| public string name { get; set; } | |||||
| public string name_scope { get; set; } | |||||
| public string _name { get; set; } | |||||
| public string _name_scope { get; set; } | |||||
| public string original_name_scope => _name_scope; | |||||
| public VariableScope(bool reuse, | public VariableScope(bool reuse, | ||||
| string name = "", | string name = "", | ||||
| string name_scope = "", | string name_scope = "", | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT) | TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| this.name = name; | |||||
| this.name_scope = name_scope; | |||||
| _name = name; | |||||
| _name_scope = name_scope; | |||||
| _reuse = _ReuseMode.AUTO_REUSE; | _reuse = _ReuseMode.AUTO_REUSE; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| } | } | ||||
| @@ -37,7 +38,7 @@ namespace Tensorflow | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation= VariableAggregation.NONE) | VariableAggregation aggregation= VariableAggregation.NONE) | ||||
| { | { | ||||
| string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | |||||
| string full_name = !string.IsNullOrEmpty(this._name) ? this._name + "/" + name : name; | |||||
| return with(new ops.name_scope(null), scope => | return with(new ops.name_scope(null), scope => | ||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -12,42 +13,83 @@ namespace Tensorflow | |||||
| private bool _use_resource; | private bool _use_resource; | ||||
| public bool UseResource => _use_resource; | public bool UseResource => _use_resource; | ||||
| private string _name_or_scope; | |||||
| private string _name; | |||||
| private VariableScope _scope; | |||||
| private string _default_name; | private string _default_name; | ||||
| private object _values; | private object _values; | ||||
| private string _current_name_scope; | |||||
| private ops.name_scope _current_name_scope; | |||||
| private bool _auxiliary_name_scope; | |||||
| private PureVariableScope _cached_pure_variable_scope; | private PureVariableScope _cached_pure_variable_scope; | ||||
| public variable_scope(string name_or_scope, string default_name = "", object values = null) | |||||
| public variable_scope(string name, | |||||
| string default_name = "", | |||||
| object values = null, | |||||
| bool auxiliary_name_scope = true) | |||||
| { | { | ||||
| _name_or_scope = name_or_scope; | |||||
| _name = name; | |||||
| _default_name = default_name; | _default_name = default_name; | ||||
| _values = values; | _values = values; | ||||
| _current_name_scope = null; | _current_name_scope = null; | ||||
| _use_resource = false; | _use_resource = false; | ||||
| if (_default_name == null && _name_or_scope == null) | |||||
| throw new TypeError("If default_name is None then name_or_scope is required"); | |||||
| if (_default_name == null && _name == null) | |||||
| throw new TypeError("If default_name is None then name is required"); | |||||
| _auxiliary_name_scope = auxiliary_name_scope; | |||||
| } | |||||
| public variable_scope(VariableScope scope, | |||||
| string default_name = "", | |||||
| object values = null, | |||||
| bool auxiliary_name_scope = true) | |||||
| { | |||||
| _scope = scope; | |||||
| _default_name = default_name; | |||||
| _values = values; | |||||
| _current_name_scope = null; | |||||
| _use_resource = false; | |||||
| if (_default_name == null && _scope == null) | |||||
| throw new TypeError("If default_name is None then scope is required"); | |||||
| _auxiliary_name_scope = auxiliary_name_scope; | |||||
| } | } | ||||
| public void __enter__() | public void __enter__() | ||||
| { | { | ||||
| _enter_scope_uncached(); | |||||
| _scope = _enter_scope_uncached(); | |||||
| } | } | ||||
| public VariableScope _enter_scope_uncached() | |||||
| private VariableScope _enter_scope_uncached() | |||||
| { | { | ||||
| ops.name_scope current_name_scope = null; | |||||
| if(_name_or_scope != null) | |||||
| ops.name_scope current_name_scope; | |||||
| if (_auxiliary_name_scope) | |||||
| // Create a new name scope later | |||||
| current_name_scope = null; | |||||
| else | |||||
| { | { | ||||
| var name_scope = _name_or_scope; | |||||
| // Reenter the current name scope | |||||
| string name_scope = ops.get_name_scope(); | |||||
| if(!string.IsNullOrEmpty(name_scope)) | |||||
| // Hack to reenter | |||||
| name_scope += "/"; | |||||
| current_name_scope = new ops.name_scope(name_scope); | |||||
| } | |||||
| if (_name != null || _scope != null) | |||||
| { | |||||
| var name_scope = _name == null ? _scope._name.Split('/').Last() : _name; | |||||
| if (name_scope != null || current_name_scope != null) | if (name_scope != null || current_name_scope != null) | ||||
| current_name_scope = new ops.name_scope(name_scope); | current_name_scope = new ops.name_scope(name_scope); | ||||
| current_name_scope.__enter__(); | current_name_scope.__enter__(); | ||||
| string current_name_scope_name = current_name_scope; | |||||
| var current_name_scope_name = current_name_scope; | |||||
| _current_name_scope = current_name_scope; | _current_name_scope = current_name_scope; | ||||
| string old_name_scope = current_name_scope_name; | string old_name_scope = current_name_scope_name; | ||||
| var pure_variable_scope = new PureVariableScope(_name_or_scope, old_name_scope: old_name_scope); | |||||
| PureVariableScope pure_variable_scope = null; | |||||
| if(_scope == null) | |||||
| pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope); | |||||
| else | |||||
| pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope); | |||||
| pure_variable_scope.__enter__(); | pure_variable_scope.__enter__(); | ||||
| VariableScope entered_pure_variable_scope = pure_variable_scope; | VariableScope entered_pure_variable_scope = pure_variable_scope; | ||||
| _cached_pure_variable_scope = pure_variable_scope; | _cached_pure_variable_scope = pure_variable_scope; | ||||
| @@ -149,14 +191,21 @@ namespace Tensorflow | |||||
| return trainable.Value; | return trainable.Value; | ||||
| } | } | ||||
| public static implicit operator VariableScope(variable_scope scope) | |||||
| { | |||||
| return scope._scope; | |||||
| } | |||||
| public void __exit__() | public void __exit__() | ||||
| { | { | ||||
| if (_current_name_scope != null) | |||||
| _current_name_scope.__exit__(); | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| if (_current_name_scope != null) | |||||
| _current_name_scope.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -475,5 +475,11 @@ namespace Tensorflow | |||||
| return name; | return name; | ||||
| } | } | ||||
| } | } | ||||
| public static string get_name_scope() | |||||
| { | |||||
| var g = get_default_graph(); | |||||
| return g.get_name_scope(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -47,10 +47,28 @@ namespace TensorFlowNET.UnitTest | |||||
| }); | }); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// how to reenter a premade variable scope safely | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void ReenterVariableScope() | public void ReenterVariableScope() | ||||
| { | { | ||||
| variable_scope vs = null; | |||||
| with(tf.variable_scope("foo"), v => vs = v); | |||||
| // Re-enter the variable scope. | |||||
| with(tf.variable_scope(vs, auxiliary_name_scope: false), v => | |||||
| { | |||||
| var vs1 = (VariableScope)v; | |||||
| // Restore the original name_scope. | |||||
| with(tf.name_scope(vs1.original_name_scope), delegate | |||||
| { | |||||
| var v1 = tf.get_variable("v", new TensorShape(1)); | |||||
| Assert.AreEqual(v1.name, "foo/v:0"); | |||||
| var c1 = tf.constant(new int[] { 1 }, name: "c"); | |||||
| Assert.AreEqual(c1.name, "foo/c:0"); | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| [TestMethod] | [TestMethod] | ||||