You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkpoint.cs 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. using Google.Protobuf;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using Tensorflow.Contexts;
  7. using Tensorflow.Eager;
  8. using Tensorflow.Train;
  9. using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
  10. using static Tensorflow.Binding;
  11. namespace Tensorflow.Checkpoint;
  12. /// <summary>
  13. /// Saves and restores a `Trackable` object and its dependencies.
  14. /// </summary>
  15. public class TrackableSaver
  16. {
  17. private ObjectGraphView _graph_view;
  18. private Tensor _cached_save_operation;
  19. private TrackableObjectGraph _last_save_object_graph;
  20. private Tensor? _object_graph_feed_tensor = null;
  21. private Tensor? _file_prefix_feed_tensor = null;
  22. private Dictionary<Trackable, Trackable>? _object_map = null;
  23. private object? _cache = null;
  24. public TrackableSaver(ObjectGraphView graph_view)
  25. {
  26. _graph_view = graph_view;
  27. // TODO: cache when not executing eagerly.
  28. // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
  29. // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
  30. }
  31. private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
  32. gather_serialized_tensors(Tensor? object_graph_tensor = null)
  33. {
  34. var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
  35. // TODO: cache.
  36. if(object_graph_tensor is null)
  37. {
  38. // tensorflow python: `with ops.device("/cpu:0"):`
  39. object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
  40. }
  41. else
  42. {
  43. feed_additions[object_graph_tensor] = graph_proto.ToString();
  44. }
  45. Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
  46. if (serialized_tensors.ContainsKey(Trackable.None))
  47. {
  48. serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
  49. }
  50. return (serialized_tensors, feed_additions, registered_savers, graph_proto);
  51. }
  52. private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
  53. {
  54. var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);
  55. Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
  56. {
  57. if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
  58. {
  59. var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
  60. var save_op = saver.save(file_prefix, options);
  61. // tensorflow python: `with ops.device("/cpu:0"):`
  62. using (ops.control_dependencies(new object[] { save_op }))
  63. {
  64. _cached_save_operation = array_ops.identity(file_prefix);
  65. }
  66. _last_save_object_graph = graph_proto;
  67. }
  68. return (_cached_save_operation, feed_additions);
  69. };
  70. if (options.experimental_enable_async_checkpoint)
  71. {
  72. throw new NotImplementedException();
  73. }
  74. return run_save();
  75. }
  76. private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
  77. {
  78. var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);
  79. Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
  80. {
  81. if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
  82. {
  83. var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
  84. var save_op = saver.save(file_prefix, options);
  85. // tensorflow python: `with ops.device("/cpu:0"):`
  86. using (ops.control_dependencies(new object[] {save_op} ))
  87. {
  88. _cached_save_operation = array_ops.identity(tf.constant(file_prefix));
  89. }
  90. _last_save_object_graph = graph_proto;
  91. }
  92. return (_cached_save_operation, feed_additions);
  93. };
  94. if (options.experimental_enable_async_checkpoint)
  95. {
  96. throw new NotImplementedException();
  97. }
  98. return run_save();
  99. }
  100. // TODO: parameter write_done_callback
  101. public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
  102. CheckpointOptions? options = null)
  103. {
  104. if (options is null)
  105. {
  106. options = new CheckpointOptions();
  107. }
  108. Dictionary<Tensor, string> feed_dict = new();
  109. bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
  110. if (checkpoint_number is not null)
  111. {
  112. file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
  113. }
  114. Tensor file_prefix_tensor;
  115. Tensor object_graph_tensor;
  116. if (use_session)
  117. {
  118. if (_object_graph_feed_tensor is null)
  119. {
  120. // In python there is `with ops.device("/cpu:0")`.
  121. _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
  122. _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
  123. }
  124. object_graph_tensor = _object_graph_feed_tensor;
  125. file_prefix_tensor = _file_prefix_feed_tensor;
  126. feed_dict[file_prefix_tensor] = file_prefix;
  127. }
  128. else
  129. {
  130. // In python there is `with ops.device("/cpu:0")`.
  131. file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
  132. object_graph_tensor = null;
  133. }
  134. var (save_path, new_feed_additions) =
  135. save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);
  136. if (new_feed_additions is not null)
  137. {
  138. foreach (var pair in new_feed_additions)
  139. {
  140. feed_dict.Add(pair.Key, pair.Value);
  141. }
  142. }
  143. if(!use_session)
  144. {
  145. session = null;
  146. }
  147. else if (session is null)
  148. {
  149. session = new Session(); // In python it uses `get_session`.
  150. }
  151. if (session is not null)
  152. {
  153. var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
  154. return session.run((Tensor)save_path, s);
  155. }
  156. else if (use_session)
  157. {
  158. throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
  159. "in graph mode without a default session. Please use " +
  160. "`with tf.Session():` to create a session.");
  161. }
  162. else
  163. {
  164. return save_path;
  165. }
  166. }
  167. }