| @@ -140,7 +140,32 @@ namespace Tensorflow.Graphs | |||||
| } | } | ||||
| Tensor capture_eager_tensor(Tensor tensor, string name) | Tensor capture_eager_tensor(Tensor tensor, string name) | ||||
| => throw new NotImplementedException(""); | |||||
| { | |||||
| Tensor graph_const = null; | |||||
| if (!_captures.ContainsKey(tensor.Id)) | |||||
| { | |||||
| graph_const = tf_with(ops.control_dependencies(null), ctl | |||||
| => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||||
| add_capture(tensor, graph_const); | |||||
| } | |||||
| else | |||||
| { | |||||
| graph_const = _captures[tensor.Id].Item2; | |||||
| } | |||||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
| { | |||||
| return output_grads; | |||||
| }; | |||||
| tf.Runner.RecordGradient("captured_value", | |||||
| new[] { graph_const }, null, | |||||
| new[] { tensor }, | |||||
| getBackwardFunction: () => _backward_function_wrapper | |||||
| /*getForwardFunction: forward_function*/); | |||||
| return graph_const; | |||||
| } | |||||
| Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | ||||
| { | { | ||||