Browse Source

add SlotCreator class.

tags/v0.9
Oceania2018 6 years ago
parent
commit
a7450215ad
7 changed files with 201 additions and 5 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  2. +14
    -0
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  3. +72
    -2
      src/TensorFlowNET.Core/Train/Optimizer.cs
  4. +81
    -0
      src/TensorFlowNET.Core/Train/SlotCreator.cs
  5. +22
    -0
      src/TensorFlowNET.Core/Train/Trackable.cs
  6. +5
    -1
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  7. +2
    -2
      test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs

+ 5
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -16,5 +16,10 @@ namespace Tensorflow
value_tensor,
name: name);
}

public static bool is_resource_variable(VariableV1 var)
{
return var is ResourceVariable;
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Python;
@@ -63,6 +64,19 @@ namespace Tensorflow.Train
return control_flow_ops.group(new[] { var_update, m_t, v_t });
}

protected override void _create_slots(RefVariable[] var_list)
{
var first_var = var_list.OrderBy(x => x.name).First();
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var);
_create_non_slot_variable(initial_value: _beta2, name: "beta2_power", colocate_with: first_var);

// Create slots for the first and second moments.
foreach(var v in var_list)
{
_zero_slot(v, "m", Name);
}
}

private (RefVariable, RefVariable) _get_beta_accumulators()
{
ops.init_scope();


+ 72
- 2
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using Tensorflow.Train;
using static Tensorflow.Python;

namespace Tensorflow
@@ -13,7 +14,7 @@ namespace Tensorflow
/// class directly, but instead instantiate one of its subclasses such as
/// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
/// </summary>
public abstract class Optimizer
public abstract class Optimizer : Trackable
{
// Values for gate_gradients.
public static int GATE_NONE = 0;
@@ -27,6 +28,7 @@ namespace Tensorflow
public Dictionary<string, Dictionary<string, RefVariable>> _slots;
public Dictionary<string, RefVariable> _non_slot_dict;
public Dictionary<string, object> _deferred_slot_restorations;
SlotCreator slot_creator = new SlotCreator();

public Optimizer(float learning_rate, bool use_locking, string name = null)
{
@@ -187,9 +189,49 @@ namespace Tensorflow
});
}

private void _create_slots(RefVariable[] var_list)
/// <summary>
/// Create the beta1 and beta2 accumulators on the same device as the first
/// variable. Sort the var_list to make sure this device is consistent across
/// workers (these need to go on the same PS, otherwise some updates are
/// silently ignored).
/// </summary>
/// <param name="var_list"></param>
protected virtual void _create_slots(RefVariable[] var_list)
{
}

/// <summary>
/// Add an extra variable, not associated with a slot.
/// </summary>
/// <param name="initial_value"></param>
/// <param name="name"></param>
/// <param name="colocate_with"></param>
protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with)
{
// Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
var graph = colocate_with.graph;
var key = $"{name}.{graph.graph_key}";
var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;
if(v == null)
{
_maybe_initialize_trackable();
v = variable_scope.default_variable_creator(
initial_value, name: name, trainable: false,
use_resource: resource_variable_ops.is_resource_variable(
colocate_with));

// Restore this variable by name if necessary, but don't add a
// Trackable dependency. Optimizers return the current graph's
// non-slot variables from _checkpoint_dependencies explicitly rather
// than unconditionally adding dependencies (since there may be multiple
// non-slot variables with the same name in different graphs, trying to
// save all of them would result in errors).
_handle_deferred_dependencies(name, v);
_non_slot_dict[key] = v;
}

return v;
}

public virtual Operation _finish(Operation[] update_ops, string name_scope)
@@ -341,5 +383,33 @@ namespace Tensorflow
{
return param;
}

/// <summary>
/// Find or create a slot initialized with 0.0.
/// </summary>
/// <param name="var"></param>
/// <param name="slot_name"></param>
/// <param name="op_name"></param>
/// <returns></returns>
protected RefVariable _zero_slot(RefVariable var, string slot_name, string op_name)
{
var named_slots = _slot_dict(slot_name);
if (!named_slots.ContainsKey(_var_key(var)))
{
var new_slot_variable = slot_creator.create_zeros_slot(var, op_name);
}
return named_slots[_var_key(var)];
}

protected Dictionary<string, RefVariable> _slot_dict(string slot_name)
{
var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null;
if(named_slots == null)
{
_slots[slot_name] = new Dictionary<string, RefVariable>();
}

return named_slots;
}
}
}

+ 81
- 0
src/TensorFlowNET.Core/Train/SlotCreator.cs View File

@@ -0,0 +1,81 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Initializers;
using static Tensorflow.Python;

namespace Tensorflow.Train
{
public class SlotCreator
{
/// <summary>
/// Create a slot initialized to 0 with same shape as the primary object.
/// </summary>
/// <param name="primary"></param>
/// <param name="name"></param>
/// <param name="dtype"></param>
/// <param name="colocate_with_primary"></param>
/// <returns></returns>
public RefVariable create_zeros_slot(RefVariable primary, string name, TF_DataType dtype = TF_DataType.DtInvalid, bool colocate_with_primary = true)
{
if (dtype == TF_DataType.DtInvalid)
dtype = primary.dtype;
var slot_shape = primary.shape;
if (slot_shape.is_fully_defined())
{
var initializer = new Zeros();
return create_slot_with_initializer(
primary, initializer, slot_shape, dtype, name,
colocate_with_primary: colocate_with_primary);
}
else
{
throw new NotImplementedException("create_zeros_slot is not fully defined.");
}
}

/// <summary>
/// Creates a slot initialized using an `Initializer`.
/// </summary>
/// <returns></returns>
public RefVariable create_slot_with_initializer(RefVariable primary, IInitializer initializer, TensorShape shape,
TF_DataType dtype, string name, bool colocate_with_primary = true)
{
var validate_shape = shape.is_fully_defined();
var prefix = primary.op.name;
return with(new variable_scope(prefix + "/" + name), delegate
{
return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype);
});
}

/// <summary>
/// Helper function for creating a slot variable.
/// </summary>
/// <param name="primary"></param>
/// <param name="val"></param>
/// <param name="scope"></param>
/// <param name="validate_shape"></param>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <returns></returns>
private RefVariable _create_slot_var(VariableV1 primary, IInitializer val, string scope, bool validate_shape,
TensorShape shape, TF_DataType dtype)
{
bool use_resource = primary is RefVariable;
if (resource_variable_ops.is_resource_variable(primary))
use_resource = true;

var slot = variable_scope.get_variable(
scope,
initializer: val,
trainable: false,
use_resource: use_resource,
shape: shape,
dtype: dtype,
validate_shape: validate_shape);

return slot;
}
}
}

+ 22
- 0
src/TensorFlowNET.Core/Train/Trackable.cs View File

@@ -6,6 +6,8 @@ namespace Tensorflow.Train
{
public abstract class Trackable
{
protected int _self_update_uid;

/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary>
@@ -32,9 +34,29 @@ namespace Tensorflow.Train
return new_variable;
}

/// <summary>
/// Pop and load any deferred checkpoint restores into `trackable`.
/// </summary>
/// <param name="name"></param>
/// <param name="trackable"></param>
protected void _handle_deferred_dependencies(string name, RefVariable trackable)
{
_maybe_initialize_trackable();
// TODO
}

protected RefVariable _track_checkpointable(RefVariable checkpointable, string name, bool overwrite = false)
{
return checkpointable;
}

/// <summary>
/// Initialize dependency management.
/// </summary>
protected void _maybe_initialize_trackable()
{
// _self_unconditional_checkpoint_dependencies = []
_self_update_uid = -1;
}
}
}

+ 5
- 1
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -270,7 +270,11 @@ namespace Tensorflow
}

// TODO for Switch/Case
public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource)
public static RefVariable get_variable(string embeddingMatrix, IInitializer initializer, bool use_resource,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool trainable = false,
bool validate_shape = true)
{
throw new NotImplementedException();
}


+ 2
- 2
test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs View File

@@ -16,8 +16,8 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
public void testResourceReadInLoop()
{
var embedding_matrix = variable_scope.get_variable(
"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);
//var embedding_matrix = variable_scope.get_variable(
//"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);
Tensor cond(Tensor it, Tensor _)
{


Loading…
Cancel
Save