You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TapeGradientFunctions.cs 7.1 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Graphs;
  6. using static Tensorflow.Binding;
  7. using static Tensorflow.tensorflow;
  8. namespace Tensorflow.Functions
  9. {
  10. /// <summary>
  11. /// Caches forward and backward functions compatible with eager gradients.
  12. /// </summary>
  13. public abstract class TapeGradientFunctions
  14. {
  15. string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
  16. string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
  17. string _FORWARD_PREFIX = "__forward_";
  18. string _BACKWARD_PREFIX = "__backward_";
  19. string _INFERENCE_PREFIX = "__inference_";
  20. protected FuncGraph _func_graph;
  21. protected EagerDefinedFunction _forward;
  22. protected FuncGraph _forward_graph;
  23. protected List<int> _forwardprop_output_indices;
  24. protected int _num_forwardprop_outputs;
  25. protected ConcreteFunction _backward;
  26. public TapeGradientFunctions(FuncGraph func_graph,
  27. bool need_gradients_for_jvps)
  28. {
  29. _func_graph = func_graph;
  30. }
  31. public EagerDefinedFunction Forward(Tensors inference_args)
  32. {
  33. return ForwardAndBackwardFunctions(inference_args);
  34. }
  35. /// <summary>
  36. /// Record the function call operation.
  37. /// </summary>
  38. /// <param name="flat_outputs"></param>
  39. /// <param name="inference_args"></param>
  40. public void Record(Tensors flat_outputs, Tensors inference_args)
  41. {
  42. var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
  43. tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record,
  44. getBackwardFunction: () => backward_function);
  45. }
  46. /// <summary>
  47. /// Create a backward function given `outputs` from the forward function.
  48. /// </summary>
  49. /// <param name="forward_graph"></param>
  50. /// <param name="backward"></param>
  51. /// <param name="outputs"></param>
  52. /// <returns></returns>
  53. (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
  54. {
  55. var capture_mapping = new Dictionary<long, Tensor>();
  56. foreach(var (i, output) in enumerate(outputs))
  57. capture_mapping[forward_graph.Outputs[i].Id] = output;
  58. var remapped_captures = new Tensors();
  59. foreach(var capture in backward.CapturedInputs)
  60. {
  61. if (capture_mapping.ContainsKey(capture.Id))
  62. remapped_captures.Add(capture_mapping[capture.Id]);
  63. }
  64. var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
  65. var recorded_outputs = new Tensors();
  66. var relevant_outputs = outputs;
  67. var trainable_recorded_outputs = 0;
  68. var skip_positions = new List<int>();
  69. foreach (var (output_index, output) in enumerate(relevant_outputs))
  70. {
  71. if (trainable_recorded_outputs < backward_function_inputs)
  72. recorded_outputs.Add(output);
  73. if (gradients_util.IsTrainable(output))
  74. trainable_recorded_outputs += 1;
  75. else
  76. skip_positions.Add(output_index);
  77. }
  78. BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) =>
  79. {
  80. var processed_args = new Tensors();
  81. var input_index = 0;
  82. foreach (var (output_index, arg) in enumerate(args))
  83. {
  84. if (skip_positions.Contains(output_index))
  85. continue;
  86. if (arg == null)
  87. throw new NotImplementedException("");
  88. processed_args.Add(arg);
  89. input_index += 1;
  90. if (input_index >= backward_function_inputs)
  91. break;
  92. }
  93. tf.Logger.Debug($"Invoke backward function: {backward.Name}");
  94. return backward.CallFlat(processed_args, remapped_captures);
  95. };
  96. return (_backward_function_wrapper, recorded_outputs);
  97. }
  98. protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
  99. BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args)
  100. {
  101. var trainable_outputs = new List<Tensor>();
  102. var trainable_indices = new List<int>();
  103. foreach(var (index, output) in enumerate(outputs))
  104. {
  105. if (gradients_util.IsTrainable(output))
  106. {
  107. trainable_outputs.Add(output);
  108. trainable_indices.Add(index);
  109. }
  110. }
  111. var gradients_wrt_outputs = new List<Tensor>();
  112. var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}");
  113. backwards_graph.as_default();
  114. foreach (var output in trainable_outputs)
  115. gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
  116. var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
  117. _func_graph.Inputs,
  118. grad_ys: gradients_wrt_outputs.ToArray(),
  119. src_graph: _func_graph);
  120. var captures_from_forward = backwards_graph.external_captures
  121. .Where(x => !x.IsEagerTensor && x.graph == _func_graph)
  122. .ToArray();
  123. foreach(var capture in captures_from_forward)
  124. {
  125. if (!_func_graph.Outputs.Contains(capture))
  126. _func_graph.Outputs.Add(capture);
  127. }
  128. backwards_graph.Exit();
  129. var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
  130. var backward_function_attr = new Dictionary<string, string>();
  131. backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
  132. gradients_wrt_outputs.append(backwards_graph.internal_captures);
  133. backwards_graph.Inputs = gradients_wrt_outputs;
  134. backwards_graph.Outputs = gradients_wrt_inputs;
  135. var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
  136. var forward_function_attr = new Dictionary<string, string>();
  137. forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
  138. var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
  139. _func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
  140. return (forward_function, _func_graph, backward_function, null, 0);
  141. }
  142. public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
  143. {
  144. throw new NotImplementedException("");
  145. }
  146. }
  147. }