diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index b2c02dfe..bc2eebb4 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -140,7 +140,32 @@ namespace Tensorflow.Graphs } Tensor capture_eager_tensor(Tensor tensor, string name) - => throw new NotImplementedException(""); + { + Tensor graph_const = null; + if (!_captures.ContainsKey(tensor.Id)) + { + graph_const = tf_with(ops.control_dependencies(null), ctl + => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); + add_capture(tensor, graph_const); + } + else + { + graph_const = _captures[tensor.Id].Item2; + } + + BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => + { + return output_grads; + }; + + tf.Runner.RecordGradient("captured_value", + new[] { graph_const }, null, + new[] { tensor }, + getBackwardFunction: () => _backward_function_wrapper + /*getForwardFunction: forward_function*/); + + return graph_const; + } Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) {