From a7450215ad182b9ddd1fb4315e057a6d4ccdb2e8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 18 Jun 2019 07:26:20 -0500 Subject: [PATCH] add SlotCreator class. --- .../Operations/resource_variable_ops.cs | 5 ++ src/TensorFlowNET.Core/Train/AdamOptimizer.cs | 14 ++++ src/TensorFlowNET.Core/Train/Optimizer.cs | 74 ++++++++++++++++- src/TensorFlowNET.Core/Train/SlotCreator.cs | 81 +++++++++++++++++++ src/TensorFlowNET.Core/Train/Trackable.cs | 22 +++++ .../Variables/variable_scope.py.cs | 6 +- .../control_flow_ops_test/SwitchTestCase.cs | 4 +- 7 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Core/Train/SlotCreator.cs diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index f8ecf7b9..eb48c5bc 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -16,5 +16,10 @@ namespace Tensorflow value_tensor, name: name); } + + public static bool is_resource_variable(VariableV1 var) + { + return var is ResourceVariable; + } } } diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 8b14bf50..9d3cab19 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Framework; using static Tensorflow.Python; @@ -63,6 +64,19 @@ namespace Tensorflow.Train return control_flow_ops.group(new[] { var_update, m_t, v_t }); } + protected override void _create_slots(RefVariable[] var_list) + { + var first_var = var_list.OrderBy(x => x.name).First(); + _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); + _create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var); + + // Create slots for the first and second moments. + foreach(var v in var_list) + { + _zero_slot(v, "m", Name); + } + } + private (RefVariable, RefVariable) _get_beta_accumulators() { ops.init_scope(); diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 9284d5c6..cc61c48b 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Framework; +using Tensorflow.Train; using static Tensorflow.Python; namespace Tensorflow @@ -13,7 +14,7 @@ namespace Tensorflow /// class directly, but instead instantiate one of its subclasses such as /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. /// - public abstract class Optimizer + public abstract class Optimizer : Trackable { // Values for gate_gradients. public static int GATE_NONE = 0; @@ -27,6 +28,7 @@ namespace Tensorflow public Dictionary> _slots; public Dictionary _non_slot_dict; public Dictionary _deferred_slot_restorations; + SlotCreator slot_creator = new SlotCreator(); public Optimizer(float learning_rate, bool use_locking, string name = null) { @@ -187,9 +189,49 @@ namespace Tensorflow }); } - private void _create_slots(RefVariable[] var_list) + /// + /// Create the beta1 and beta2 accumulators on the same device as the first + /// variable. Sort the var_list to make sure this device is consistent across + /// workers (these need to go on the same PS, otherwise some updates are + /// silently ignored). + /// + /// + protected virtual void _create_slots(RefVariable[] var_list) { + + } + /// + /// Add an extra variable, not associated with a slot. + /// + /// + /// + /// + protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) + { + // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. + var graph = colocate_with.graph; + var key = $"{name}.{graph.graph_key}"; + var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; + if(v == null) + { + _maybe_initialize_trackable(); + v = variable_scope.default_variable_creator( + initial_value, name: name, trainable: false, + use_resource: resource_variable_ops.is_resource_variable( + colocate_with)); + + // Restore this variable by name if necessary, but don't add a + // Trackable dependency. Optimizers return the current graph's + // non-slot variables from _checkpoint_dependencies explicitly rather + // than unconditionally adding dependencies (since there may be multiple + // non-slot variables with the same name in different graphs, trying to + // save all of them would result in errors). + _handle_deferred_dependencies(name, v); + _non_slot_dict[key] = v; + } + + return v; } public virtual Operation _finish(Operation[] update_ops, string name_scope) @@ -341,5 +383,33 @@ namespace Tensorflow { return param; } + + /// + /// Find or create a slot initialized with 0.0. + /// + /// + /// + /// + /// + protected RefVariable _zero_slot(RefVariable var, string slot_name, string op_name) + { + var named_slots = _slot_dict(slot_name); + if (!named_slots.ContainsKey(_var_key(var))) + { + var new_slot_variable = slot_creator.create_zeros_slot(var, op_name); + } + return named_slots[_var_key(var)]; + } + + protected Dictionary _slot_dict(string slot_name) + { + var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null; + if(named_slots == null) + { + _slots[slot_name] = new Dictionary(); + } + + return named_slots; + } } } diff --git a/src/TensorFlowNET.Core/Train/SlotCreator.cs b/src/TensorFlowNET.Core/Train/SlotCreator.cs new file mode 100644 index 00000000..1ee8f774 --- /dev/null +++ b/src/TensorFlowNET.Core/Train/SlotCreator.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Initializers; +using static Tensorflow.Python; + +namespace Tensorflow.Train +{ + public class SlotCreator + { + /// + /// Create a slot initialized to 0 with same shape as the primary object. + /// + /// + /// + /// + /// + /// + public RefVariable create_zeros_slot(RefVariable primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) + { + if (dtype == TF_DataType.DtInvalid) + dtype = primary.dtype; + var slot_shape = primary.shape; + if (slot_shape.is_fully_defined()) + { + var initializer = new Zeros(); + return create_slot_with_initializer( + primary, initializer, slot_shape, dtype, name, + colocate_with_primary: colocate_with_primary); + } + else + { + throw new NotImplementedException("create_zeros_slot is not fully defined."); + } + } + + /// + /// Creates a slot initialized using an `Initializer`. + /// + /// + public RefVariable create_slot_with_initializer(RefVariable primary, IInitializer initializer, TensorShape shape, + TF_DataType dtype, string name, bool colocate_with_primary = true) + { + var validate_shape = shape.is_fully_defined(); + var prefix = primary.op.name; + return with(new variable_scope(prefix + "/" + name), delegate + { + return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype); + }); + } + + /// + /// Helper function for creating a slot variable. + /// + /// + /// + /// + /// + /// + /// + /// + private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape, + TensorShape shape, TF_DataType dtype) + { + bool use_resource = primary is RefVariable; + if (resource_variable_ops.is_resource_variable(primary)) + use_resource = true; + + var slot = variable_scope.get_variable( + scope, + initializer: val, + trainable: false, + use_resource: use_resource, + shape: shape, + dtype: dtype, + validate_shape: validate_shape); + + return slot; + } + } +} diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs index c16304a9..c98b2116 100644 --- a/src/TensorFlowNET.Core/Train/Trackable.cs +++ b/src/TensorFlowNET.Core/Train/Trackable.cs @@ -6,6 +6,8 @@ namespace Tensorflow.Train { public abstract class Trackable { + protected int _self_update_uid; + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -32,9 +34,29 @@ namespace Tensorflow.Train return new_variable; } + /// + /// Pop and load any deferred checkpoint restores into `trackable`. + /// + /// + /// + protected void _handle_deferred_dependencies(string name, RefVariable trackable) + { + _maybe_initialize_trackable(); + // TODO + } + protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) { return checkpointable; } + + /// + /// Initialize dependency management. + /// + protected void _maybe_initialize_trackable() + { + // _self_unconditional_checkpoint_dependencies = [] + _self_update_uid = -1; + } } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index c972ae99..0482daa6 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -270,7 +270,11 @@ namespace Tensorflow } // TODO for Switch/Case - public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource) + public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool trainable = false, + bool validate_shape = true) { throw new NotImplementedException(); } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs index 0e95fdc8..1b81fc21 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs @@ -16,8 +16,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test public void testResourceReadInLoop() { - var embedding_matrix = variable_scope.get_variable( - "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); + //var embedding_matrix = variable_scope.get_variable( + //"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); Tensor cond(Tensor it, Tensor _) {