diff --git a/src/TensorFlowNET.Core/APIs/tf.state.cs b/src/TensorFlowNET.Core/APIs/tf.state.cs index c57d03c6..cb5cb4f5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.state.cs +++ b/src/TensorFlowNET.Core/APIs/tf.state.cs @@ -18,7 +18,7 @@ namespace Tensorflow { public partial class tensorflow { - public Tensor assign_add(RefVariable @ref, T value, + public Tensor assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) => state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); } diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index e3c4ca03..ebe9690d 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -106,7 +106,7 @@ namespace Tensorflow /// was not `None`, that operation also increments `global_step`. /// public Operation minimize(Tensor loss, - RefVariable global_step = null, + IVariableV1 global_step = null, List var_list=null, GateGradientType gate_gradients = GateGradientType.GATE_OP, int? aggregation_method=null, @@ -142,7 +142,7 @@ namespace Tensorflow /// /// An `Operation` that applies the specified gradients. If `global_step` /// was not None, that operation also increments `global_step`. - public Operation apply_gradients(Tuple[] grads_and_vars, RefVariable global_step = null, string name = null) + public Operation apply_gradients(Tuple[] grads_and_vars, IVariableV1 global_step = null, string name = null) { // No DistributionStrategy case. var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>(); @@ -192,7 +192,7 @@ namespace Tensorflow { tf_with(ops.control_dependencies(new object[] {_finish(update_ops.ToArray(), "update")}), dep => { - ops.colocate_with(global_step); + // ops.colocate_with(global_step); // TODO: port this if branch once ResourceVariable has been ported! //if (global_step is ResourceVariable) //{ diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index e0af8b70..822d1f66 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -122,15 +122,28 @@ namespace Tensorflow return array_ops.identity(value); }); + public Operation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) + { + var assign_add_op = gen_resource_variable_ops.assign_add_variable_op(Handle, + ops.convert_to_tensor(delta, dtype: dtype), name: name); + + /*if (read_value) + return _lazy_read(assign_add_op);*/ + return assign_add_op; + } + public override string ToString() - => $"tf.Variable '{Name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}"; + { + if (tf.context.executing_eagerly()) + return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; + else + return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}"; + } public NDArray numpy() => read_value().numpy(); protected override void DisposeUnmanagedResources(IntPtr handle) { - // delete - // c_api.TFE_DeleteResourceVariable(handle); } } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index af49d09d..e7389522 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -37,5 +37,7 @@ namespace Tensorflow public Operation Op { get; } public Tensor GraphElement { get; } public Graph Graph { get; } + public TF_DataType dtype { get; } + public Operation assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 35d617fd..38d124ac 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -401,5 +401,26 @@ namespace Tensorflow read_value, initial_value); } + + // Update 'ref' by adding 'value' to it. + // This operation outputs "ref" after the update is done. + // This makes it easier to chain operations that need to use the reset value. + // Args: + // ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + // Should be from a `Variable` node. + // value: A `Tensor`. Must have the same type as `ref`. + // The value to be added to the variable. + // use_locking: An optional `bool`. Defaults to `False`. + // If True, the addition will be protected by a lock; + // otherwise the behavior is undefined, but may exhibit less contention. + // name: A name for the operation(optional). + // Returns: + // A mutable `Tensor`. Has the same type as `ref`. + public Operation assign_add(T value, bool use_locking = false, string name = null, bool read_value = true) + { + var variable = this; + var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { variable, value, use_locking }); + return _op; + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index c2c586b9..adee9779 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -139,9 +139,8 @@ namespace Tensorflow tf_with(ops.name_scope("Assign"), scope1 => { string n = scope1; - initializer_op = gen_resource_variable_ops.assign_variable_op(handle, - variables._try_guard_against_uninitialized_dependencies(name, _initial_value), - name: n); + var _initial_value2 = variables._try_guard_against_uninitialized_dependencies(name, _initial_value); + initializer_op = gen_resource_variable_ops.assign_variable_op(handle, _initial_value2, name: n); }); } @@ -149,7 +148,8 @@ namespace Tensorflow // messages. tf_with(ops.name_scope("Read"), delegate { - var value = _read_variable_op(); + var value = gen_resource_variable_ops.read_variable_op(handle, _dtype); + // _maybe_set_handle_data(dtype, handle, value); _graph_element = value; }); @@ -233,16 +233,5 @@ namespace Tensorflow return array_ops.identity(value); }); } - - public override string ToString() - { - return $"tf.Variable: '{Name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; - } - - protected override void DisposeUnmanagedResources(IntPtr handle) - { - // delete - // c_api.TFE_DeleteResourceVariable(handle); - } } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index 192227a7..9a566d70 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -120,27 +120,6 @@ namespace Tensorflow return _op.outputs[0]; } - - // Update 'ref' by adding 'value' to it. - // This operation outputs "ref" after the update is done. - // This makes it easier to chain operations that need to use the reset value. - // Args: - // ref: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. - // Should be from a `Variable` node. - // value: A `Tensor`. Must have the same type as `ref`. - // The value to be added to the variable. - // use_locking: An optional `bool`. Defaults to `False`. - // If True, the addition will be protected by a lock; - // otherwise the behavior is undefined, but may exhibit less contention. - // name: A name for the operation(optional). - // Returns: - // A mutable `Tensor`. Has the same type as `ref`. - public static Tensor assign_add(RefVariable @ref, T value, bool use_locking = false, string name = null) - { - var _op = tf._op_def_lib._apply_op_helper("AssignAdd", name: name, args: new { @ref, value, use_locking }); - return _op.outputs[0]; - } - /// /// Adds sparse updates to a variable reference. /// diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index b87512c3..4ad626ef 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -106,15 +106,11 @@ namespace Tensorflow // Returns: // Same as "ref". Returned as a convenience for operations that want // to use the new value after the variable has been updated. - public static Tensor assign_add(RefVariable @ref, + public static Operation assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) - { - if (@ref.dtype.is_ref_dtype()) - return gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name); - throw new NotImplementedException("assign_add"); - } + => @ref.assign_add(value, use_locking: use_locking, name: name); public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null) {