From 45318831148ee931dbb581c59e702c297e6a4d5e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 10 Oct 2019 07:08:46 -0500 Subject: [PATCH] change RefVariable to VariableV1 --- .../Keras/Utils/base_layer_utils.cs | 5 +++- src/TensorFlowNET.Core/Train/AdamOptimizer.cs | 4 +-- src/TensorFlowNET.Core/Train/Optimizer.cs | 10 +++---- src/TensorFlowNET.Core/Train/Trackable.cs | 2 +- .../Variables/ResourceVariable.cs | 28 +++++++++++++++++-- .../Variables/VariableScope.cs | 2 +- .../Variables/VariableV1.cs | 2 +- .../Variables/_VariableStore.cs | 8 +++--- .../Variables/variable_scope.py.cs | 11 ++++++-- src/TensorFlowNET.Core/tensorflow.cs | 8 ++++-- 10 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index c7839b04..477cc56f 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -46,7 +46,10 @@ namespace Tensorflow.Keras.Utils Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); var variable_dtype = dtype.as_base_dtype(); - var v = tf.VariableV1(init_val); + var v = tf.VariableV1(init_val, + use_resource: use_resource, + dtype: dtype, + shape: shape); return v; } diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index faf6fec2..39228691 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -143,8 +143,8 @@ namespace Tensorflow.Train { ops.init_scope(); var graph = ops.get_default_graph(); - return (_get_non_slot_variable("beta1_power", graph: graph), - _get_non_slot_variable("beta2_power", graph: graph)); + return (_get_non_slot_variable("beta1_power", graph: graph) as RefVariable, + _get_non_slot_variable("beta2_power", graph: graph) as RefVariable); } public override void _prepare() diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index e0040ecf..524a0e34 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -44,7 +44,7 @@ namespace Tensorflow public Tensor LearningRateTensor => _lr_t; public bool _use_locking; public Dictionary> _slots; - public Dictionary _non_slot_dict; + public Dictionary _non_slot_dict; public Dictionary _deferred_slot_restorations; SlotCreator slot_creator = new SlotCreator(); @@ -58,7 +58,7 @@ namespace Tensorflow _lr = learning_rate; // Dictionary of slots. _slots = new Dictionary>(); - _non_slot_dict = new Dictionary(); + _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -72,7 +72,7 @@ namespace Tensorflow _lr_t = learning_rate; // Dictionary of slots. _slots = new Dictionary>(); - _non_slot_dict = new Dictionary(); + _non_slot_dict = new Dictionary(); _deferred_slot_restorations = new Dictionary(); } @@ -239,7 +239,7 @@ namespace Tensorflow /// /// /// - protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) + protected VariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) { // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. var graph = colocate_with.graph; @@ -333,7 +333,7 @@ namespace Tensorflow return $"{var.op.graph.graph_key}.{var.op.name}"; } - protected RefVariable _get_non_slot_variable(string name, Graph graph = null) + protected VariableV1 _get_non_slot_variable(string name, Graph graph = null) { var key = $"{name}.{graph.graph_key}"; var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs index 975546f7..36083d84 100644 --- a/src/TensorFlowNET.Core/Train/Trackable.cs +++ b/src/TensorFlowNET.Core/Train/Trackable.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Train /// /// /// - protected void _handle_deferred_dependencies(string name, RefVariable trackable) + protected void _handle_deferred_dependencies(string name, VariableV1 trackable) { _maybe_initialize_trackable(); // TODO diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 85d2ca56..b548a50f 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; +using static Tensorflow.Binding; namespace Tensorflow { @@ -53,7 +54,8 @@ namespace Tensorflow string name = null, VariableDef variable_def = null, TF_DataType dtype = TF_DataType.DtInvalid, - string import_scope = "") : base(initial_value, + string import_scope = "", + TensorShape shape = null) : base(initial_value, trainable, collections, validate_shape, @@ -69,11 +71,31 @@ namespace Tensorflow } else { - throw new NotImplementedException("ResourceVariable _init_from_args"); - //_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype); + _init_from_args(initial_value: initial_value, + trainable: trainable, + collections: collections, + caching_device: caching_device, + name: name, + dtype: dtype, + shape: shape); } } + private void _init_from_args(object initial_value = null, + bool trainable = true, + List collections = null, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + TensorShape shape = null) + { + var init_from_fn = initial_value.GetType().Name == "Func`1"; + if(collections == null) + collections = new List() { tf.GraphKeys.GLOBAL_VARIABLES }; + + throw new NotImplementedException(""); + } + private void _init_from_proto(VariableDef variable_def, string import_scope = null) { _in_graph_mode = true; diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index ad7750a1..52766e4f 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -71,7 +71,7 @@ namespace Tensorflow trainable: trainable, collections: collections, synchronization: synchronization, - aggregation: aggregation); + aggregation: aggregation) as RefVariable; }); } } diff --git a/src/TensorFlowNET.Core/Variables/VariableV1.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs index 48e1952c..e1247f8d 100644 --- a/src/TensorFlowNET.Core/Variables/VariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -28,7 +28,7 @@ namespace Tensorflow /// the variable are fixed. The value can be changed using one of the assign methods. /// https://tensorflow.org/guide/variables /// - public class VariableV1 + public abstract class VariableV1 { public virtual string name { get; } public virtual Tensor graph_element { get; } diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index d0fbf161..5b706a95 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -36,7 +36,7 @@ namespace Tensorflow _store_eager_variables = false; } - public RefVariable get_variable(string name, + public VariableV1 get_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, object initializer = null, // IInitializer or Tensor @@ -61,7 +61,7 @@ namespace Tensorflow aggregation: aggregation); } - private RefVariable _true_getter(string name, + private VariableV1 _true_getter(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, object initializer = null, @@ -110,7 +110,7 @@ namespace Tensorflow } } - private RefVariable _get_single_variable(string name, + private VariableV1 _get_single_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, @@ -136,7 +136,7 @@ namespace Tensorflow throw new NotImplementedException("_get_single_variable"); } - RefVariable v = null; + VariableV1 v = null; // Create the tensor to initialize the variable with default value. if (initializer == null) { diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 4f357b12..f4a01054 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -172,11 +172,12 @@ namespace Tensorflow return $"{prefix}_{idx}"; } - public static RefVariable default_variable_creator(object initial_value, + public static VariableV1 default_variable_creator(object initial_value, string name = null, bool? trainable = null, List collections = null, TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, bool validate_shape = false, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.Auto, @@ -193,7 +194,13 @@ namespace Tensorflow if (use_resource.Value) { - throw new NotImplementedException(); + return new ResourceVariable(initial_value, + trainable: trainable.Value, + validate_shape: validate_shape, + collections: collections, + name: name, + dtype: dtype, + shape: shape); } else { diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 4b260632..39fd2ac9 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -55,7 +55,7 @@ namespace Tensorflow trainable: trainable, validate_shape: validate_shape, name: name, - dtype: dtype); + dtype: dtype) as RefVariable; } public VariableV1 VariableV1(T data, @@ -63,14 +63,16 @@ namespace Tensorflow bool validate_shape = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, - bool use_resource = false) + bool use_resource = false, + int[] shape = null) { return Tensorflow.variable_scope.default_variable_creator(data, trainable: trainable, validate_shape: validate_shape, name: name, dtype: dtype, - use_resource: use_resource); + use_resource: use_resource, + shape: shape); } public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)