You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CheckPointUtils.cs 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.IO;
  5. using System.Linq;
  6. using Tensorflow.Train;
  7. using Tensorflow.Training;
  8. using pbc = global::Google.Protobuf.Collections;
  9. namespace Tensorflow.Checkpoint;
  10. public static class CheckPointUtils
  11. {
  12. private static string _ESCAPE_CHAR = ".";
  13. public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
  14. IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
  15. IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
  16. {
  17. var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
  18. Dictionary<Trackable, string> object_names = new();
  19. foreach (var pair in node_paths)
  20. {
  21. object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
  22. }
  23. Dictionary<Trackable, int> node_ids = new();
  24. for (int i = 0; i < trackable_objects.Count; i++)
  25. {
  26. node_ids[trackable_objects[i]] = i;
  27. }
  28. var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
  29. return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
  30. }
  31. public static
  32. IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
  33. serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
  34. IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
  35. {
  36. var non_slot_objects = trackable_objects.ToList();
  37. Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
  38. slot_variables = new();
  39. foreach (var trackable in non_slot_objects)
  40. {
  41. if (trackable is not Optimizer)
  42. {
  43. continue;
  44. }
  45. var optim = (Optimizer)trackable;
  46. var slot_names = optim.get_slot_names();
  47. foreach (var slot_name in slot_names)
  48. {
  49. for (int original_variable_node_id = 0;
  50. original_variable_node_id < non_slot_objects.Count;
  51. original_variable_node_id++)
  52. {
  53. var original_variable = non_slot_objects[original_variable_node_id];
  54. IVariableV1 slot_variable;
  55. if (original_variable is not IVariableV1)
  56. {
  57. slot_variable = null;
  58. }
  59. slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
  60. if(slot_variable is null) continue;
  61. // There're some problems about the inherits of `Variable` and `Trackable`.
  62. throw new NotImplementedException();
  63. }
  64. }
  65. }
  66. return slot_variables;
  67. }
  68. public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
  69. {
  70. if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
  71. {
  72. return trackable;
  73. }
  74. else
  75. {
  76. return possible_res;
  77. }
  78. }
  79. public static string get_full_name(Trackable variable)
  80. {
  81. // TODO: This state is not correct, the whole framework need to be updated in the future.
  82. if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
  83. {
  84. return "";
  85. }
  86. // skip the check of attribute `_save_slice_info` .
  87. // TODO: Need to be revised!!!
  88. Debug.Assert(variable is BaseResourceVariable);
  89. return ((BaseResourceVariable)variable).Name;
  90. }
  91. public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
  92. {
  93. HashSet<int> checkpointed_trackables = new();
  94. Dictionary<int, HashSet<int>> parents = new();
  95. for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
  96. {
  97. var object_proto = object_graph_proto.Nodes[i];
  98. // skip the process of registered saver.
  99. if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
  100. object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
  101. {
  102. checkpointed_trackables.Add(i);
  103. }
  104. foreach (var child_proto in object_proto.Children)
  105. {
  106. var child = child_proto.NodeId;
  107. if (!parents.ContainsKey(child))
  108. {
  109. parents[child] = new HashSet<int>();
  110. }
  111. parents[child].Add(i);
  112. }
  113. }
  114. Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
  115. while (to_visit.Count > 0)
  116. {
  117. var trackable = to_visit.Dequeue();
  118. if (!parents.ContainsKey(trackable)) continue;
  119. var current_parents = parents[trackable];
  120. foreach (var parent in current_parents)
  121. {
  122. checkpointed_trackables.Add(parent);
  123. if (parents.ContainsKey(parent))
  124. {
  125. to_visit.Enqueue(parent);
  126. }
  127. }
  128. parents.Remove(trackable);
  129. }
  130. // TODO: Complete it after supporting checkpoint.
  131. // for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
  132. // {
  133. // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
  134. // }
  135. }
  136. /// <summary>
  137. /// Traverse the object graph and list all accessible objects.
  138. /// </summary>
  139. /// <param name="object_graph_view"></param>
  140. public static IList<Trackable> list_objects(ObjectGraphView graph_view)
  141. {
  142. return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
  143. }
  144. internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
  145. {
  146. return full_list.Where(x =>
  147. {
  148. var saveables = x.gather_saveables_for_checkpoint();
  149. return saveables is not null && saveables.Count > 0;
  150. });
  151. }
  152. }