|
|
|
@@ -140,7 +140,32 @@ namespace Tensorflow.Graphs |
|
|
|
} |
|
|
|
|
|
|
|
Tensor capture_eager_tensor(Tensor tensor, string name) |
|
|
|
=> throw new NotImplementedException(""); |
|
|
|
{ |
|
|
|
Tensor graph_const = null; |
|
|
|
if (!_captures.ContainsKey(tensor.Id)) |
|
|
|
{ |
|
|
|
graph_const = tf_with(ops.control_dependencies(null), ctl |
|
|
|
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); |
|
|
|
add_capture(tensor, graph_const); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
graph_const = _captures[tensor.Id].Item2; |
|
|
|
} |
|
|
|
|
|
|
|
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => |
|
|
|
{ |
|
|
|
return output_grads; |
|
|
|
}; |
|
|
|
|
|
|
|
tf.Runner.RecordGradient("captured_value", |
|
|
|
new[] { graph_const }, null, |
|
|
|
new[] { tensor }, |
|
|
|
getBackwardFunction: () => _backward_function_wrapper |
|
|
|
/*getForwardFunction: forward_function*/); |
|
|
|
|
|
|
|
return graph_const; |
|
|
|
} |
|
|
|
|
|
|
|
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) |
|
|
|
{ |
|
|
|
|