using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Keras.Optimizers { /// /// Updated base class for optimizers. /// public class OptimizerV2 : Trackable, IOptimizer { OptimizerV2Args args; protected bool _hypers_created; protected virtual string _name { get; } IVariableV1 _iterations; protected ResourceVariable iterations => _iterations as ResourceVariable; List _weights; Dictionary _hyper; Dictionary _hyper_variables; protected bool _momentum; protected float _initial_decay = 0.0f; protected bool _use_locking = true; public IVariableV1 lr => _hyper_variables["learning_rate"]; Dictionary> _slots; List _slot_names; public OptimizerV2(OptimizerV2Args args) : base() { this.args = args; _weights = new List(); _hyper = new Dictionary(); _hyper_variables = new Dictionary(); _slots = new Dictionary>(); _slot_names = new List(); _set_hyper("learning_rate", args.LearningRate); _set_hyper("decay", args.InitialDecay); } public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, string name = null, bool experimental_aggregate_gradients = true) => apply_gradients(new[] { grads_and_vars }, name: name, experimental_aggregate_gradients: experimental_aggregate_gradients); /// /// Apply gradients to variables. /// /// /// /// public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, string name = null, bool experimental_aggregate_gradients = true) { var var_list = grads_and_vars.Select(x => x.Item2).ToArray(); tf_with(ops.name_scope(_name), delegate { ops.init_scope(); _create_all_weights(var_list); if (grads_and_vars == null || grads_and_vars.Count() == 0) return control_flow_ops.no_op(); var apply_state = _prepare(var_list); // if(experimental_aggregate_gradients) { // var reduced_grads = _aggregate_gradients(grads_and_vars); _distributed_apply(grads_and_vars, name, apply_state); } return null; }); } void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary> apply_state) { _resource_apply_dense(var, grad, apply_state); // if var.constraint is not None: // with ops.control_dependencies([update_op]): // return var.assign(var.constraint(var)) } protected virtual Operation _resource_apply_dense(IVariableV1 var, Tensor grad, Dictionary> _apply_state) { throw new NotImplementedException("_resource_apply_dense"); } void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, string name, Dictionary> _apply_state) { tf_with(ops.name_scope(name, "", new { skip_on_eager = true }), delegate { foreach (var (grad, var) in grads_and_vars) { tf_with(ops.name_scope("update"), delegate { apply_grad_to_update_var(var, grad, _apply_state); }); } _iterations.assign_add(ops.convert_to_tensor(1, dtype: _iterations.dtype)); }); } public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) { return grads_and_vars.Select(x => x.Item1).ToArray(); } public Tensor[] _clip_gradients(Tensor[] grads) { return grads; } protected IVariableV1 get_slot(IVariableV1 var, string slot_name) { var slot_dict = _slots[var.UniqueId]; return slot_dict[slot_name]; } Dictionary> _prepare(IVariableV1[] var_list) { var _apply_state = new Dictionary>(); var keys = var_list.Select(x => new DeviceDType { Device = x.Device, DType = x.dtype.as_base_dtype() }).Distinct(new DeviceDType()).ToArray(); foreach (var device_dtype in keys) { _apply_state[device_dtype] = new Dictionary(); _prepare_local(device_dtype, _apply_state); } return _apply_state; } protected Dictionary _fallback_apply_state(string var_device, TF_DataType var_dtype) { throw new NotImplementedException(""); } protected virtual void _prepare_local(DeviceDType device_dtype, Dictionary> _apply_state) { if (_hyper.ContainsKey("learning_rate")) { var lr_t = array_ops.identity(_decayed_lr(device_dtype.DType)); _apply_state[device_dtype]["lr_t"] = lr_t; } } Tensor _decayed_lr(TF_DataType var_dtype) { var lr_t = _get_hyper("learning_rate", var_dtype); if (_initial_decay > 0.0f) { throw new NotImplementedException(""); } return lr_t; } protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid) { var value = _hyper_variables[name]; return math_ops.cast(value, dtype); } void _create_all_weights(IVariableV1[] var_list) { if (_iterations == null) { _iterations = add_weight("iter", shape: new int[0], dtype: TF_DataType.TF_INT64, trainable: false, aggregation: VariableAggregation.OnlyFirstReplica); _weights.Add(_iterations); } _create_hypers(); _create_slots(var_list); } protected void _set_hyper(string name, float value) { _hyper[name] = value; } void _create_hypers() { if (_hypers_created) return; foreach (var dict in _hyper) { var name = dict.Key; var value = dict.Value; _hyper_variables[name] = add_weight( name, shape: new int[0], trainable: false, initializer: tf.constant_initializer(value), aggregation: VariableAggregation.OnlyFirstReplica); } _hypers_created = true; } protected virtual void _create_slots(IVariableV1[] var_list) { if (_momentum) { /*for var in var_list: self.add_slot(var, "momentum")*/ } } protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) { if (initializer == null) initializer = tf.zeros_initializer; if (!_slot_names.Contains(slot_name)) _slot_names.append(slot_name); if (!_slots.ContainsKey(var.UniqueId)) _slots[var.UniqueId] = new Dictionary(); var slot_dict = _slots[var.UniqueId]; if (!slot_dict.ContainsKey(slot_name)) { var weight = tf.Variable(initializer, dtype: var.dtype, trainable: false, shape: var.shape, name: $"{var.Name}/{slot_name}"); slot_dict[slot_name] = weight; _weights.append(weight); return weight; } else { return slot_dict[slot_name]; } } ResourceVariable add_weight(string name, Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, bool trainable = false, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) { if (initializer == null) initializer = tf.zeros_initializer; if (dtype == TF_DataType.DtInvalid) dtype = TF_DataType.TF_FLOAT; var variable = _add_variable_with_custom_getter(new VariableArgs { Name = name, Shape = shape, Getter = base_layer_utils.make_variable, DType = dtype, Overwrite = true, Initializer = initializer, Trainable = trainable, UseResource = true, Synchronization = synchronization, Aggregation = aggregation }); return variable as ResourceVariable; } } }