|
|
|
@@ -25,6 +25,7 @@ namespace Tensorflow.Functions |
|
|
|
protected List<int> _forwardprop_output_indices; |
|
|
|
protected int _num_forwardprop_outputs; |
|
|
|
protected ConcreteFunction _backward; |
|
|
|
BackwardFunction _backward_function_wrapper; |
|
|
|
|
|
|
|
public TapeGradientFunctions(FuncGraph func_graph, |
|
|
|
bool need_gradients_for_jvps) |
|
|
|
@@ -58,60 +59,66 @@ namespace Tensorflow.Functions |
|
|
|
/// <returns></returns> |
|
|
|
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) |
|
|
|
{ |
|
|
|
var capture_mapping = new Dictionary<long, Tensor>(); |
|
|
|
foreach(var (i, output) in enumerate(outputs)) |
|
|
|
capture_mapping[forward_graph.Outputs[i].Id] = output; |
|
|
|
|
|
|
|
var remapped_captures = new Tensors(); |
|
|
|
foreach(var capture in backward.CapturedInputs) |
|
|
|
{ |
|
|
|
if (capture_mapping.ContainsKey(capture.Id)) |
|
|
|
remapped_captures.Add(capture_mapping[capture.Id]); |
|
|
|
} |
|
|
|
|
|
|
|
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; |
|
|
|
var recorded_outputs = new Tensors(); |
|
|
|
var relevant_outputs = outputs; |
|
|
|
var trainable_recorded_outputs = 0; |
|
|
|
var skip_positions = new List<int>(); |
|
|
|
foreach (var (output_index, output) in enumerate(relevant_outputs)) |
|
|
|
foreach (var (output_index, output) in enumerate(outputs)) |
|
|
|
{ |
|
|
|
if (trainable_recorded_outputs < backward_function_inputs) |
|
|
|
recorded_outputs.Add(output); |
|
|
|
if (gradients_util.IsTrainable(output)) |
|
|
|
trainable_recorded_outputs += 1; |
|
|
|
else |
|
|
|
skip_positions.Add(output_index); |
|
|
|
} |
|
|
|
|
|
|
|
BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => |
|
|
|
if(_backward_function_wrapper == null) |
|
|
|
{ |
|
|
|
var processed_args = new Tensors(); |
|
|
|
var input_index = 0; |
|
|
|
foreach (var (output_index, arg) in enumerate(args)) |
|
|
|
var capture_mapping = new Dictionary<long, Tensor>(); |
|
|
|
foreach (var (i, output) in enumerate(outputs)) |
|
|
|
capture_mapping[forward_graph.Outputs[i].Id] = output; |
|
|
|
|
|
|
|
var remapped_captures = new Tensors(); |
|
|
|
foreach (var capture in backward.CapturedInputs) |
|
|
|
{ |
|
|
|
if (skip_positions.Contains(output_index)) |
|
|
|
continue; |
|
|
|
if (arg == null) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
processed_args.Add(arg); |
|
|
|
input_index += 1; |
|
|
|
if (input_index >= backward_function_inputs) |
|
|
|
break; |
|
|
|
if (capture_mapping.ContainsKey(capture.Id)) |
|
|
|
remapped_captures.Add(capture_mapping[capture.Id]); |
|
|
|
} |
|
|
|
|
|
|
|
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); |
|
|
|
var gradients = backward.CallFlat(processed_args, remapped_captures); |
|
|
|
|
|
|
|
foreach (var unneeded_gradient_index in unneeded_gradients) |
|
|
|
var skip_positions = new List<int>(); |
|
|
|
foreach (var (output_index, output) in enumerate(outputs)) |
|
|
|
{ |
|
|
|
var index = Convert.ToInt32(unneeded_gradient_index); |
|
|
|
if (gradients.Length <= index) |
|
|
|
gradients.Insert(index, null); |
|
|
|
if (!gradients_util.IsTrainable(output)) |
|
|
|
skip_positions.Add(output_index); |
|
|
|
} |
|
|
|
|
|
|
|
return gradients; |
|
|
|
}; |
|
|
|
_backward_function_wrapper = (args, unneeded_gradients) => |
|
|
|
{ |
|
|
|
var processed_args = new Tensors(); |
|
|
|
var input_index = 0; |
|
|
|
foreach (var (output_index, arg) in enumerate(args)) |
|
|
|
{ |
|
|
|
if (skip_positions.Contains(output_index)) |
|
|
|
continue; |
|
|
|
if (arg == null) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
processed_args.Add(arg); |
|
|
|
input_index += 1; |
|
|
|
if (input_index >= backward_function_inputs) |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); |
|
|
|
var gradients = backward.CallFlat(processed_args, remapped_captures); |
|
|
|
|
|
|
|
foreach (var unneeded_gradient_index in unneeded_gradients) |
|
|
|
{ |
|
|
|
var index = Convert.ToInt32(unneeded_gradient_index); |
|
|
|
if (gradients.Length <= index) |
|
|
|
gradients.Insert(index, null); |
|
|
|
} |
|
|
|
|
|
|
|
return gradients; |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
return (_backward_function_wrapper, recorded_outputs); |
|
|
|
} |
|
|
|
|