diff --git a/src/TensorFlowNET.Core/Gradients/GradientActor.cs b/src/TensorFlowNET.Core/Gradients/GradientActor.cs index e8be5dae..82f37ac3 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientActor.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientActor.cs @@ -48,6 +48,14 @@ namespace Tensorflow.Gradients _recording = true; } + private void _pop_tape() + { + if (!_recording) + throw new ValueError("Tape is not recording."); + _tape.pop_tape(_tape); + _recording = false; + } + /// /// Marks this tensor to be watched by the given tape. /// @@ -59,12 +67,19 @@ namespace Tensorflow.Gradients public Tensor gradient(Tensor target, Tensor sources) { - using (var status = new Status()) + if(_recording) { - c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status); + if (!_persistent) + _pop_tape(); } - - return null; + + using var status = new Status(); + var et = c_api.TFE_TapeGradient(_tape, + new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1, + new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1, + status); + status.Check(true); + return et; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index f47616dd..8bcf7f5f 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -20,6 +20,11 @@ namespace Tensorflow.Gradients c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); } + public void pop_tape(Tape tape) + { + c_api.TFE_TapeSetRemove(tape); + } + public static bool IsDtypeTrainable(DataType dtype) { switch (dtype) diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 84914a76..93c43ca3 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -715,17 +715,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - using (var status = new Status()) + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Mul", name, new IntPtr[] { - var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Mul", name, new IntPtr[] - { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle - }, 2, status); - status.Check(true); - return new EagerTensor(_result); - } + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, status); + status.Check(true); + return new EagerTensor(_result); } var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });