Browse Source

change RefVariable to VariableV1

tags/v0.12
Oceania2018 6 years ago
parent
commit
4531883114
10 changed files with 57 additions and 23 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Train/AdamOptimizer.cs
  3. +5
    -5
      src/TensorFlowNET.Core/Train/Optimizer.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Train/Trackable.cs
  5. +25
    -3
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Variables/VariableV1.cs
  8. +4
    -4
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  9. +9
    -2
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  10. +5
    -3
      src/TensorFlowNET.Core/tensorflow.cs

+ 4
- 1
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -46,7 +46,10 @@ namespace Tensorflow.Keras.Utils
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype); Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);


var variable_dtype = dtype.as_base_dtype(); var variable_dtype = dtype.as_base_dtype();
var v = tf.VariableV1(init_val);
var v = tf.VariableV1(init_val,
use_resource: use_resource,
dtype: dtype,
shape: shape);


return v; return v;
} }


+ 2
- 2
src/TensorFlowNET.Core/Train/AdamOptimizer.cs View File

@@ -143,8 +143,8 @@ namespace Tensorflow.Train
{ {
ops.init_scope(); ops.init_scope();
var graph = ops.get_default_graph(); var graph = ops.get_default_graph();
return (_get_non_slot_variable("beta1_power", graph: graph),
_get_non_slot_variable("beta2_power", graph: graph));
return (_get_non_slot_variable("beta1_power", graph: graph) as RefVariable,
_get_non_slot_variable("beta2_power", graph: graph) as RefVariable);
} }


public override void _prepare() public override void _prepare()


+ 5
- 5
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow
public Tensor LearningRateTensor => _lr_t; public Tensor LearningRateTensor => _lr_t;
public bool _use_locking; public bool _use_locking;
public Dictionary<string, Dictionary<string, RefVariable>> _slots; public Dictionary<string, Dictionary<string, RefVariable>> _slots;
public Dictionary<string, RefVariable> _non_slot_dict;
public Dictionary<string, VariableV1> _non_slot_dict;
public Dictionary<string, object> _deferred_slot_restorations; public Dictionary<string, object> _deferred_slot_restorations;
SlotCreator slot_creator = new SlotCreator(); SlotCreator slot_creator = new SlotCreator();


@@ -58,7 +58,7 @@ namespace Tensorflow
_lr = learning_rate; _lr = learning_rate;
// Dictionary of slots. // Dictionary of slots.
_slots = new Dictionary<string, Dictionary<string, RefVariable>>(); _slots = new Dictionary<string, Dictionary<string, RefVariable>>();
_non_slot_dict = new Dictionary<string, RefVariable>();
_non_slot_dict = new Dictionary<string, VariableV1>();
_deferred_slot_restorations = new Dictionary<string, object>(); _deferred_slot_restorations = new Dictionary<string, object>();
} }


@@ -72,7 +72,7 @@ namespace Tensorflow
_lr_t = learning_rate; _lr_t = learning_rate;
// Dictionary of slots. // Dictionary of slots.
_slots = new Dictionary<string, Dictionary<string, RefVariable>>(); _slots = new Dictionary<string, Dictionary<string, RefVariable>>();
_non_slot_dict = new Dictionary<string, RefVariable>();
_non_slot_dict = new Dictionary<string, VariableV1>();
_deferred_slot_restorations = new Dictionary<string, object>(); _deferred_slot_restorations = new Dictionary<string, object>();
} }


@@ -239,7 +239,7 @@ namespace Tensorflow
/// <param name="initial_value"></param> /// <param name="initial_value"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="colocate_with"></param> /// <param name="colocate_with"></param>
protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with)
protected VariableV1 _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with)
{ {
// Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
var graph = colocate_with.graph; var graph = colocate_with.graph;
@@ -333,7 +333,7 @@ namespace Tensorflow
return $"{var.op.graph.graph_key}.{var.op.name}"; return $"{var.op.graph.graph_key}.{var.op.name}";
} }


protected RefVariable _get_non_slot_variable(string name, Graph graph = null)
protected VariableV1 _get_non_slot_variable(string name, Graph graph = null)
{ {
var key = $"{name}.{graph.graph_key}"; var key = $"{name}.{graph.graph_key}";
var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null; var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;


+ 1
- 1
src/TensorFlowNET.Core/Train/Trackable.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Train
/// </summary> /// </summary>
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="trackable"></param> /// <param name="trackable"></param>
protected void _handle_deferred_dependencies(string name, RefVariable trackable)
protected void _handle_deferred_dependencies(string name, VariableV1 trackable)
{ {
_maybe_initialize_trackable(); _maybe_initialize_trackable();
// TODO // TODO


+ 25
- 3
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -16,6 +16,7 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -53,7 +54,8 @@ namespace Tensorflow
string name = null, string name = null,
VariableDef variable_def = null, VariableDef variable_def = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
string import_scope = "") : base(initial_value,
string import_scope = "",
TensorShape shape = null) : base(initial_value,
trainable, trainable,
collections, collections,
validate_shape, validate_shape,
@@ -69,11 +71,31 @@ namespace Tensorflow
} }
else else
{ {
throw new NotImplementedException("ResourceVariable _init_from_args");
//_init_from_args(initial_value, trainable, collections, validate_shape, caching_device, name, dtype);
_init_from_args(initial_value: initial_value,
trainable: trainable,
collections: collections,
caching_device: caching_device,
name: name,
dtype: dtype,
shape: shape);
} }
} }


private void _init_from_args(object initial_value = null,
bool trainable = true,
List<string> collections = null,
string caching_device = "",
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape shape = null)
{
var init_from_fn = initial_value.GetType().Name == "Func`1";
if(collections == null)
collections = new List<string>() { tf.GraphKeys.GLOBAL_VARIABLES };
throw new NotImplementedException("");
}

private void _init_from_proto(VariableDef variable_def, string import_scope = null) private void _init_from_proto(VariableDef variable_def, string import_scope = null)
{ {
_in_graph_mode = true; _in_graph_mode = true;


+ 1
- 1
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow
trainable: trainable, trainable: trainable,
collections: collections, collections: collections,
synchronization: synchronization, synchronization: synchronization,
aggregation: aggregation);
aggregation: aggregation) as RefVariable;
}); });
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Variables/VariableV1.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
/// the variable are fixed. The value can be changed using one of the assign methods. /// the variable are fixed. The value can be changed using one of the assign methods.
/// https://tensorflow.org/guide/variables /// https://tensorflow.org/guide/variables
/// </summary> /// </summary>
public class VariableV1
public abstract class VariableV1
{ {
public virtual string name { get; } public virtual string name { get; }
public virtual Tensor graph_element { get; } public virtual Tensor graph_element { get; }


+ 4
- 4
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow
_store_eager_variables = false; _store_eager_variables = false;
} }


public RefVariable get_variable(string name,
public VariableV1 get_variable(string name,
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
object initializer = null, // IInitializer or Tensor object initializer = null, // IInitializer or Tensor
@@ -61,7 +61,7 @@ namespace Tensorflow
aggregation: aggregation); aggregation: aggregation);
} }


private RefVariable _true_getter(string name,
private VariableV1 _true_getter(string name,
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
object initializer = null, object initializer = null,
@@ -110,7 +110,7 @@ namespace Tensorflow
} }
} }


private RefVariable _get_single_variable(string name,
private VariableV1 _get_single_variable(string name,
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,
@@ -136,7 +136,7 @@ namespace Tensorflow
throw new NotImplementedException("_get_single_variable"); throw new NotImplementedException("_get_single_variable");
} }


RefVariable v = null;
VariableV1 v = null;
// Create the tensor to initialize the variable with default value. // Create the tensor to initialize the variable with default value.
if (initializer == null) if (initializer == null)
{ {


+ 9
- 2
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -172,11 +172,12 @@ namespace Tensorflow
return $"{prefix}_{idx}"; return $"{prefix}_{idx}";
} }


public static RefVariable default_variable_creator(object initial_value,
public static VariableV1 default_variable_creator(object initial_value,
string name = null, string name = null,
bool? trainable = null, bool? trainable = null,
List<string> collections = null, List<string> collections = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
bool validate_shape = false, bool validate_shape = false,
bool ? use_resource = null, bool ? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.Auto, VariableSynchronization synchronization = VariableSynchronization.Auto,
@@ -193,7 +194,13 @@ namespace Tensorflow


if (use_resource.Value) if (use_resource.Value)
{ {
throw new NotImplementedException();
return new ResourceVariable(initial_value,
trainable: trainable.Value,
validate_shape: validate_shape,
collections: collections,
name: name,
dtype: dtype,
shape: shape);
} }
else else
{ {


+ 5
- 3
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow
trainable: trainable, trainable: trainable,
validate_shape: validate_shape, validate_shape: validate_shape,
name: name, name: name,
dtype: dtype);
dtype: dtype) as RefVariable;
} }


public VariableV1 VariableV1<T>(T data, public VariableV1 VariableV1<T>(T data,
@@ -63,14 +63,16 @@ namespace Tensorflow
bool validate_shape = true, bool validate_shape = true,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
bool use_resource = false)
bool use_resource = false,
int[] shape = null)
{ {
return Tensorflow.variable_scope.default_variable_creator(data, return Tensorflow.variable_scope.default_variable_creator(data,
trainable: trainable, trainable: trainable,
validate_shape: validate_shape, validate_shape: validate_shape,
name: name, name: name,
dtype: dtype, dtype: dtype,
use_resource: use_resource);
use_resource: use_resource,
shape: shape);
} }


public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) public unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)


Loading…
Cancel
Save