diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 54b252fb..809dde46 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -195,6 +195,17 @@ namespace Tensorflow return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; } + public static IEnumerable<(T1, T2)> zip((T1, T1) t1, (T2, T2) t2) + { + for (int i = 0; i < 2; i++) + { + if (i == 0) + yield return (t1.Item1, t2.Item1); + else + yield return (t1.Item2, t2.Item2); + } + } + public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2) where T : unmanaged { diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 10a37e53..2d905410 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Keras.Utils; using Tensorflow.Train; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Optimizers { @@ -10,15 +13,119 @@ namespace Tensorflow.Keras.Optimizers /// public class OptimizerV2 : Trackable, IOptimizer { + protected bool _hypers_created; + protected virtual string _name { get; } + + ResourceVariable _iterations; + List _weight = new List(); + Dictionary _hyper = new Dictionary(); + Dictionary _hyper_variables = new Dictionary(); + protected bool _momentum; + public OptimizerV2() : base() { } - public void apply_gradients((Tensor, Tensor) gradients, - (ResourceVariable, ResourceVariable) vars) + public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars) + { + var var_list = grads_and_vars.Select(x => x.Item2).ToArray(); + tf_with(ops.name_scope(_name), delegate + { + ops.init_scope(); + _create_all_weights(var_list); + if (grads_and_vars == null || grads_and_vars.Count() == 0) + return control_flow_ops.no_op(); + + //var apply_state = + _prepare(var_list); + + return control_flow_ops.no_op(); + }); + } + + void _prepare(ResourceVariable[] var_list) + { + foreach(var variable in var_list) + { + + } + } + + void _create_all_weights(ResourceVariable[] var_list) + { + if(_iterations == null) + { + _iterations = add_weight("iter", + shape: new int[0], + dtype: TF_DataType.TF_INT64, + trainable: false, + aggregation: VariableAggregation.OnlyFirstReplica); + _weight.Add(_iterations); + } + + _create_hypers(); + _create_slots(var_list); + } + + protected void _set_hyper(string name, float value) { + _hyper[name] = value; + } + + void _create_hypers() + { + if (_hypers_created) + return; + foreach (var dict in _hyper) + { + var name = dict.Key; + var value = dict.Value; + _hyper_variables[name] = add_weight( + name, + shape: new int[0], + trainable: false, + initializer: tf.constant_initializer(value), + aggregation: VariableAggregation.OnlyFirstReplica); + } + _hypers_created = true; + } + + void _create_slots(ResourceVariable[] var_list) + { + if(_momentum) + { + /*for var in var_list: + self.add_slot(var, "momentum")*/ + } + } + + ResourceVariable add_weight(string name, + TensorShape shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + bool trainable = false, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) + { + if (initializer == null) + initializer = tf.zeros_initializer; + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + var variable = _add_variable_with_custom_getter(name: name, + shape: shape, + getter: base_layer_utils.make_variable, + dtype: dtype, + overwrite: true, + initializer: initializer, + trainable: trainable, + use_resource: true, + synchronization: synchronization, + aggregation: aggregation); + return variable as ResourceVariable; } } } diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs index 2cef9fe8..975854a6 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs @@ -6,9 +6,23 @@ namespace Tensorflow.Keras.Optimizers { public class SGD : OptimizerV2 { - public SGD(float learning_rate) : base() + protected override string _name => "SGD"; + + bool nesterov; + + public SGD(float learning_rate, + float momentum = 0.0f, + bool nesterov = false, + float decay = 0.0f) : base() { + _set_hyper("learning_rate", learning_rate); + _set_hyper("decay", decay); + + _momentum = momentum > 0; + + _set_hyper("momentum", momentum); + nesterov = nesterov; } } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index 69862ccb..ed672912 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Utils Func init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); var variable_dtype = dtype.as_base_dtype(); - var v = tf.Variable(init_val, + var v = tf.Variable(init_val, dtype: dtype, shape: shape, name: name); diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 3882646c..0c5b06d3 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -140,6 +140,8 @@ namespace Tensorflow return new EagerTensor(val, ctx.device_name); case int[,] val: return new EagerTensor(val, ctx.device_name); + case long val: + return new EagerTensor(val, ctx.device_name); case float val: return new EagerTensor(val, ctx.device_name); case float[,] val: diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index d9aeb65b..332e1764 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using static Tensorflow.Binding; namespace Tensorflow.Train { @@ -32,10 +33,20 @@ namespace Tensorflow.Train IInitializer initializer = null, Func getter = null, bool overwrite = false, - bool trainable = false) + bool trainable = false, + bool use_resource = false, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None) { - var checkpoint_initializer = true; - var new_variable = getter(name, shape, dtype, initializer, trainable); + ops.init_scope(); + IInitializer checkpoint_initializer = null; + if (tf.context.executing_eagerly()) + ; + else + checkpoint_initializer = null; + + IVariableV1 new_variable; + new_variable = getter(name, shape, dtype, initializer, trainable); // If we set an initializer and the variable processed it, tracking will not // assign again. It will add this variable to our dependencies, and if there