| @@ -48,6 +48,14 @@ namespace Tensorflow.Gradients | |||||
| _recording = true; | _recording = true; | ||||
| } | } | ||||
| private void _pop_tape() | |||||
| { | |||||
| if (!_recording) | |||||
| throw new ValueError("Tape is not recording."); | |||||
| _tape.pop_tape(_tape); | |||||
| _recording = false; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Marks this tensor to be watched by the given tape. | /// Marks this tensor to be watched by the given tape. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -59,12 +67,19 @@ namespace Tensorflow.Gradients | |||||
| public Tensor gradient(Tensor target, Tensor sources) | public Tensor gradient(Tensor target, Tensor sources) | ||||
| { | { | ||||
| using (var status = new Status()) | |||||
| if(_recording) | |||||
| { | { | ||||
| c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status); | |||||
| if (!_persistent) | |||||
| _pop_tape(); | |||||
| } | } | ||||
| return null; | |||||
| using var status = new Status(); | |||||
| var et = c_api.TFE_TapeGradient(_tape, | |||||
| new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1, | |||||
| new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1, | |||||
| status); | |||||
| status.Check(true); | |||||
| return et; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||
| @@ -20,6 +20,11 @@ namespace Tensorflow.Gradients | |||||
| c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); | c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle); | ||||
| } | } | ||||
| public void pop_tape(Tape tape) | |||||
| { | |||||
| c_api.TFE_TapeSetRemove(tape); | |||||
| } | |||||
| public static bool IsDtypeTrainable(DataType dtype) | public static bool IsDtypeTrainable(DataType dtype) | ||||
| { | { | ||||
| switch (dtype) | switch (dtype) | ||||
| @@ -715,17 +715,15 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (tf.context.executing_eagerly()) | if (tf.context.executing_eagerly()) | ||||
| { | { | ||||
| using (var status = new Status()) | |||||
| using var status = new Status(); | |||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Mul", name, new IntPtr[] | |||||
| { | { | ||||
| var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||||
| "Mul", name, new IntPtr[] | |||||
| { | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | |||||
| (x as EagerTensor).EagerTensorHandle, | |||||
| (y as EagerTensor).EagerTensorHandle | |||||
| }, 2, status); | |||||
| status.Check(true); | |||||
| return new EagerTensor(_result); | |||||
| } | } | ||||
| var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | ||||