Browse Source

Change RefVariable to IVariableV1.

tags/v0.20
Oceania2018 5 years ago
parent
commit
f849af095d
8 changed files with 49 additions and 49 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.state.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Training/Optimizer.cs
  3. +16
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  5. +21
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  6. +4
    -15
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  7. +0
    -21
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs
  8. +2
    -6
      src/TensorFlowNET.Core/Variables/state_ops.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.state.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
{
public partial class tensorflow
{
public Tensor assign_add<T>(RefVariable @ref, T value,
public Tensor assign_add<T>(IVariableV1 @ref, T value,
bool use_locking = false, string name = null)
=> state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
}


+ 3
- 3
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -106,7 +106,7 @@ namespace Tensorflow
/// was not `None`, that operation also increments `global_step`.
/// </returns>
public Operation minimize(Tensor loss,
RefVariable global_step = null,
IVariableV1 global_step = null,
List<ResourceVariable> var_list=null,
GateGradientType gate_gradients = GateGradientType.GATE_OP,
int? aggregation_method=null,
@@ -142,7 +142,7 @@ namespace Tensorflow
/// <returns>
/// An `Operation` that applies the specified gradients. If `global_step`
/// was not None, that operation also increments `global_step`.</returns>
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, RefVariable global_step = null, string name = null)
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, IVariableV1 global_step = null, string name = null)
{
// No DistributionStrategy case.
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>();
@@ -192,7 +192,7 @@ namespace Tensorflow
{
tf_with(ops.control_dependencies(new object[] {_finish(update_ops.ToArray(), "update")}), dep =>
{
ops.colocate_with(global_step);
// ops.colocate_with(global_step);
// TODO: port this if branch once ResourceVariable has been ported!
//if (global_step is ResourceVariable)
//{


+ 16
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -122,15 +122,28 @@ namespace Tensorflow
return array_ops.identity(value);
});

public Operation assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
{
var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle,
ops.convert_to_tensor(delta, dtype: dtype), name: name);
/*if (read_value)
return _lazy_read(assign_add_op);*/
return assign_add_op;
}

public override string ToString()
=> $"tf.Variable '{Name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}";
{
if (tf.context.executing_eagerly())
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}";
else
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}";
}

public NDArray numpy() => read_value().numpy();

protected override void DisposeUnmanagedResources(IntPtr handle)
{
// delete
// c_api.TFE_DeleteResourceVariable(handle);
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -37,5 +37,7 @@ namespace Tensorflow
public Operation Op { get; }
public Tensor GraphElement { get; }
public Graph Graph { get; }
public TF_DataType dtype { get; }
public Operation assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
}
}

+ 21
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -401,5 +401,26 @@ namespace Tensorflow
read_value,
initial_value);
}

// Update 'ref' by adding 'value' to it.
// This operation outputs "ref" after the update is done.
// This makes it easier to chain operations that need to use the reset value.
// Args:
// ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
// Should be from a `Variable` node.
// value: A `Tensor`. Must have the same type as `ref`.
// The value to be added to the variable.
// use_locking: An optional `bool`. Defaults to `False`.
// If True, the addition will be protected by a lock;
// otherwise the behavior is undefined, but may exhibit less contention.
// name: A name for the operation(optional).
// Returns:
// A mutable `Tensor`. Has the same type as `ref`.
public Operation assign_add<T>(T value, bool use_locking = false, string name = null, bool read_value = true)
{
var variable = this;
var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking });
return _op;
}
}
}

+ 4
- 15
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -139,9 +139,8 @@ namespace Tensorflow
tf_with(ops.name_scope("Assign"), scope1 =>
{
string n = scope1;
initializer_op = gen_resource_variable_ops.assign_variable_op(handle,
variables._try_guard_against_uninitialized_dependencies(name, _initial_value),
name: n);
var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value);
initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n);
});
}

@@ -149,7 +148,8 @@ namespace Tensorflow
// messages.
tf_with(ops.name_scope("Read"), delegate
{
var value = _read_variable_op();
var value = gen_resource_variable_ops.read_variable_op(handle, _dtype);
// _maybe_set_handle_data(dtype, handle, value);
_graph_element = value;
});

@@ -233,16 +233,5 @@ namespace Tensorflow
return array_ops.identity(value);
});
}

public override string ToString()
{
return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}";
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
// delete
// c_api.TFE_DeleteResourceVariable(handle);
}
}
}

+ 0
- 21
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -120,27 +120,6 @@ namespace Tensorflow
return _op.outputs[0];
}


// Update 'ref' by adding 'value' to it.
// This operation outputs "ref" after the update is done.
// This makes it easier to chain operations that need to use the reset value.
// Args:
// ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
// Should be from a `Variable` node.
// value: A `Tensor`. Must have the same type as `ref`.
// The value to be added to the variable.
// use_locking: An optional `bool`. Defaults to `False`.
// If True, the addition will be protected by a lock;
// otherwise the behavior is undefined, but may exhibit less contention.
// name: A name for the operation(optional).
// Returns:
// A mutable `Tensor`. Has the same type as `ref`.
public static Tensor assign_add<T>(RefVariable @ref, T value, bool use_locking = false, string name = null)
{
var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking });
return _op.outputs[0];
}

/// <summary>
/// Adds sparse updates to a variable reference.
/// </summary>


+ 2
- 6
src/TensorFlowNET.Core/Variables/state_ops.cs View File

@@ -106,15 +106,11 @@ namespace Tensorflow
// Returns:
// Same as "ref". Returned as a convenience for operations that want
// to use the new value after the variable has been updated.
public static Tensor assign_add<T>(RefVariable @ref,
public static Operation assign_add<T>(IVariableV1 @ref,
T value,
bool use_locking = false,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
throw new NotImplementedException("assign_add");
}
=> @ref.assign_add(value, use_locking: use_locking, name: name);

public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
{


Loading…
Cancel
Save