You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AutoGraphAttribute.cs 4.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. using MethodBoundaryAspect.Fody.Attributes;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Eager;
  6. using Tensorflow.Functions;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Graphs
  9. {
  10. /// <summary>
  11. /// func_graph.py func_graph_from_py_func
  12. /// </summary>
  13. [AllowChangingInputArguments]
  14. public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
  15. {
  16. ConcreteFunction function;
  17. Tensors originalInputs;
  18. string func_name;
  19. static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>();
  20. public override void OnEntry(MethodExecutionArgs args)
  21. {
  22. // TODO: func_name can be cache in FullName + Args
  23. func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}";
  24. if (functions.ContainsKey(func_name))
  25. {
  26. function = functions[func_name];
  27. if (args.Arguments[0] is Tensors tensor_inputs)
  28. args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs));
  29. else
  30. args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray()));
  31. args.FlowBehavior = FlowBehavior.Return;
  32. return;
  33. }
  34. // make function as an Operation by autograph
  35. // need to restore mode when exits
  36. function = new ConcreteFunction(func_name);
  37. function.Enter();
  38. // convert to Tensors
  39. if (args.Arguments[0] is Tensors inputs)
  40. {
  41. originalInputs = inputs;
  42. var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray();
  43. args.Arguments[0] = new Tensors(new_inputs);
  44. }
  45. else
  46. {
  47. originalInputs = new Tensors();
  48. // convert args to placeholder
  49. for (var i = 0; i < args.Arguments.Length; i++)
  50. {
  51. if (args.Arguments[i] is EagerTensor tensor)
  52. {
  53. originalInputs.Add(tensor);
  54. args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs");
  55. }
  56. }
  57. }
  58. }
  59. public override void OnExit(MethodExecutionArgs args)
  60. {
  61. if (args.ReturnValue is Tensors outputs)
  62. {
  63. Tensors inputs = null;
  64. outputs = mark_as_return(outputs);
  65. if (args.Arguments[0] is Tensors inputs1)
  66. inputs = inputs1;
  67. else
  68. inputs = args.Arguments.Select(x => x as Tensor).ToArray();
  69. inputs = inputs.Where(x => x.op.OpType == "Placeholder"
  70. && x.op.name.StartsWith("inputs")).ToArray();
  71. function.ToGraph(inputs, outputs);
  72. }
  73. else if (args.ReturnValue is Tensor output)
  74. {
  75. var inputs = args.Arguments.Select(x => x as Tensor)
  76. .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs"))
  77. .ToArray();
  78. var outputs2 = array_ops.identity(output);
  79. function.ToGraph(inputs, outputs2);
  80. }
  81. function.Exit();
  82. // cache function.
  83. function.ReturnType = args.ReturnValue.GetType();
  84. functions[func_name] = function;
  85. // run function
  86. args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs));
  87. }
  88. object ConvertReturnValue(Tensors tensors)
  89. {
  90. if (function.ReturnType == typeof(Tensor))
  91. return (Tensor)tensors;
  92. else
  93. return tensors;
  94. }
  95. /// <summary>
  96. /// Acts like identity but marks the `Tensor` as a return value.
  97. /// </summary>
  98. /// <param name="tensors"></param>
  99. /// <returns></returns>
  100. public Tensors mark_as_return(Tensors tensors)
  101. {
  102. if (tensors == null)
  103. return null;
  104. var result = new Tensors();
  105. foreach (var tensor in tensors)
  106. result.Add(array_ops.identity(tensor));
  107. return result;
  108. }
  109. }
  110. }