diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 06f51352..4a801b32 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -46,8 +46,7 @@ namespace Tensorflow.Train var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)); var m = get_slot(var, "m"); var m_scaled_g_values = grad * (1 - beta1_t); - var mul = m * beta1_t; - var m_t = state_ops.assign(m, mul, use_locking: _use_locking); + var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking); with(ops.control_dependencies(new[] { m_t }), delegate { m_t = scatter_add(m, indices, m_scaled_g_values); diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index a5a4ab69..e19ace4d 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -67,6 +67,26 @@ namespace Tensorflow return _result[0]; } + public static Tensor assign(RefVariable @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); + + var _result = _op.outputs; + var _inputs_flat = _op.inputs; + + var _attrs = new Dictionary(); + _attrs["T"] = _op.get_attr("T"); + _attrs["validate_shape"] = _op.get_attr("validate_shape"); + _attrs["use_locking"] = _op.get_attr("use_locking"); + + _execute.record_gradient("Assign", _inputs_flat, _attrs, _result, name); + + return _result[0]; + } + public static Tensor assign_sub(RefVariable @ref, Tensor value, bool use_locking = false, diff --git a/src/TensorFlowNET.Core/Variables/state_ops.cs b/src/TensorFlowNET.Core/Variables/state_ops.cs index 22894fe0..d5acf767 100644 --- a/src/TensorFlowNET.Core/Variables/state_ops.cs +++ b/src/TensorFlowNET.Core/Variables/state_ops.cs @@ -40,6 +40,18 @@ namespace Tensorflow //return @ref.assign(value, name: name); } + public static Tensor assign(RefVariable @ref, object value, + bool validate_shape = true, + bool use_locking = true, + string name = null) + { + return gen_state_ops.assign(@ref, + value, + validate_shape: validate_shape, + use_locking: use_locking, + name: name); + } + public static Tensor assign_sub(RefVariable @ref, Tensor value, bool use_locking = false,