using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; namespace Tensorflow.Checkpoint; public static class CheckPointUtils { private static string _ESCAPE_CHAR = "."; public static (IList, IDictionary>, IDictionary, IDictionary>, IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); foreach (var pair in node_paths) { object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); } Dictionary node_ids = new(); for (int i = 0; i < trackable_objects.Count; i++) { node_ids[trackable_objects[i]] = i; } var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); return (trackable_objects, node_paths, node_ids, slot_variables, object_names); } public static IDictionary> serialize_slot_variables(IEnumerable trackable_objects, IDictionary node_ids, IDictionary object_names) { var non_slot_objects = trackable_objects.ToList(); Dictionary> slot_variables = new(); foreach (var trackable in non_slot_objects) { if (trackable is not Optimizer) { continue; } var optim = (Optimizer)trackable; var slot_names = optim.get_slot_names(); foreach (var slot_name in slot_names) { for (int original_variable_node_id = 0; original_variable_node_id < non_slot_objects.Count; original_variable_node_id++) { var original_variable = non_slot_objects[original_variable_node_id]; IVariableV1 slot_variable; if (original_variable is not IVariableV1) { slot_variable = null; } slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); if(slot_variable is null) continue; // There're some problems about the inherits of `Variable` and `Trackable`. throw new NotImplementedException(); } } } return slot_variables; } public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) { if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) { return trackable; } else { return possible_res; } } public static string get_full_name(Trackable variable) { // TODO: This state is not correct, the whole framework need to be updated in the future. if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) { return ""; } // skip the check of attribute `_save_slice_info` . // TODO: Need to be revised!!! Debug.Assert(variable is BaseResourceVariable); return ((BaseResourceVariable)variable).Name; } public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) { HashSet checkpointed_trackables = new(); Dictionary> parents = new(); for (int i = 0; i < object_graph_proto.Nodes.Count; i++) { var object_proto = object_graph_proto.Nodes[i]; // skip the process of registered saver. if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) { checkpointed_trackables.Add(i); } foreach (var child_proto in object_proto.Children) { var child = child_proto.NodeId; if (!parents.ContainsKey(child)) { parents[child] = new HashSet(); } parents[child].Add(i); } } Queue to_visit = new(checkpointed_trackables.AsEnumerable()); while (to_visit.Count > 0) { var trackable = to_visit.Dequeue(); if (!parents.ContainsKey(trackable)) continue; var current_parents = parents[trackable]; foreach (var parent in current_parents) { checkpointed_trackables.Add(parent); if (parents.ContainsKey(parent)) { to_visit.Enqueue(parent); } } parents.Remove(trackable); } // TODO: Complete it after supporting checkpoint. // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) // { // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // } } /// /// Traverse the object graph and list all accessible objects. /// /// public static IList list_objects(ObjectGraphView graph_view) { return objects_ids_and_slot_variables_and_paths(graph_view).Item1; } internal static IEnumerable _objects_with_attributes(IEnumerable full_list) { return full_list.Where(x => { var saveables = x.gather_saveables_for_checkpoint(); return saveables is not null && saveables.Count > 0; }); } }