From de7a6941a189d9f1610f80a8f024f40eb42496f4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 7 Feb 2019 23:21:24 -0600 Subject: [PATCH] fix name_scope stack can't restore when name is empty. --- src/TensorFlowNET.Core/Graphs/Graph.cs | 4 +- .../Operations/array_ops.py.cs | 22 +++----- .../TensorFlowNET.Core.csproj | 6 +++ .../Variables/VariableScope.cs | 17 ++++-- .../Variables/_VariableStore.cs | 52 ++++++++++++++++--- .../Variables/tf.variable.cs | 14 +++-- .../Variables/variable_scope.py.cs | 21 ++++++-- 7 files changed, 101 insertions(+), 35 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 30d528df..aa3eb26e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -197,7 +197,9 @@ namespace Tensorflow { string new_stack = ""; - if (name.EndsWith("/")) + if (string.IsNullOrEmpty(name)) + new_stack = ""; + else if (name.EndsWith("/")) new_stack = ops._name_from_scope_name(name); else new_stack = unique_name(name); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 23a6eef3..28ff42cf 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -9,32 +9,24 @@ namespace Tensorflow { public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") { - Tensor output = null; - dtype = dtype.as_base_dtype(); - Python.with(new ops.name_scope(name, "zeros", shape), self => + return Python.with(new ops.name_scope(name, "zeros", shape), scope => { - name = self as ops.name_scope; + name = scope; switch (dtype) { case TF_DataType.TF_BOOL: - output = _constant_if_small(false, shape, dtype, name); - break; + return _constant_if_small(false, shape, dtype, name); case TF_DataType.TF_DOUBLE: - output = _constant_if_small(0.0D, shape, dtype, name); - break; + return _constant_if_small(0.0D, shape, dtype, name); case TF_DataType.TF_FLOAT: - output = _constant_if_small(0.0F, shape, dtype, name); - break; + return _constant_if_small(0.0F, shape, dtype, name); case TF_DataType.TF_INT32: - output = _constant_if_small(0, shape, dtype, name); - break; + return _constant_if_small(0, shape, dtype, name); default: - break; + throw new TypeError("can't find type for zeros"); } }); - - return output; } private static Tensor _constant_if_small(T value, Shape shape, TF_DataType dtype, string name) diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 182b457a..d398c163 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -31,6 +31,12 @@ TensorFlow 1.13 RC. true + + + + + + diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 2d2efdfa..c8d99036 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -13,28 +13,35 @@ namespace Tensorflow private TF_DataType _dtype; public string name { get; set; } - public VariableScope() + public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT) { _reuse = _ReuseMode.AUTO_REUSE; + _dtype = dtype; } public RefVariable get_variable(_VariableStore var_store, string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool? trainable = null, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation= VariableAggregation.NONE) { string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; - return Python.with(new ops.name_scope(""), scope => + return Python.with(new ops.name_scope(""), scope => { if (dtype == TF_DataType.DtInvalid) dtype = _dtype; - return var_store.get_variable(full_name); - + return var_store.get_variable(full_name, + shape: shape, + dtype: dtype, + initializer: initializer, + trainable: trainable, + synchronization: synchronization, + aggregation: aggregation); }); - } } } diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 27f16c55..2c22f25c 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -24,7 +24,7 @@ namespace Tensorflow TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, - bool trainable = false, + bool? trainable = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) @@ -44,14 +44,23 @@ namespace Tensorflow private RefVariable _true_getter(string name, TensorShape shape = null, - TF_DataType dtype = TF_DataType.DtInvalid, + TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, - bool trainable = false, + bool? trainable = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { - return _get_single_variable(name: name); + bool is_scalar = shape.NDim == 0; + + return _get_single_variable(name: name, + shape: shape, + dtype: dtype, + initializer: initializer, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); } private RefVariable _get_single_variable(string name, @@ -59,11 +68,14 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, bool reuse = false, - bool trainable = false, + bool? trainable = null, bool validate_shape = false, + bool? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { + bool initializing_from_value = false; + if (_vars.ContainsKey(name)) { if (!reuse) @@ -74,7 +86,35 @@ namespace Tensorflow throw new NotImplementedException("_get_single_variable"); } - throw new NotImplementedException("_get_single_variable"); + Tensor init_val = null; + ops.init_scope(); + { + if (initializing_from_value) + { + + } + else + { + init_val = initializer.call(shape, dtype); + var variable_dtype = dtype.as_base_dtype(); + } + } + + // Create the variable. + if (use_resource == null) + use_resource = false; + + var v = variable_scope.default_variable_creator(init_val, + name: name, + trainable: trainable, + dtype: TF_DataType.DtInvalid, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + + _vars[name] = v; + + return v; } } } diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs index 6515399a..2e7eefec 100644 --- a/src/TensorFlowNET.Core/Variables/tf.variable.cs +++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs @@ -12,14 +12,22 @@ namespace Tensorflow return variables.variables_initializer(g.ToArray()); } - public static RefVariable get_variable(string name, - TensorShape shape = null, + public static RefVariable get_variable(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, + bool? trainable = null, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) { + var scope = variable_scope.get_variable_scope(); var store = variable_scope._get_default_variable_store(); - return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); + return scope.get_variable(store, + name, + shape: shape, + dtype: dtype, + initializer: initializer, + trainable: trainable); } } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index b7d2662a..6e8f5c38 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -10,9 +10,16 @@ namespace Tensorflow public static string _VARSCOPESTORE_KEY = "__varscope"; public static bool _DEFAULT_USE_RESOURCE = false; - public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO) + public static RefVariable default_variable_creator(object initial_value, + string name = "", + bool? trainable = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool validate_shape = false, + bool ? use_resource = null, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) { - var trainable = _get_trainable_value(synchronization); + trainable = _get_trainable_value(synchronization, trainable); if (!use_resource.HasValue) { use_resource = get_variable_scope().use_resource; @@ -77,18 +84,22 @@ namespace Tensorflow return ret; } - public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) + public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) { if (synchronization == VariableSynchronization.ON_READ) { - if (trainable) + if (trainable.Value) throw new ValueError("Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."); } + else if (!trainable.HasValue) + { + trainable = true; + } - return trainable; + return trainable.Value; } } }