Browse Source

Implement FuncGraph.capture_eager_tensor.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
2093577a98
1 changed files with 26 additions and 1 deletions
  1. +26
    -1
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs

+ 26
- 1
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -140,7 +140,32 @@ namespace Tensorflow.Graphs
}

Tensor capture_eager_tensor(Tensor tensor, string name)
=> throw new NotImplementedException("");
{
Tensor graph_const = null;
if (!_captures.ContainsKey(tensor.Id))
{
graph_const = tf_with(ops.control_dependencies(null), ctl
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
add_capture(tensor, graph_const);
}
else
{
graph_const = _captures[tensor.Id].Item2;
}

BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
return output_grads;
};

tf.Runner.RecordGradient("captured_value",
new[] { graph_const }, null,
new[] { tensor },
getBackwardFunction: () => _backward_function_wrapper
/*getForwardFunction: forward_function*/);

return graph_const;
}

Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
{


Loading…
Cancel
Save