| @@ -0,0 +1,253 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| internal record class TrackableData( | |||||
| // A trackable in the root Trackable object graph. | |||||
| Trackable trackable, | |||||
| // The index at which the Trackable appears in TrackableObjectGraph.nodes. | |||||
| int node_id, | |||||
| // The BFS-generated path from the root object / used to generate readable checkpoint keys. | |||||
| string object_name, | |||||
| // A list of ObjectReference for each child connected to this Trackable. | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto, | |||||
| // A list of SlotVariableReference to save to the object (only valid for Optimizer objects). | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot_variable_proto, | |||||
| // The object to save to checkpoint. Usually this is the same as `trackable`, | |||||
| // but can differ when the the caller wants to specify a different object to | |||||
| // save. For example, when saving checkpoints asynchronously, variables are | |||||
| // copied to the CPU. `object_to_save` is set as the copied variable. | |||||
| Trackable object_to_save | |||||
| ); | |||||
| public static class SaveUtil | |||||
| { | |||||
| public static (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||||
| { | |||||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||||
| var (tensor_trackables, pystate_trackables, registered_trackables) = split_trackables(trackable_data); | |||||
| var object_graph_proto = fill_object_graph_proto(trackable_data); | |||||
| var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); | |||||
| var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); | |||||
| Dictionary<Tensor, string> feed_additions; | |||||
| if(cache is null) | |||||
| { | |||||
| feed_additions = null; | |||||
| serialized_tensors = serialized_tensors.Concat(get_and_write_tensors_to_serialize(pystate_trackables, node_ids, call_with_mapped_captures, | |||||
| cache, object_graph_proto)).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| else | |||||
| { | |||||
| feed_additions = null; | |||||
| // TODO: deal with cache. | |||||
| throw new NotFiniteNumberException(); | |||||
| } | |||||
| CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
| return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | |||||
| } | |||||
| private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach(var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for(int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| List<TrackableData> trackable_data = new(); | |||||
| foreach(var trackable in trackable_objects) | |||||
| { | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_proto = new(); | |||||
| foreach(var child in graph_view.list_children(trackable)) | |||||
| { | |||||
| children_proto.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
| { | |||||
| NodeId = node_ids[child.Refer], | |||||
| LocalName = child.Name | |||||
| }); | |||||
| } | |||||
| slot_variables.TryGetValue(trackable, out var slot_variable); | |||||
| trackable_data.Add(new TrackableData( | |||||
| trackable: trackable, | |||||
| node_id: node_ids[trackable], | |||||
| object_name: object_names[trackable], | |||||
| children_proto: children_proto, | |||||
| slot_variable_proto: slot_variable??new pbc.RepeatedField<TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>(), | |||||
| object_to_save: CheckPointUtils.get_mapped_trackable(trackable, object_map) | |||||
| )); | |||||
| } | |||||
| return (trackable_data, node_ids); | |||||
| } | |||||
| private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData> trackable_data) | |||||
| { | |||||
| TrackableObjectGraph object_graph_proto = new(); | |||||
| for(int i = 0; i < trackable_data.Count; i++) | |||||
| { | |||||
| var td = trackable_data[i]; | |||||
| Debug.Assert(td.node_id == i); | |||||
| object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto)); | |||||
| } | |||||
| return object_graph_proto; | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates dictionary of tensors to checkpoint, and updates the proto. | |||||
| /// </summary> | |||||
| /// <param name="tensor_trackables"></param> | |||||
| /// <param name="node_ids"></param> | |||||
| /// <param name="call_with_mapped_captures"></param> | |||||
| /// <param name="cache"></param> | |||||
| /// <param name="object_graph_proto"></param> | |||||
| private static IDictionary<Trackable, IDictionary<string, object>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<Trackable, IDictionary<string, object>> serialized_tensors = new(); | |||||
| foreach(var td in tensor_trackables) | |||||
| { | |||||
| // TODO: deal with cache. | |||||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||||
| var trackable = td.object_to_save; | |||||
| IDictionary<string, object> tensor_dict; | |||||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||||
| { | |||||
| (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); | |||||
| } | |||||
| else | |||||
| { | |||||
| tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); | |||||
| } | |||||
| if(trackable is not null) | |||||
| { | |||||
| serialized_tensors[trackable] = tensor_dict; | |||||
| } | |||||
| else | |||||
| { | |||||
| serialized_tensors[Trackable.None] = tensor_dict; | |||||
| } | |||||
| } | |||||
| return serialized_tensors; | |||||
| } | |||||
| private static IDictionary<string, object> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| var trackable = trackable_data.object_to_save; | |||||
| // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. | |||||
| IDictionary<string, object> ret_tensor_dict; | |||||
| if (call_with_mapped_captures) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| ret_tensor_dict = trackable.serialize_to_tensors(); | |||||
| } | |||||
| // TODO: revise the types and complete it | |||||
| Dictionary<string, object> tensor_dict = new(); | |||||
| foreach(var pair in ret_tensor_dict) | |||||
| { | |||||
| var local_name = TrackableUtils.escape_local_name(pair.Key); | |||||
| var maybe_tensor = pair.Value; | |||||
| var checkpoint_key = TrackableUtils.checkpoint_key(trackable_data.object_name, local_name); | |||||
| tensor_dict[checkpoint_key] = maybe_tensor; | |||||
| if(maybe_tensor is SaveSpec) | |||||
| { | |||||
| ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||||
| } | |||||
| if(object_graph_proto is not null) | |||||
| { | |||||
| object_graph_proto.Nodes[trackable_data.node_id].Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
| { | |||||
| Name = local_name, | |||||
| CheckpointKey = checkpoint_key, | |||||
| FullName = CheckPointUtils.get_full_name(trackable) | |||||
| }); | |||||
| } | |||||
| } | |||||
| return tensor_dict; | |||||
| } | |||||
| /// <summary> | |||||
| /// Gets tensors to serialize from a Trackable with legacy SaveableObjects. | |||||
| /// </summary> | |||||
| /// <param name="trackable_data"></param> | |||||
| /// <param name="node_ids"></param> | |||||
| /// <param name="call_with_mapped_captures"></param> | |||||
| /// <param name="object_graph_proto"></param> | |||||
| /// <returns></returns> | |||||
| private static (Trackable, IDictionary<string, object>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids, | |||||
| bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| object_names[trackable_data.trackable] = trackable_data.object_name; | |||||
| Dictionary<Trackable, Trackable> object_map = new(); | |||||
| object_map[trackable_data.trackable] = trackable_data.object_to_save; | |||||
| var (checkpoint_factory_map, _) = SaveUtilV1.get_checkpoint_factories_and_keys(object_names, object_map); | |||||
| var (named_saveable_objects, _) = SaveUtilV1.generate_saveable_objects(checkpoint_factory_map, object_graph_proto, node_ids, object_map, | |||||
| call_with_mapped_captures, saveables_cache: null); | |||||
| var trackable = new SaveableCompatibilityConverter(trackable_data.object_to_save, named_saveable_objects); | |||||
| return (trackable, trackable.serialize_to_tensors()); | |||||
| } | |||||
| private static IDictionary<string, IDictionary<string, Trackable>> get_and_write_registered_savers(IDictionary<string, IList<TrackableData>> registered_trackables, TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| Dictionary<string, IDictionary<string, Trackable>> registered_savers = new(); | |||||
| foreach(var pair in registered_trackables) | |||||
| { | |||||
| foreach(var td in pair.Value) | |||||
| { | |||||
| if (registered_savers.ContainsKey(pair.Key)) | |||||
| { | |||||
| registered_savers[pair.Key] = new Dictionary<string, Trackable>(); | |||||
| } | |||||
| else | |||||
| { | |||||
| registered_savers[pair.Key][td.object_name] = td.object_to_save; | |||||
| } | |||||
| var object_proto = object_graph_proto.Nodes[td.node_id]; | |||||
| // TODO: add APIs and complete it. Now the `TrackableObjectGraph.Types.TrackableObject` lacks `registered_savers`. | |||||
| } | |||||
| } | |||||
| return registered_savers; | |||||
| } | |||||
| private static (IList<TrackableData>, IList<TrackableData>, IDictionary<string, IList<TrackableData>>) split_trackables(IEnumerable<TrackableData> trackable_data) | |||||
| { | |||||
| List<TrackableData> tensor_trackables = new(); | |||||
| List<TrackableData> py_state_trackables = new(); // skip the process of `PyState` for the lack of API. This is only a pleceholder. | |||||
| Dictionary<string, IList<TrackableData>> registered_trackables = new(); | |||||
| foreach(var td in trackable_data) | |||||
| { | |||||
| // TODO: deal with registration. | |||||
| tensor_trackables.Add(td); | |||||
| } | |||||
| return (tensor_trackables, py_state_trackables, registered_trackables); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -7,6 +7,7 @@ using Tensorflow.Train; | |||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Google.Protobuf; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| @@ -47,19 +48,16 @@ public static class SaveUtilV1 | |||||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | ||||
| object? saveables_cache = null) | object? saveables_cache = null) | ||||
| { | { | ||||
| Graph target_context; | |||||
| if (to_graph is not null) | if (to_graph is not null) | ||||
| { | { | ||||
| using (to_graph.as_default()) | |||||
| { | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| to_graph.as_default(); | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| object_map, call_with_mapped_captures, saveables_cache); | object_map, call_with_mapped_captures, saveables_cache); | ||||
| // tensorflow python: `with ops.device("/cpu:0")` | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | |||||
| // tensorflow python: `with ops.device("/cpu:0")` | |||||
| var serialized = graph_proto.ToByteString().ToString(); | |||||
| var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| internal static class SaveableCompat | |||||
| { | |||||
| public static string? get_saveable_name(Trackable cls_or_obj) | |||||
| { | |||||
| // TODO: implement it with Attribute. | |||||
| return null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,109 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class TrackableSaver | |||||
| { | |||||
| private ObjectGraphView _graph_view; | |||||
| private EagerTensor _cached_save_operation; | |||||
| private TrackableObjectGraph _last_save_object_graph; | |||||
| private Tensor? _object_graph_feed_tensor = null; | |||||
| private Tensor? _file_prefix_feed_tensor = null; | |||||
| public TrackableSaver(ObjectGraphView graph_view) | |||||
| { | |||||
| _graph_view = graph_view; | |||||
| // TODO: cache when not executing eagerly. | |||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||||
| } | |||||
| private void gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| private (EagerTensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO: parameter write_done_callback | |||||
| public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||||
| CheckpointOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| Dictionary<Tensor, string> feed_dict = new(); | |||||
| bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||||
| if (checkpoint_number is not null) | |||||
| { | |||||
| file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||||
| } | |||||
| Tensor file_prefix_tensor; | |||||
| Tensor object_graph_tensor; | |||||
| if (use_session) | |||||
| { | |||||
| if (_object_graph_feed_tensor is null) | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); | |||||
| _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); | |||||
| } | |||||
| object_graph_tensor = _object_graph_feed_tensor; | |||||
| file_prefix_tensor = _file_prefix_feed_tensor; | |||||
| feed_dict[file_prefix_tensor] = file_prefix; | |||||
| } | |||||
| else | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); | |||||
| object_graph_tensor = null; | |||||
| } | |||||
| var (save_path, new_feed_additions) = | |||||
| save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||||
| if (new_feed_additions is not null) | |||||
| { | |||||
| foreach (var pair in new_feed_additions) | |||||
| { | |||||
| feed_dict.Add(pair.Key, pair.Value); | |||||
| } | |||||
| } | |||||
| if(!use_session) | |||||
| { | |||||
| session = null; | |||||
| } | |||||
| else if (session is null) | |||||
| { | |||||
| session = new Session(); // In python it uses `get_session`. | |||||
| } | |||||
| if (session is not null) | |||||
| { | |||||
| var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||||
| return session.run((Tensor)save_path, s); | |||||
| } | |||||
| else if (use_session) | |||||
| { | |||||
| throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||||
| "in graph mode without a default session. Please use " + | |||||
| "`with tf.Session():` to create a session."); | |||||
| } | |||||
| else | |||||
| { | |||||
| return save_path; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,9 +21,14 @@ public class TrackableView | |||||
| public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | ||||
| { | { | ||||
| obj._maybe_initialize_trackable(); | obj._maybe_initialize_trackable(); | ||||
| Dictionary<string, Trackable> children = new(); | |||||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | // Note: in python the return type of `Trackable._trackable_children` is not fixed. | ||||
| // Therefore it uses `convert_to_trackable` to have an extra process. | // Therefore it uses `convert_to_trackable` to have an extra process. | ||||
| return obj._trackable_children(save_type); | |||||
| foreach(var pair in obj._trackable_children(save_type)) | |||||
| { | |||||
| children[pair.Key] = pair.Value; | |||||
| } | |||||
| return children; | |||||
| } | } | ||||
| public Trackable Root | public Trackable Root | ||||
| @@ -50,6 +55,7 @@ public class TrackableView | |||||
| { | { | ||||
| List<Trackable> bfs_sorted = new(); | List<Trackable> bfs_sorted = new(); | ||||
| Queue<Trackable> to_visit = new(); | Queue<Trackable> to_visit = new(); | ||||
| to_visit.Enqueue(Root); | |||||
| Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | ||||
| node_paths[this.Root] = new List<TrackableReference>(); | node_paths[this.Root] = new List<TrackableReference>(); | ||||
| while (!to_visit.empty()) | while (!to_visit.empty()) | ||||
| @@ -0,0 +1,191 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| /// <summary> | |||||
| /// Saves and restores a `Trackable` object and its dependencies. | |||||
| /// </summary> | |||||
| public class TrackableSaver | |||||
| { | |||||
| private ObjectGraphView _graph_view; | |||||
| private Tensor _cached_save_operation; | |||||
| private TrackableObjectGraph _last_save_object_graph; | |||||
| private Tensor? _object_graph_feed_tensor = null; | |||||
| private Tensor? _file_prefix_feed_tensor = null; | |||||
| private Dictionary<Trackable, Trackable>? _object_map = null; | |||||
| private object? _cache = null; | |||||
| public TrackableSaver(ObjectGraphView graph_view) | |||||
| { | |||||
| _graph_view = graph_view; | |||||
| // TODO: cache when not executing eagerly. | |||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||||
| } | |||||
| private (IDictionary<Trackable, IDictionary<string, object>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||||
| // TODO: cache. | |||||
| if(object_graph_tensor is null) | |||||
| { | |||||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||||
| object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| } | |||||
| else | |||||
| { | |||||
| feed_additions[object_graph_tensor] = graph_proto.ToString(); | |||||
| } | |||||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| if (serialized_tensors.ContainsKey(Trackable.None)) | |||||
| { | |||||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||||
| } | |||||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||||
| } | |||||
| private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
| Func<(Tensor, IDictionary<Tensor, string>)> run_save = () => | |||||
| { | |||||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
| { | |||||
| var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
| var save_op = saver.save(file_prefix, options); | |||||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||||
| using (ops.control_dependencies(new object[] { save_op })) | |||||
| { | |||||
| _cached_save_operation = array_ops.identity(file_prefix); | |||||
| } | |||||
| _last_save_object_graph = graph_proto; | |||||
| } | |||||
| return (_cached_save_operation, feed_additions); | |||||
| }; | |||||
| if (options.experimental_enable_async_checkpoint) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return run_save(); | |||||
| } | |||||
| private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||||
| Func<(Tensor, IDictionary<Tensor, string>)> run_save = () => | |||||
| { | |||||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||||
| { | |||||
| var saver = new MultiDeviceSaver(serialized_tensors, registered_savers); | |||||
| var save_op = saver.save(file_prefix, options); | |||||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||||
| using (ops.control_dependencies(new object[] {save_op} )) | |||||
| { | |||||
| _cached_save_operation = array_ops.identity(tf.constant(file_prefix)); | |||||
| } | |||||
| _last_save_object_graph = graph_proto; | |||||
| } | |||||
| return (_cached_save_operation, feed_additions); | |||||
| }; | |||||
| if (options.experimental_enable_async_checkpoint) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return run_save(); | |||||
| } | |||||
| // TODO: parameter write_done_callback | |||||
| public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||||
| CheckpointOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| Dictionary<Tensor, string> feed_dict = new(); | |||||
| bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||||
| if (checkpoint_number is not null) | |||||
| { | |||||
| file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||||
| } | |||||
| Tensor file_prefix_tensor; | |||||
| Tensor object_graph_tensor; | |||||
| if (use_session) | |||||
| { | |||||
| if (_object_graph_feed_tensor is null) | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| _object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
| _file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING); | |||||
| } | |||||
| object_graph_tensor = _object_graph_feed_tensor; | |||||
| file_prefix_tensor = _file_prefix_feed_tensor; | |||||
| feed_dict[file_prefix_tensor] = file_prefix; | |||||
| } | |||||
| else | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); | |||||
| object_graph_tensor = null; | |||||
| } | |||||
| var (save_path, new_feed_additions) = | |||||
| save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||||
| if (new_feed_additions is not null) | |||||
| { | |||||
| foreach (var pair in new_feed_additions) | |||||
| { | |||||
| feed_dict.Add(pair.Key, pair.Value); | |||||
| } | |||||
| } | |||||
| if(!use_session) | |||||
| { | |||||
| session = null; | |||||
| } | |||||
| else if (session is null) | |||||
| { | |||||
| session = new Session(); // In python it uses `get_session`. | |||||
| } | |||||
| if (session is not null) | |||||
| { | |||||
| var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||||
| return session.run((Tensor)save_path, s); | |||||
| } | |||||
| else if (use_session) | |||||
| { | |||||
| throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||||
| "in graph mode without a default session. Please use " + | |||||
| "`with tf.Session():` to create a session."); | |||||
| } | |||||
| else | |||||
| { | |||||
| return save_path; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Buffers.Text; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.ApiDef.Types; | |||||
| using static Tensorflow.CostGraphDef.Types; | |||||
| using static Tensorflow.OptimizerOptions.Types; | |||||
| namespace Tensorflow.Checkpoint | |||||
| { | |||||
| /// <summary> | |||||
| /// Saves checkpoints directly from multiple devices. | |||||
| /// Note that this is a low-level utility which stores Tensors in the keys | |||||
| /// specified by `SaveableObject`s.Higher-level utilities for object-based | |||||
| /// checkpointing are built on top of it. | |||||
| /// </summary> | |||||
| public class MultiDeviceSaver | |||||
| { | |||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, object>> serialized_tensors, | |||||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||||
| { | |||||
| } | |||||
| public Operation? save(string file_prefix, CheckpointOptions? options= null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -205,6 +205,16 @@ namespace Tensorflow { | |||||
| slotVariables_ = slot; | slotVariables_ = slot; | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot, | |||||
| pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children | |||||
| ) | |||||
| { | |||||
| OnConstruction(); | |||||
| slotVariables_ = slot; | |||||
| children_ = children; | |||||
| } | |||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| @@ -1,4 +1,10 @@ | |||||
| namespace Tensorflow.Train | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Operations.Activation; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Train | |||||
| { | { | ||||
| public abstract class AutoTrackable : Trackable | public abstract class AutoTrackable : Trackable | ||||
| { | { | ||||
| @@ -17,5 +23,48 @@ | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| if(save_type != SaveType.SAVEDMODEL) | |||||
| { | |||||
| return base._trackable_children(save_type, cache); | |||||
| } | |||||
| Dictionary<string, Trackable> functions = new(); | |||||
| // TODO: process of logs. | |||||
| var properties = this.GetType().GetProperties(); | |||||
| foreach ( var property in properties ) | |||||
| { | |||||
| string name = property.Name; | |||||
| object value = property.GetValue(this, null); | |||||
| if(value is Function || value is ConcreteFunction) | |||||
| { | |||||
| functions[name] = (Trackable)value; | |||||
| } | |||||
| } | |||||
| // TODO: process the type `core_types.GenericFunction`. | |||||
| Dictionary<string, Trackable> children = new(); | |||||
| foreach(var pair in CheckpointDependencies) | |||||
| { | |||||
| var name = pair.Name; | |||||
| var child = pair.Refer; | |||||
| if(child is ConcreteFunction) // or Generic function | |||||
| { | |||||
| continue; | |||||
| } | |||||
| if(functions.ContainsKey(name) && functions[name] != child) | |||||
| { | |||||
| throw new ValueError($"Can't save object because it has multiple children with the same " + | |||||
| $"name. Object: {this}, attribute name: {name}, child 1: " + | |||||
| $"{child}, child 2: {functions[name]}"); | |||||
| } | |||||
| children[name] = child; | |||||
| } | |||||
| return children.Concat(functions).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| public string slice_spec => _slice_spec; | public string slice_spec => _slice_spec; | ||||
| private string _name; | private string _name; | ||||
| public string name => _name; | |||||
| public string name { get => _name; set => _name = value; } | |||||
| private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| @@ -134,35 +134,33 @@ public static partial class SavedModelUtils | |||||
| Dictionary<Trackable, Trackable> object_map; | Dictionary<Trackable, Trackable> object_map; | ||||
| Dictionary<Tensor, Tensor> tensor_map; | Dictionary<Tensor, Tensor> tensor_map; | ||||
| AssetInfo asset_info; | AssetInfo asset_info; | ||||
| using (var g = exported_graph.as_default()) | |||||
| exported_graph.as_default(); | |||||
| (object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||||
| // TODO: deal with signatures. | |||||
| if (save_custom_gradients) | |||||
| { | { | ||||
| (object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||||
| // TODO: deal with signatures. | |||||
| if (save_custom_gradients) | |||||
| { | |||||
| // TODO: trace gradient functions. | |||||
| } | |||||
| // TODO: trace gradient functions. | |||||
| } | |||||
| foreach (var resource_initializer_function in resource_initializers) | |||||
| { | |||||
| // List<Trackable> asset_dependencies = new(); | |||||
| // TODO: deal with initializers | |||||
| } | |||||
| // using(ops.control_dependencies(...)) | |||||
| var init_op = control_flow_ops.no_op(); | |||||
| if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); | |||||
| } | |||||
| else | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); | |||||
| } | |||||
| // Lack `CopyFrom` API | |||||
| // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||||
| foreach (var resource_initializer_function in resource_initializers) | |||||
| { | |||||
| // List<Trackable> asset_dependencies = new(); | |||||
| // TODO: deal with initializers | |||||
| } | } | ||||
| // using(ops.control_dependencies(...)) | |||||
| var init_op = control_flow_ops.no_op(); | |||||
| if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); | |||||
| } | |||||
| else | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); | |||||
| } | |||||
| // Lack `CopyFrom` API | |||||
| // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||||
| foreach (var obj in object_map.Values) | foreach (var obj in object_map.Values) | ||||
| { | { | ||||
| obj._maybe_initialize_trackable(); | obj._maybe_initialize_trackable(); | ||||
| @@ -180,11 +178,13 @@ public static partial class SavedModelUtils | |||||
| verify_ops(graph_def, namespace_whitelist); | verify_ops(graph_def, namespace_whitelist); | ||||
| meta_graph_def.GraphDef = new GraphDef(graph_def); | meta_graph_def.GraphDef = new GraphDef(graph_def); | ||||
| meta_graph_def.MetaInfoDef = new(); | |||||
| meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); | meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); | ||||
| meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; | meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; | ||||
| // TODO: add git version. | // TODO: add git version. | ||||
| meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; | meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; | ||||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | ||||
| meta_graph_def.MetaInfoDef.StrippedOpList = new(); | |||||
| meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); | meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); | ||||
| meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); | meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); | ||||
| @@ -138,5 +138,55 @@ namespace Tensorflow | |||||
| // TODO: implement it. | // TODO: implement it. | ||||
| return false; | return false; | ||||
| } | } | ||||
| internal static string convert_to_string(string x) | |||||
| { | |||||
| return tf.compat.as_str(x); | |||||
| } | |||||
| } | |||||
| public class SaveableCompatibilityConverter: Trackable | |||||
| { | |||||
| private Trackable _obj; | |||||
| private IList<MySaveableObject> _saveables; | |||||
| public SaveableCompatibilityConverter(Trackable obj, IList<MySaveableObject> saveables) | |||||
| { | |||||
| _obj= obj; | |||||
| _saveables= saveables; | |||||
| } | |||||
| public Trackable Obj => _obj; | |||||
| public IList<MySaveableObject> mySaveables=> _saveables; | |||||
| public override IDictionary<string, object> serialize_to_tensors() | |||||
| { | |||||
| return saveable_objects_to_tensor_dict(_saveables); | |||||
| } | |||||
| /// <summary> | |||||
| /// Converts a list of SaveableObjects to a tensor dictionary. | |||||
| /// </summary> | |||||
| /// <param name="saveables"></param> | |||||
| public static Dictionary<string, object> saveable_objects_to_tensor_dict(IList<MySaveableObject> saveables) | |||||
| { | |||||
| Dictionary<string, object> tensor_dict = new(); | |||||
| foreach (var saveable in saveables) | |||||
| { | |||||
| foreach(var spec in saveable.specs) | |||||
| { | |||||
| var name = saveable_object_util.convert_to_string(spec.name); | |||||
| var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| tensor_dict[name] = spec.tensor; | |||||
| } | |||||
| } | |||||
| } | |||||
| return tensor_dict; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -34,18 +34,35 @@ namespace Tensorflow.Train | |||||
| public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; | public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; | ||||
| } | } | ||||
| protected int _self_update_uid; | protected int _self_update_uid; | ||||
| protected IDictionary<string, Trackable> _unconditional_dependency_names = | |||||
| new Dictionary<string, Trackable>(); | |||||
| protected IDictionary<string, Trackable> _unconditional_dependency_names; | |||||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||||
| protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | ||||
| new Dictionary<string, ResourceVariable>(); | new Dictionary<string, ResourceVariable>(); | ||||
| private static Trackable _none = new Function(); | |||||
| /// <summary> | |||||
| /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. | |||||
| /// The `None` can be any object that inherits `Trackable`. | |||||
| /// This Property is supposed to be used only internal. | |||||
| /// </summary> | |||||
| public static Trackable None | |||||
| { | |||||
| get | |||||
| { | |||||
| return _none; | |||||
| } | |||||
| } | |||||
| public virtual string ObjectIdentifier | public virtual string ObjectIdentifier | ||||
| { | { | ||||
| get => "_generic_user_object"; | get => "_generic_user_object"; | ||||
| } | } | ||||
| public int UpdateUid { get => _self_update_uid; set => _self_update_uid = value; } | |||||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -99,8 +116,9 @@ namespace Tensorflow.Train | |||||
| /// </summary> | /// </summary> | ||||
| public void _maybe_initialize_trackable() | public void _maybe_initialize_trackable() | ||||
| { | { | ||||
| // _self_unconditional_checkpoint_dependencies = [] | |||||
| _self_update_uid = -1; | _self_update_uid = -1; | ||||
| _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||||
| _unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||||
| } | } | ||||
| // TODO: cache | // TODO: cache | ||||
| @@ -153,6 +171,20 @@ namespace Tensorflow.Train | |||||
| { | { | ||||
| return _self_saveable_object_factories; | return _self_saveable_object_factories; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Gathers tensors to save to the checkpoint. You should only override `serialize_to_tensors` and `restore_from_tensors` | |||||
| /// if you are defining a custom resource or variable with custom ops. | |||||
| /// Otherwise, please store the state of your trackable in `tf.Variable` objects | |||||
| /// and add them to Trackable object hierarchy using `setattr` (for subclasses | |||||
| /// of `AutoTrackable`) or overriding the `_trackable_children` method. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| /// <exception cref="NotImplementedException"></exception> | |||||
| public virtual IDictionary<string, object> serialize_to_tensors() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| public record class TrackableReference(string Name, Trackable Refer); | public record class TrackableReference(string Name, Trackable Refer); | ||||