Browse Source

OptimizerV2 partially work.

tags/v0.20
Oceania2018 5 years ago
parent
commit
88aa2eb2e0
6 changed files with 152 additions and 7 deletions
  1. +11
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +109
    -2
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  3. +15
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  5. +2
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  6. +14
    -3
      src/TensorFlowNET.Core/Training/Trackable.cs

+ 11
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -195,6 +195,17 @@ namespace Tensorflow
return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds; return (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds;
} }


public static IEnumerable<(T1, T2)> zip<T1, T2>((T1, T1) t1, (T2, T2) t2)
{
for (int i = 0; i < 2; i++)
{
if (i == 0)
yield return (t1.Item1, t2.Item1);
else
yield return (t1.Item2, t2.Item2);
}
}

public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2) public static IEnumerable<(T, T)> zip<T>(NDArray t1, NDArray t2)
where T : unmanaged where T : unmanaged
{ {


+ 109
- 2
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -1,7 +1,10 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Utils;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding;


namespace Tensorflow.Keras.Optimizers namespace Tensorflow.Keras.Optimizers
{ {
@@ -10,15 +13,119 @@ namespace Tensorflow.Keras.Optimizers
/// </summary> /// </summary>
public class OptimizerV2 : Trackable, IOptimizer public class OptimizerV2 : Trackable, IOptimizer
{ {
protected bool _hypers_created;
protected virtual string _name { get; }

ResourceVariable _iterations;
List<ResourceVariable> _weight = new List<ResourceVariable>();
Dictionary<string, float> _hyper = new Dictionary<string, float>();
Dictionary<string, ResourceVariable> _hyper_variables = new Dictionary<string, ResourceVariable>();
protected bool _momentum;

public OptimizerV2() : base() public OptimizerV2() : base()
{ {


} }


public void apply_gradients((Tensor, Tensor) gradients,
(ResourceVariable, ResourceVariable) vars)
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars)
{
var var_list = grads_and_vars.Select(x => x.Item2).ToArray();
tf_with(ops.name_scope(_name), delegate
{
ops.init_scope();
_create_all_weights(var_list);
if (grads_and_vars == null || grads_and_vars.Count() == 0)
return control_flow_ops.no_op();

//var apply_state =
_prepare(var_list);

return control_flow_ops.no_op();
});
}

void _prepare(ResourceVariable[] var_list)
{
foreach(var variable in var_list)
{

}
}

void _create_all_weights(ResourceVariable[] var_list)
{
if(_iterations == null)
{
_iterations = add_weight("iter",
shape: new int[0],
dtype: TF_DataType.TF_INT64,
trainable: false,
aggregation: VariableAggregation.OnlyFirstReplica);
_weight.Add(_iterations);
}

_create_hypers();
_create_slots(var_list);
}

protected void _set_hyper(string name, float value)
{ {
_hyper[name] = value;
}

void _create_hypers()
{
if (_hypers_created)
return;
foreach (var dict in _hyper)
{
var name = dict.Key;
var value = dict.Value;
_hyper_variables[name] = add_weight(
name,
shape: new int[0],
trainable: false,
initializer: tf.constant_initializer(value),
aggregation: VariableAggregation.OnlyFirstReplica);
}
_hypers_created = true;
}

void _create_slots(ResourceVariable[] var_list)
{
if(_momentum)
{
/*for var in var_list:
self.add_slot(var, "momentum")*/
}
}

ResourceVariable add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = false,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
{
if (initializer == null)
initializer = tf.zeros_initializer;

if (dtype == TF_DataType.DtInvalid)
dtype = TF_DataType.TF_FLOAT;

var variable = _add_variable_with_custom_getter(name: name,
shape: shape,
getter: base_layer_utils.make_variable,
dtype: dtype,
overwrite: true,
initializer: initializer,
trainable: trainable,
use_resource: true,
synchronization: synchronization,
aggregation: aggregation);


return variable as ResourceVariable;
} }
} }
} }

+ 15
- 1
src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs View File

@@ -6,9 +6,23 @@ namespace Tensorflow.Keras.Optimizers
{ {
public class SGD : OptimizerV2 public class SGD : OptimizerV2
{ {
public SGD(float learning_rate) : base()
protected override string _name => "SGD";
bool nesterov;

public SGD(float learning_rate,
float momentum = 0.0f,
bool nesterov = false,
float decay = 0.0f) : base()
{ {
_set_hyper("learning_rate", learning_rate);
_set_hyper("decay", decay);

_momentum = momentum > 0;

_set_hyper("momentum", momentum);


nesterov = nesterov;
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Utils
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);


var variable_dtype = dtype.as_base_dtype(); var variable_dtype = dtype.as_base_dtype();
var v = tf.Variable(init_val,
var v = tf.Variable(init_val,
dtype: dtype, dtype: dtype,
shape: shape, shape: shape,
name: name); name: name);


+ 2
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -140,6 +140,8 @@ namespace Tensorflow
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case int[,] val: case int[,] val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case long val:
return new EagerTensor(val, ctx.device_name);
case float val: case float val:
return new EagerTensor(val, ctx.device_name); return new EagerTensor(val, ctx.device_name);
case float[,] val: case float[,] val:


+ 14
- 3
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using static Tensorflow.Binding;


namespace Tensorflow.Train namespace Tensorflow.Train
{ {
@@ -32,10 +33,20 @@ namespace Tensorflow.Train
IInitializer initializer = null, IInitializer initializer = null,
Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null, Func<string, int[], TF_DataType, IInitializer, bool, IVariableV1> getter = null,
bool overwrite = false, bool overwrite = false,
bool trainable = false)
bool trainable = false,
bool use_resource = false,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
{ {
var checkpoint_initializer = true;
var new_variable = getter(name, shape, dtype, initializer, trainable);
ops.init_scope();
IInitializer checkpoint_initializer = null;
if (tf.context.executing_eagerly())
;
else
checkpoint_initializer = null;

IVariableV1 new_variable;
new_variable = getter(name, shape, dtype, initializer, trainable);


// If we set an initializer and the variable processed it, tracking will not // If we set an initializer and the variable processed it, tracking will not
// assign again. It will add this variable to our dependencies, and if there // assign again. It will add this variable to our dependencies, and if there


Loading…
Cancel
Save