| @@ -16,5 +16,10 @@ namespace Tensorflow | |||
| value_tensor, | |||
| name: name); | |||
| } | |||
| public static bool is_resource_variable(VariableV1 var) | |||
| { | |||
| return var is ResourceVariable; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Python; | |||
| @@ -63,6 +64,19 @@ namespace Tensorflow.Train | |||
| return control_flow_ops.group(new[] { var_update, m_t, v_t }); | |||
| } | |||
| protected override void _create_slots(RefVariable[] var_list) | |||
| { | |||
| var first_var = var_list.OrderBy(x => x.name).First(); | |||
| _create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var); | |||
| _create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var); | |||
| // Create slots for the first and second moments. | |||
| foreach(var v in var_list) | |||
| { | |||
| _zero_slot(v, "m", Name); | |||
| } | |||
| } | |||
| private (RefVariable, RefVariable) _get_beta_accumulators() | |||
| { | |||
| ops.init_scope(); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Python; | |||
| namespace Tensorflow | |||
| @@ -13,7 +14,7 @@ namespace Tensorflow | |||
| /// class directly, but instead instantiate one of its subclasses such as | |||
| /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. | |||
| /// </summary> | |||
| public abstract class Optimizer | |||
| public abstract class Optimizer : Trackable | |||
| { | |||
| // Values for gate_gradients. | |||
| public static int GATE_NONE = 0; | |||
| @@ -27,6 +28,7 @@ namespace Tensorflow | |||
| public Dictionary<string, Dictionary<string, RefVariable>> _slots; | |||
| public Dictionary<string, RefVariable> _non_slot_dict; | |||
| public Dictionary<string, object> _deferred_slot_restorations; | |||
| SlotCreator slot_creator = new SlotCreator(); | |||
| public Optimizer(float learning_rate, bool use_locking, string name = null) | |||
| { | |||
| @@ -187,9 +189,49 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| private void _create_slots(RefVariable[] var_list) | |||
| /// <summary> | |||
| /// Create the beta1 and beta2 accumulators on the same device as the first | |||
| /// variable. Sort the var_list to make sure this device is consistent across | |||
| /// workers (these need to go on the same PS, otherwise some updates are | |||
| /// silently ignored). | |||
| /// </summary> | |||
| /// <param name="var_list"></param> | |||
| protected virtual void _create_slots(RefVariable[] var_list) | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Add an extra variable, not associated with a slot. | |||
| /// </summary> | |||
| /// <param name="initial_value"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="colocate_with"></param> | |||
| protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with) | |||
| { | |||
| // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. | |||
| var graph = colocate_with.graph; | |||
| var key = $"{name}.{graph.graph_key}"; | |||
| var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; | |||
| if(v == null) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| v = variable_scope.default_variable_creator( | |||
| initial_value, name: name, trainable: false, | |||
| use_resource: resource_variable_ops.is_resource_variable( | |||
| colocate_with)); | |||
| // Restore this variable by name if necessary, but don't add a | |||
| // Trackable dependency. Optimizers return the current graph's | |||
| // non-slot variables from _checkpoint_dependencies explicitly rather | |||
| // than unconditionally adding dependencies (since there may be multiple | |||
| // non-slot variables with the same name in different graphs, trying to | |||
| // save all of them would result in errors). | |||
| _handle_deferred_dependencies(name, v); | |||
| _non_slot_dict[key] = v; | |||
| } | |||
| return v; | |||
| } | |||
| public virtual Operation _finish(Operation[] update_ops, string name_scope) | |||
| @@ -341,5 +383,33 @@ namespace Tensorflow | |||
| { | |||
| return param; | |||
| } | |||
| /// <summary> | |||
| /// Find or create a slot initialized with 0.0. | |||
| /// </summary> | |||
| /// <param name="var"></param> | |||
| /// <param name="slot_name"></param> | |||
| /// <param name="op_name"></param> | |||
| /// <returns></returns> | |||
| protected RefVariable _zero_slot(RefVariable var, string slot_name, string op_name) | |||
| { | |||
| var named_slots = _slot_dict(slot_name); | |||
| if (!named_slots.ContainsKey(_var_key(var))) | |||
| { | |||
| var new_slot_variable = slot_creator.create_zeros_slot(var, op_name); | |||
| } | |||
| return named_slots[_var_key(var)]; | |||
| } | |||
| protected Dictionary<string, RefVariable> _slot_dict(string slot_name) | |||
| { | |||
| var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null; | |||
| if(named_slots == null) | |||
| { | |||
| _slots[slot_name] = new Dictionary<string, RefVariable>(); | |||
| } | |||
| return named_slots; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,81 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Operations.Initializers; | |||
| using static Tensorflow.Python; | |||
| namespace Tensorflow.Train | |||
| { | |||
| public class SlotCreator | |||
| { | |||
| /// <summary> | |||
| /// Create a slot initialized to 0 with same shape as the primary object. | |||
| /// </summary> | |||
| /// <param name="primary"></param> | |||
| /// <param name="name"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <param name="colocate_with_primary"></param> | |||
| /// <returns></returns> | |||
| public RefVariable create_zeros_slot(RefVariable primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = primary.dtype; | |||
| var slot_shape = primary.shape; | |||
| if (slot_shape.is_fully_defined()) | |||
| { | |||
| var initializer = new Zeros(); | |||
| return create_slot_with_initializer( | |||
| primary, initializer, slot_shape, dtype, name, | |||
| colocate_with_primary: colocate_with_primary); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("create_zeros_slot is not fully defined."); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Creates a slot initialized using an `Initializer`. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public RefVariable create_slot_with_initializer(RefVariable primary, IInitializer initializer, TensorShape shape, | |||
| TF_DataType dtype, string name, bool colocate_with_primary = true) | |||
| { | |||
| var validate_shape = shape.is_fully_defined(); | |||
| var prefix = primary.op.name; | |||
| return with(new variable_scope(prefix + "/" + name), delegate | |||
| { | |||
| return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// Helper function for creating a slot variable. | |||
| /// </summary> | |||
| /// <param name="primary"></param> | |||
| /// <param name="val"></param> | |||
| /// <param name="scope"></param> | |||
| /// <param name="validate_shape"></param> | |||
| /// <param name="shape"></param> | |||
| /// <param name="dtype"></param> | |||
| /// <returns></returns> | |||
| private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape, | |||
| TensorShape shape, TF_DataType dtype) | |||
| { | |||
| bool use_resource = primary is RefVariable; | |||
| if (resource_variable_ops.is_resource_variable(primary)) | |||
| use_resource = true; | |||
| var slot = variable_scope.get_variable( | |||
| scope, | |||
| initializer: val, | |||
| trainable: false, | |||
| use_resource: use_resource, | |||
| shape: shape, | |||
| dtype: dtype, | |||
| validate_shape: validate_shape); | |||
| return slot; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,6 +6,8 @@ namespace Tensorflow.Train | |||
| { | |||
| public abstract class Trackable | |||
| { | |||
| protected int _self_update_uid; | |||
| /// <summary> | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| /// </summary> | |||
| @@ -32,9 +34,29 @@ namespace Tensorflow.Train | |||
| return new_variable; | |||
| } | |||
| /// <summary> | |||
| /// Pop and load any deferred checkpoint restores into `trackable`. | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <param name="trackable"></param> | |||
| protected void _handle_deferred_dependencies(string name, RefVariable trackable) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| // TODO | |||
| } | |||
| protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false) | |||
| { | |||
| return checkpointable; | |||
| } | |||
| /// <summary> | |||
| /// Initialize dependency management. | |||
| /// </summary> | |||
| protected void _maybe_initialize_trackable() | |||
| { | |||
| // _self_unconditional_checkpoint_dependencies = [] | |||
| _self_update_uid = -1; | |||
| } | |||
| } | |||
| } | |||
| @@ -270,7 +270,11 @@ namespace Tensorflow | |||
| } | |||
| // TODO for Switch/Case | |||
| public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource) | |||
| public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource, | |||
| TensorShape shape = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool trainable = false, | |||
| bool validate_shape = true) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| @@ -16,8 +16,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| public void testResourceReadInLoop() | |||
| { | |||
| var embedding_matrix = variable_scope.get_variable( | |||
| "embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); | |||
| //var embedding_matrix = variable_scope.get_variable( | |||
| //"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); | |||
| Tensor cond(Tensor it, Tensor _) | |||
| { | |||