| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| @@ -23,6 +25,26 @@ namespace Tensorflow | |||||
| public class CompatApi | public class CompatApi | ||||
| { | { | ||||
| public CompatV1Api v1 { get; } = new CompatV1Api(); | public CompatV1Api v1 { get; } = new CompatV1Api(); | ||||
| internal string as_text(string bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| if(encoding is null) encoding = Encoding.UTF8; | |||||
| return bytes_or_text; | |||||
| } | |||||
| internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| if(encoding is null) encoding = Encoding.UTF8; | |||||
| return encoding.GetString(bytes_or_text); | |||||
| } | |||||
| internal string as_str(string bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| return as_text(bytes_or_text, encoding); | |||||
| } | |||||
| internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) | |||||
| { | |||||
| return as_text(bytes_or_text, encoding); | |||||
| } | |||||
| } | } | ||||
| public bool executing_eagerly() | public bool executing_eagerly() | ||||
| @@ -0,0 +1,150 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public static class CheckPointUtils | |||||
| { | |||||
| private static string _ESCAPE_CHAR = "."; | |||||
| public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>, | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||||
| Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach (var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| return (trackable_objects, node_paths, node_ids, slot_variables, object_names); | |||||
| } | |||||
| public static | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| serialize_slot_variables(IEnumerable<Trackable> trackable_objects, | |||||
| IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names) | |||||
| { | |||||
| var non_slot_objects = trackable_objects.ToList(); | |||||
| Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| slot_variables = new(); | |||||
| foreach (var trackable in non_slot_objects) | |||||
| { | |||||
| if (trackable is not Optimizer) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| var optim = (Optimizer)trackable; | |||||
| var slot_names = optim.get_slot_names(); | |||||
| foreach (var slot_name in slot_names) | |||||
| { | |||||
| for (int original_variable_node_id = 0; | |||||
| original_variable_node_id < non_slot_objects.Count; | |||||
| original_variable_node_id++) | |||||
| { | |||||
| var original_variable = non_slot_objects[original_variable_node_id]; | |||||
| IVariableV1 slot_variable; | |||||
| if (original_variable is not IVariableV1) | |||||
| { | |||||
| slot_variable = null; | |||||
| } | |||||
| slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); | |||||
| if(slot_variable is null) continue; | |||||
| // There're some problems about the inherits of `Variable` and `Trackable`. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| return slot_variables; | |||||
| } | |||||
| public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map) | |||||
| { | |||||
| if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) | |||||
| { | |||||
| return trackable; | |||||
| } | |||||
| else | |||||
| { | |||||
| return possible_res; | |||||
| } | |||||
| } | |||||
| public static string get_full_name(Trackable var) | |||||
| { | |||||
| // TODO: This state is not correct, the whole framework need to be updated in the future. | |||||
| if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) | |||||
| { | |||||
| return ""; | |||||
| } | |||||
| // skip the check of attribute `_save_slice_info` . | |||||
| // TODO: Need to be revised!!! | |||||
| return ((ResourceVariable)(object)var).Name; | |||||
| } | |||||
| public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) | |||||
| { | |||||
| HashSet<int> checkpointed_trackables = new(); | |||||
| Dictionary<int, HashSet<int>> parents = new(); | |||||
| for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
| { | |||||
| var object_proto = object_graph_proto.Nodes[i]; | |||||
| // skip the process of registered saver. | |||||
| if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || | |||||
| object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) | |||||
| { | |||||
| checkpointed_trackables.Add(i); | |||||
| } | |||||
| foreach (var child_proto in object_proto.Children) | |||||
| { | |||||
| var child = child_proto.NodeId; | |||||
| if (!parents.ContainsKey(child)) | |||||
| { | |||||
| parents[child] = new HashSet<int>(); | |||||
| } | |||||
| parents[child].Add(i); | |||||
| } | |||||
| } | |||||
| Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable()); | |||||
| while (to_visit.Count > 0) | |||||
| { | |||||
| var trackable = to_visit.Dequeue(); | |||||
| if (!parents.ContainsKey(trackable)) continue; | |||||
| var current_parents = parents[trackable]; | |||||
| foreach (var parent in current_parents) | |||||
| { | |||||
| checkpointed_trackables.Add(parent); | |||||
| if (parents.ContainsKey(parent)) | |||||
| { | |||||
| to_visit.Enqueue(parent); | |||||
| } | |||||
| } | |||||
| parents.Remove(trackable); | |||||
| } | |||||
| // TODO: Complete it after supporting checkpoint. | |||||
| // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||||
| // { | |||||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||||
| // } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,5 @@ | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public record class CheckpointOptions( | |||||
| string experimental_io_device = null, | |||||
| bool experimental_enable_async_checkpoint = false); | |||||
| @@ -0,0 +1,63 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Serilog.Debugging; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class ObjectGraphView: TrackableView, ICloneable | |||||
| { | |||||
| protected IEnumerable<TrackableReference>? _attached_dependencies; | |||||
| // TODO: attached_dependencies | |||||
| public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root) | |||||
| { | |||||
| _attached_dependencies = attached_dependencies; | |||||
| } | |||||
| public object Clone() | |||||
| { | |||||
| // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ | |||||
| return new ObjectGraphView(Root, _attached_dependencies); | |||||
| } | |||||
| public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
| { | |||||
| List<TrackableReference> res = base.children(obj, save_type) | |||||
| .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||||
| // Check the reference, not value. | |||||
| if (obj == Root && _attached_dependencies is not null) | |||||
| { | |||||
| res.AddRange(_attached_dependencies); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
| { | |||||
| return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); | |||||
| } | |||||
| public IEnumerable<TrackableReference>? AttachedDependencies | |||||
| { | |||||
| get => _attached_dependencies; | |||||
| } | |||||
| public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| { | |||||
| return base._descendants_with_paths(); | |||||
| } | |||||
| // TODO: complete the implementation | |||||
| public void serialize_object_graph(object? saveables_cache = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO: complete the implementation | |||||
| public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,229 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Exceptions; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public static class SaveUtilV1 | |||||
| { | |||||
| public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||||
| IDictionary<Trackable, Trackable>? object_map = null) | |||||
| { | |||||
| // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | |||||
| // till now only internal registrations are allowed. So, we won't return a saver in this function. | |||||
| // The implementation of this function should be updated if tensorflow update it. | |||||
| Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new(); | |||||
| foreach (var pair in object_names) | |||||
| { | |||||
| var trackable = pair.Key; | |||||
| var object_name = pair.Value; | |||||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
| // skip the registration process. | |||||
| List<CheckpointFactoryData> current_list = new(); | |||||
| foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) | |||||
| { | |||||
| // treat name as key_suffix. | |||||
| var name = name_and_factory.Key; | |||||
| var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); | |||||
| current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); | |||||
| } | |||||
| checkpoint_factory_map[trackable] = current_list; | |||||
| } | |||||
| return (checkpoint_factory_map, null); | |||||
| } | |||||
| public static (List<MySaveableObject>, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||||
| object? saveables_cache = null) | |||||
| { | |||||
| Graph target_context; | |||||
| if (to_graph is not null) | |||||
| { | |||||
| using (to_graph.as_default()) | |||||
| { | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| object_map, call_with_mapped_captures, saveables_cache); | |||||
| // tensorflow python: `with ops.device("/cpu:0")` | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| using (new ops.NullContextManager()) | |||||
| { | |||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||||
| object_map, call_with_mapped_captures, saveables_cache); | |||||
| // tensorflow python: `with ops.device("/cpu:0")` | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||||
| return (named_saveable_objects, registered_savers); | |||||
| } | |||||
| } | |||||
| } | |||||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||||
| Dictionary<Trackable, string> object_names = new(); | |||||
| foreach (var pair in node_paths) | |||||
| { | |||||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||||
| } | |||||
| Dictionary<Trackable, int> node_ids = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| node_ids[trackable_objects[i]] = i; | |||||
| } | |||||
| var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||||
| var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); | |||||
| var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( | |||||
| trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, | |||||
| saveables_cache); | |||||
| CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||||
| return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); | |||||
| } | |||||
| private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects, | |||||
| IDictionary<Trackable, int> node_ids, | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| slot_variables) | |||||
| { | |||||
| TrackableObjectGraph object_graph_proto = new(); | |||||
| for (int i = 0; i < trackable_objects.Count; i++) | |||||
| { | |||||
| var trackable = trackable_objects[i]; | |||||
| Debug.Assert(node_ids[trackable] == i); | |||||
| TrackableObjectGraph.Types.TrackableObject object_proto; | |||||
| if (slot_variables.TryGetValue(trackable, out var slots)) | |||||
| { | |||||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||||
| } | |||||
| else | |||||
| { | |||||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||||
| } | |||||
| object_graph_proto.Nodes.Add(object_proto); | |||||
| foreach (var child in graph_view.list_children(trackable)) | |||||
| { | |||||
| object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||||
| { NodeId = node_ids[child.Refer], LocalName = child.Name }); | |||||
| } | |||||
| } | |||||
| return object_graph_proto; | |||||
| } | |||||
| private static (List<MySaveableObject>, object?, object?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||||
| bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); | |||||
| for (int i = 0; i < cnt; i++) | |||||
| { | |||||
| Debug.Assert(node_ids[trackable_objects[i]] == i); | |||||
| } | |||||
| var (checkpoint_factory_map, unmmaped_registered_savers) = | |||||
| get_checkpoint_factories_and_keys(object_names, object_map); | |||||
| // skip the process of registered savers | |||||
| var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, | |||||
| object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); | |||||
| return (named_saveable_objects, feed_additions, null); | |||||
| } | |||||
| public static (List<MySaveableObject>, object?) generate_saveable_objects( | |||||
| IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | |||||
| TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | |||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||||
| { | |||||
| List<MySaveableObject> named_saveable_objects = new(); | |||||
| foreach (var pair in checkpoint_factory_map) | |||||
| { | |||||
| var trackable = pair.Key; | |||||
| var factory_data_list = pair.Value; | |||||
| bool fill_object_proto = object_graph_proto is not null && node_ids is not null; | |||||
| TrackableObjectGraph.Types.TrackableObject object_proto = null!; | |||||
| if (fill_object_proto) | |||||
| { | |||||
| object_proto = object_graph_proto.Nodes[node_ids[trackable]]; | |||||
| } | |||||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||||
| // skip cache | |||||
| foreach (var factory_data in factory_data_list) | |||||
| { | |||||
| var name = factory_data.name; | |||||
| var key = factory_data.checkpoint_key; | |||||
| var saveable_factory = factory_data.factory; | |||||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||||
| var maybe_saveable = saveable_factory; | |||||
| IEnumerable<MySaveableObject> savesbles; | |||||
| if (maybe_saveable is MySaveableObject) | |||||
| { | |||||
| savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable }; | |||||
| } | |||||
| else if (maybe_saveable is Tensor) | |||||
| { | |||||
| savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError("Unexpected type."); | |||||
| } | |||||
| foreach (var saveable in savesbles) | |||||
| { | |||||
| if (!saveable.name.Contains(key)) | |||||
| { | |||||
| throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + | |||||
| $"'{saveable.name}' for attribute '{name}'. Expected a name" + | |||||
| $" containing '{key}'."); | |||||
| } | |||||
| } | |||||
| // skip the process of PythonState | |||||
| named_saveable_objects.AddRange(savesbles); | |||||
| if(!fill_object_proto) continue; | |||||
| // skip the process of TrackableSaveable | |||||
| object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||||
| { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | |||||
| } | |||||
| } | |||||
| return (named_saveable_objects, null); | |||||
| } | |||||
| } | |||||
| public record class CheckpointFactoryData | |||||
| ( | |||||
| object factory, | |||||
| string name, | |||||
| string checkpoint_key | |||||
| ); | |||||
| @@ -0,0 +1,109 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class TrackableSaver | |||||
| { | |||||
| private ObjectGraphView _graph_view; | |||||
| private EagerTensor _cached_save_operation; | |||||
| private TrackableObjectGraph _last_save_object_graph; | |||||
| private Tensor? _object_graph_feed_tensor = null; | |||||
| private Tensor? _file_prefix_feed_tensor = null; | |||||
| public TrackableSaver(ObjectGraphView graph_view) | |||||
| { | |||||
| _graph_view = graph_view; | |||||
| // TODO: cache when not executing eagerly. | |||||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||||
| } | |||||
| private void gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| private (EagerTensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO: parameter write_done_callback | |||||
| public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||||
| CheckpointOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| Dictionary<Tensor, string> feed_dict = new(); | |||||
| bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||||
| if (checkpoint_number is not null) | |||||
| { | |||||
| file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||||
| } | |||||
| Tensor file_prefix_tensor; | |||||
| Tensor object_graph_tensor; | |||||
| if (use_session) | |||||
| { | |||||
| if (_object_graph_feed_tensor is null) | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); | |||||
| _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); | |||||
| } | |||||
| object_graph_tensor = _object_graph_feed_tensor; | |||||
| file_prefix_tensor = _file_prefix_feed_tensor; | |||||
| feed_dict[file_prefix_tensor] = file_prefix; | |||||
| } | |||||
| else | |||||
| { | |||||
| // In python there is `with ops.device("/cpu:0")`. | |||||
| file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); | |||||
| object_graph_tensor = null; | |||||
| } | |||||
| var (save_path, new_feed_additions) = | |||||
| save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||||
| if (new_feed_additions is not null) | |||||
| { | |||||
| foreach (var pair in new_feed_additions) | |||||
| { | |||||
| feed_dict.Add(pair.Key, pair.Value); | |||||
| } | |||||
| } | |||||
| if(!use_session) | |||||
| { | |||||
| session = null; | |||||
| } | |||||
| else if (session is null) | |||||
| { | |||||
| session = new Session(); // In python it uses `get_session`. | |||||
| } | |||||
| if (session is not null) | |||||
| { | |||||
| var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||||
| return session.run((Tensor)save_path, s); | |||||
| } | |||||
| else if (use_session) | |||||
| { | |||||
| throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||||
| "in graph mode without a default session. Please use " + | |||||
| "`with tf.Session():` to create a session."); | |||||
| } | |||||
| else | |||||
| { | |||||
| return save_path; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,75 @@ | |||||
| using System; | |||||
| using Tensorflow.Train; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| public class TrackableView | |||||
| { | |||||
| protected WeakReference<Trackable> _root_ref; | |||||
| public TrackableView(Trackable obj) | |||||
| { | |||||
| _root_ref = new WeakReference<Trackable>(obj); | |||||
| } | |||||
| public TrackableView(WeakReference<Trackable> obj) | |||||
| { | |||||
| _root_ref = obj; | |||||
| } | |||||
| public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
| { | |||||
| obj._maybe_initialize_trackable(); | |||||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||||
| // Therefore it uses `convert_to_trackable` to have an extra process. | |||||
| return obj._trackable_children(save_type); | |||||
| } | |||||
| public Trackable Root | |||||
| { | |||||
| get | |||||
| { | |||||
| if (_root_ref.TryGetTarget(out Trackable res)) | |||||
| { | |||||
| return res; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new InvalidDataException( | |||||
| "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | |||||
| /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | |||||
| /// </summary> | |||||
| protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||||
| { | |||||
| List<Trackable> bfs_sorted = new(); | |||||
| Queue<Trackable> to_visit = new(); | |||||
| Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | |||||
| node_paths[this.Root] = new List<TrackableReference>(); | |||||
| while (!to_visit.empty()) | |||||
| { | |||||
| var current_trackable = to_visit.Dequeue(); | |||||
| bfs_sorted.Add(current_trackable); | |||||
| var children_dict = this.children(current_trackable); | |||||
| foreach (var name in children_dict.Keys) | |||||
| { | |||||
| var dependency = children_dict[name]; | |||||
| if (!node_paths.ContainsKey(dependency)) | |||||
| { | |||||
| var list = new List<TrackableReference>(node_paths[current_trackable]); | |||||
| list.Add(new TrackableReference(name, dependency)); | |||||
| node_paths[dependency] = list; | |||||
| to_visit.Enqueue(dependency); | |||||
| } | |||||
| } | |||||
| } | |||||
| return (bfs_sorted, node_paths); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| namespace Tensorflow.Exceptions; | |||||
| public class AssertionError : TensorflowException | |||||
| { | |||||
| public AssertionError() : base() | |||||
| { | |||||
| } | |||||
| public AssertionError(string message) : base(message) | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -304,7 +304,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
| public static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||||
| { | { | ||||
| var used_ops = ops_used_by_graph_def(graph_def); | var used_ops = ops_used_by_graph_def(graph_def); | ||||
| @@ -345,5 +345,66 @@ namespace Tensorflow | |||||
| return used_ops.ToArray(); | return used_ops.ToArray(); | ||||
| } | } | ||||
| private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) | |||||
| { | |||||
| foreach (var attr_def in op_def.Attr) | |||||
| { | |||||
| if (attr_def.Name == attr_name) | |||||
| { | |||||
| if (attr_def.DefaultValue is null) return false; | |||||
| // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) | |||||
| { | |||||
| Dictionary<string, FunctionDef> op_name_to_function = new(); | |||||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
| { | |||||
| op_name_to_function[function_def.Signature.Name] = function_def; | |||||
| } | |||||
| Action<NodeDef> _strip_node_default_valued_attrs = (node_def) => | |||||
| { | |||||
| if (op_name_to_function.ContainsKey(node_def.Op)) return; | |||||
| var op_def = op_def_registry.GetOpDef(node_def.Op); | |||||
| if(op_def is null) return; | |||||
| HashSet<string> attrs_to_strip = new(); | |||||
| foreach (var attr in node_def.Attr) | |||||
| { | |||||
| if (is_default_attr_value(op_def, attr.Key, attr.Value)) | |||||
| { | |||||
| attrs_to_strip.Add(attr.Key); | |||||
| } | |||||
| } | |||||
| foreach (var attr in attrs_to_strip) | |||||
| { | |||||
| node_def.Attr.Remove(attr); | |||||
| } | |||||
| }; | |||||
| foreach (var node_def in meta_graph_def.GraphDef.Node) | |||||
| { | |||||
| _strip_node_default_valued_attrs(node_def); | |||||
| } | |||||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||||
| { | |||||
| foreach (var function_node_def in function_def.NodeDef) | |||||
| { | |||||
| _strip_node_default_valued_attrs(function_node_def); | |||||
| } | |||||
| } | |||||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Functions | namespace Tensorflow.Functions | ||||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Functions | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| public class ConcreteFunction | |||||
| public class ConcreteFunction: Trackable | |||||
| { | { | ||||
| FuncGraph func_graph; | FuncGraph func_graph; | ||||
| ForwardBackwardCall forward_backward; | ForwardBackwardCall forward_backward; | ||||
| @@ -1,16 +1,23 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class Function | |||||
| public class Function: Trackable | |||||
| { | { | ||||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | #pragma warning restore CS0169 // The field 'Function._handle' is never used | ||||
| public string Name { get; set; } | |||||
| public Function() | public Function() | ||||
| { | { | ||||
| } | } | ||||
| public Function(string name) | |||||
| { | |||||
| Name = name; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving | |||||
| /// </summary> | /// </summary> | ||||
| public class SaveOptions | public class SaveOptions | ||||
| { | { | ||||
| bool save_debug_info; | |||||
| public bool save_debug_info = false; | |||||
| public IList<string>? namespace_white_list { get; set; } = null; | |||||
| public IDictionary<string, object>? function_aliases { get; set; } = null; | |||||
| public string? experimental_io_device { get; set; } = null; | |||||
| // TODO: experimental | |||||
| public Object? experimental_variable_polict { get; set; } = null; | |||||
| public bool experimental_custom_gradients { get; set; } = true; | |||||
| public SaveOptions(bool save_debug_info = false) | public SaveOptions(bool save_debug_info = false) | ||||
| { | { | ||||
| this.save_debug_info = save_debug_info; | this.save_debug_info = save_debug_info; | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -38,6 +39,11 @@ namespace Tensorflow | |||||
| { | { | ||||
| return var is ResourceVariable; | return var is ResourceVariable; | ||||
| } | } | ||||
| public static bool is_resource_variable(Trackable var) | |||||
| { | |||||
| return var is BaseResourceVariable; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a variable handle with information to do shape inference. | /// Creates a variable handle with information to do shape inference. | ||||
| @@ -156,7 +156,7 @@ namespace Tensorflow { | |||||
| /// Nodes[0] is considered the root node. | /// Nodes[0] is considered the root node. | ||||
| /// </summary> | /// </summary> | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||||
| get { return nodes_; } | get { return nodes_; } | ||||
| } | } | ||||
| @@ -286,6 +286,7 @@ namespace Tensorflow { | |||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| public SavedObject(SavedObject other) : this() { | public SavedObject(SavedObject other) : this() { | ||||
| children_ = other.children_.Clone(); | children_ = other.children_.Clone(); | ||||
| dependencies_ = other.dependencies_.Clone(); | |||||
| slotVariables_ = other.slotVariables_.Clone(); | slotVariables_ = other.slotVariables_.Clone(); | ||||
| saveableObjects_ = other.saveableObjects_.Clone(); | saveableObjects_ = other.saveableObjects_.Clone(); | ||||
| switch (other.KindCase) { | switch (other.KindCase) { | ||||
| @@ -328,6 +329,7 @@ namespace Tensorflow { | |||||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | ||||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | ||||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Objects which this object depends on: named edges in the dependency | /// Objects which this object depends on: named edges in the dependency | ||||
| /// graph. | /// graph. | ||||
| @@ -338,6 +340,11 @@ namespace Tensorflow { | |||||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | ||||
| get { return children_; } | get { return children_; } | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies { | |||||
| get { return dependencies_; } | |||||
| } | |||||
| /// <summary>Field number for the "slot_variables" field.</summary> | /// <summary>Field number for the "slot_variables" field.</summary> | ||||
| public const int SlotVariablesFieldNumber = 3; | public const int SlotVariablesFieldNumber = 3; | ||||
| @@ -617,6 +624,7 @@ namespace Tensorflow { | |||||
| return; | return; | ||||
| } | } | ||||
| children_.Add(other.children_); | children_.Add(other.children_); | ||||
| dependencies_.Add(other.dependencies_); | |||||
| slotVariables_.Add(other.slotVariables_); | slotVariables_.Add(other.slotVariables_); | ||||
| saveableObjects_.Add(other.saveableObjects_); | saveableObjects_.Add(other.saveableObjects_); | ||||
| switch (other.KindCase) { | switch (other.KindCase) { | ||||
| @@ -198,6 +198,12 @@ namespace Tensorflow { | |||||
| public TrackableObject() { | public TrackableObject() { | ||||
| OnConstruction(); | OnConstruction(); | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) { | |||||
| OnConstruction(); | |||||
| slotVariables_ = slot; | |||||
| } | |||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| @@ -2,5 +2,20 @@ | |||||
| { | { | ||||
| public abstract class AutoTrackable : Trackable | public abstract class AutoTrackable : Trackable | ||||
| { | { | ||||
| public void _delete_tracking(string name) | |||||
| { | |||||
| _maybe_initialize_trackable(); | |||||
| if (_unconditional_dependency_names.ContainsKey(name)) | |||||
| { | |||||
| _unconditional_dependency_names.Remove(name); | |||||
| for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) | |||||
| { | |||||
| if (_unconditional_checkpoint_dependencies[i].Name == name) | |||||
| { | |||||
| _unconditional_checkpoint_dependencies.RemoveAt(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -351,7 +351,7 @@ namespace Tensorflow | |||||
| /// <param name="var"></param> | /// <param name="var"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected IVariableV1 get_slot(IVariableV1 var, string name) | |||||
| internal IVariableV1 get_slot(IVariableV1 var, string name) | |||||
| { | { | ||||
| var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | ||||
| if (named_slots == null) | if (named_slots == null) | ||||
| @@ -360,6 +360,11 @@ namespace Tensorflow | |||||
| return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | ||||
| } | } | ||||
| internal IEnumerable<string> get_slot_names() | |||||
| { | |||||
| return _slots.Keys; | |||||
| } | |||||
| private string _var_key(IVariableV1 var) | private string _var_key(IVariableV1 var) | ||||
| { | { | ||||
| return $"{var.Op.graph.graph_key}.{var.Op.name}"; | return $"{var.Op.graph.graph_key}.{var.Op.name}"; | ||||
| @@ -48,4 +48,18 @@ namespace Tensorflow | |||||
| validate_shape: restored_shapes == null && op.shape.IsFullyDefined); | validate_shape: restored_shapes == null && op.shape.IsFullyDefined); | ||||
| } | } | ||||
| } | } | ||||
| public class NoRestoreSaveable: MySaveableObject | |||||
| { | |||||
| public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, | |||||
| new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) | |||||
| { | |||||
| } | |||||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||||
| { | |||||
| return control_flow_ops.no_op(); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow; | |||||
| public record class AssetInfo | |||||
| ( | |||||
| List<AssetFileDef> asset_defs, | |||||
| Dictionary<object, object> asset_initializers_by_resource, | |||||
| Dictionary<AssetInfo, string> asset_filename_map, | |||||
| Dictionary<object, object> asset_index | |||||
| ); | |||||
| @@ -0,0 +1,60 @@ | |||||
| using System; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Train; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Functions; | |||||
| namespace Tensorflow; | |||||
| public class AugmentedGraphView: ObjectGraphView | |||||
| { | |||||
| // private object _children_cache; | |||||
| // private object _serialization_cache; | |||||
| private List<string> _untraces_functions; | |||||
| public AugmentedGraphView(Trackable root): base(root) | |||||
| { | |||||
| _untraces_functions = new(); | |||||
| } | |||||
| public void set_signature(object signature_map, object wrapped_functions) | |||||
| { | |||||
| // TODO: cache | |||||
| list_children(Root); | |||||
| } | |||||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
| { | |||||
| Dictionary<string, Trackable> children = new(); | |||||
| foreach (var pair in base.list_children(obj, save_type)) | |||||
| { | |||||
| var name = pair.Name; | |||||
| var child = pair.Refer; | |||||
| children[name] = child; | |||||
| } | |||||
| if (obj is Function && children.Count == 0) | |||||
| { | |||||
| _untraces_functions.Add(((Function)obj).Name); | |||||
| } | |||||
| return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||||
| } | |||||
| public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| { | |||||
| // TODO: implement it if needed. | |||||
| return base.breadth_first_traversal(); | |||||
| } | |||||
| public List<(string, Trackable)> list_dependencies(Trackable obj) | |||||
| { | |||||
| // TODO: deal with cache. | |||||
| return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); | |||||
| } | |||||
| public Trackable get_child(Trackable obj, string name) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| namespace Tensorflow; | |||||
| public static class Constants | |||||
| { | |||||
| public static readonly string ASSETS_DIRECTORY = "assets"; | |||||
| public static readonly string ASSETS_KEY = "saved_model_assets"; | |||||
| public static readonly string DEBUG_DIRECTORY = "debug"; | |||||
| public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; | |||||
| public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; | |||||
| public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; | |||||
| public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; | |||||
| public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; | |||||
| public static readonly string MAIN_OP_KEY = "saved_model_main_op"; | |||||
| public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; | |||||
| public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; | |||||
| public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; | |||||
| public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; | |||||
| public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; | |||||
| public static readonly string VARIABLES_DIRECTORY = "variables"; | |||||
| public static readonly string VARIABLES_FILENAME = "variables"; | |||||
| } | |||||
| @@ -0,0 +1,17 @@ | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow; | |||||
| public class RevivedTypes | |||||
| { | |||||
| /// <summary> | |||||
| /// Create a SavedUserObject from a trackable object. | |||||
| /// </summary> | |||||
| /// <param name="obj"></param> | |||||
| /// <returns></returns> | |||||
| public static SavedUserObject? serialize(Trackable obj) | |||||
| { | |||||
| // TODO: complete the implementation. | |||||
| return null; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,9 @@ | |||||
| using System; | |||||
| namespace Tensorflow; | |||||
| public enum SaveType | |||||
| { | |||||
| SAVEDMODEL, | |||||
| CHECKPOINT | |||||
| } | |||||
| @@ -0,0 +1,299 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow; | |||||
| public class SaveableView | |||||
| { | |||||
| private AugmentedGraphView _augmented_graph_view; | |||||
| private SaveOptions _options; | |||||
| private List<Trackable> _trackable_objects; | |||||
| private List<Trackable> _nodes; | |||||
| private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||||
| private Dictionary<Trackable, int> _node_ids; | |||||
| private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||||
| _slot_variables; | |||||
| private Dictionary<Trackable, string> _object_names; | |||||
| private List<object> _gradient_functions; // to be completed | |||||
| private List<RegisteredGradient> _gradient_defs; // to be completed | |||||
| private List<ConcreteFunction> _concrete_functions; | |||||
| private Dictionary<Tensor, int> _captured_tensor_node_ids; | |||||
| private Dictionary<Trackable, IDictionary<string, ConcreteFunction>> _saveable_objects_map; | |||||
| private Dictionary<Trackable, string> _obj_to_registered_saver; | |||||
| public AugmentedGraphView AugmentedGraphView | |||||
| { | |||||
| get => _augmented_graph_view; | |||||
| } | |||||
| public Trackable Root | |||||
| { | |||||
| get => _nodes[0]; | |||||
| } | |||||
| public List<Trackable> Nodes | |||||
| { | |||||
| get => _nodes; | |||||
| } | |||||
| public Dictionary<Trackable, int> NodeIds | |||||
| { | |||||
| get => _node_ids; | |||||
| } | |||||
| public List<RegisteredGradient> GradientDefs | |||||
| { | |||||
| get => _gradient_defs; | |||||
| } | |||||
| public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||||
| { | |||||
| get => _node_paths; | |||||
| } | |||||
| public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) | |||||
| { | |||||
| _augmented_graph_view = augmented_graph_view; | |||||
| _options = options; | |||||
| (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = | |||||
| CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); | |||||
| // TODO: deal with untraced functions. | |||||
| initialize_save_and_restore_functions(); | |||||
| initialize_nodes_and_concrete_functions(); | |||||
| _captured_tensor_node_ids = new(); | |||||
| } | |||||
| private void initialize_save_and_restore_functions() | |||||
| { | |||||
| // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. | |||||
| SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); | |||||
| // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. | |||||
| _obj_to_registered_saver = new(); | |||||
| _saveable_objects_map = new(); | |||||
| } | |||||
| private void initialize_nodes_and_concrete_functions() | |||||
| { | |||||
| _nodes = _trackable_objects.ConvertAll(x => x); // deep copy | |||||
| _gradient_functions = new(); | |||||
| _gradient_defs = new(); | |||||
| // TODO: deal with the condition that obj in `_saveable_objects_map`. | |||||
| // foreach (var obj in _nodes) | |||||
| // { | |||||
| // | |||||
| // } | |||||
| foreach (var obj in _nodes) | |||||
| { | |||||
| if (obj is ConcreteFunction) | |||||
| { | |||||
| _concrete_functions.Add((ConcreteFunction)obj); | |||||
| } | |||||
| } | |||||
| } | |||||
| public List<ConcreteFunction> get_concrete_resource_initializers() | |||||
| { | |||||
| // TODO: complete the implementation. | |||||
| return new List<ConcreteFunction>(); | |||||
| } | |||||
| public (Dictionary<Trackable, Trackable>, Dictionary<Tensor, Tensor>, AssetInfo) map_resources() | |||||
| { | |||||
| Debug.Assert(!tf.Context.executing_eagerly()); | |||||
| Dictionary<Trackable, Trackable> object_map = new(); | |||||
| Dictionary<Tensor, Tensor> tensor_map = new(); | |||||
| AssetInfo assetInfo = new(new List<AssetFileDef>(), new Dictionary<object, object>(), | |||||
| new Dictionary<AssetInfo, string>(), new Dictionary<object, object>()); | |||||
| foreach (var node_id in dependency_sorted_node_ids()) | |||||
| { | |||||
| var obj = _nodes[node_id]; | |||||
| var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); | |||||
| // TODO: deal with Asset (if obj is Asset) | |||||
| foreach (var tensor in tensors) | |||||
| { | |||||
| _captured_tensor_node_ids[tensor] = node_id; | |||||
| } | |||||
| } | |||||
| return (object_map, tensor_map, assetInfo); | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns topologically sorted nodes, sorted by dependencies. | |||||
| /// </summary> | |||||
| public List<int> dependency_sorted_node_ids() | |||||
| { | |||||
| Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||||
| foreach (var node in _nodes) | |||||
| { | |||||
| var node_id = _node_ids[node]; | |||||
| List<int> deps = new(); | |||||
| // TODO: deal with captured tensor. | |||||
| string node_path; | |||||
| foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | |||||
| { | |||||
| if (!_node_ids.ContainsKey(dep)) | |||||
| { | |||||
| node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||||
| throw new ValueError( | |||||
| $"Found an untracked dependency. Object {node_path} depends on {dep}, " + | |||||
| $"but this dependency isn't listed as a child. Please track this child by " + | |||||
| $"overriding `_trackable_children` or use `._track_trackable`."); | |||||
| } | |||||
| deps.Add(_node_ids[dep]); | |||||
| } | |||||
| } | |||||
| try | |||||
| { | |||||
| return TrackableUtils.order_by_dependency(dependency_map); | |||||
| } | |||||
| catch (TrackableUtils.CyclicDependencyError err) | |||||
| { | |||||
| List<string> pretty_printed_nodes = new(); | |||||
| List<string> pretty_printed_dependencies = new(); | |||||
| foreach (var pair in err.LeftOverDependencyMap) | |||||
| { | |||||
| var x = pair.Key; | |||||
| var deps = pair.Value; | |||||
| var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); | |||||
| pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); | |||||
| pretty_printed_dependencies.Add( | |||||
| $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); | |||||
| } | |||||
| throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + | |||||
| $"Saving cannot continue until this cycle is resolved." + | |||||
| $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + | |||||
| $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph | |||||
| /// </summary> | |||||
| /// <param name="asset_index"></param> | |||||
| /// <returns></returns> | |||||
| public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index, SaveOptions options) | |||||
| { | |||||
| SavedObjectGraph proto = new(); | |||||
| fill_object_graph_proto(proto); | |||||
| // TODO: complete the process of concrete functions. | |||||
| int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); | |||||
| for (int i = 0; i < cnt; i++) | |||||
| { | |||||
| var obj = _nodes[i]; | |||||
| var obj_proto = proto.Nodes[i]; | |||||
| write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), | |||||
| options); | |||||
| } | |||||
| return proto; | |||||
| } | |||||
| private static void write_object_proto(Trackable obj, SavedObject proto, | |||||
| IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn, SaveOptions options) | |||||
| { | |||||
| // skip the process of type Asset | |||||
| if (resource_variable_ops.is_resource_variable(obj)) | |||||
| { | |||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else if (obj is Function) | |||||
| { | |||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else if (obj is ConcreteFunction) | |||||
| { | |||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // skip the process of type `_CapturedTensor` and `CapturableResource`. | |||||
| else | |||||
| { | |||||
| var registered_type_proto = RevivedTypes.serialize(obj); | |||||
| if (registered_type_proto is null) | |||||
| { | |||||
| registered_type_proto = new SavedUserObject() | |||||
| { | |||||
| Identifier = obj.ObjectIdentifier, | |||||
| Version = new VersionDef() | |||||
| { | |||||
| Producer = 1, | |||||
| MinConsumer = 1, | |||||
| BadConsumers = { } | |||||
| } | |||||
| }; | |||||
| } | |||||
| proto.UserObject = new SavedUserObject(registered_type_proto); | |||||
| } | |||||
| // TODO: try get the registered_name from `registration`. | |||||
| } | |||||
| public void fill_object_graph_proto(SavedObjectGraph proto) | |||||
| { | |||||
| for (int node_id = 0; node_id < _nodes.Count; node_id++) | |||||
| { | |||||
| var node = _nodes[node_id]; | |||||
| Debug.Assert(_node_ids[node] == node_id); | |||||
| SavedObject object_proto = new(); | |||||
| if (_slot_variables.TryGetValue(node, out var value)) | |||||
| { | |||||
| object_proto.SlotVariables.AddRange(value); | |||||
| } | |||||
| // skip the check of type `_CapturedTensor` | |||||
| foreach (var child in _augmented_graph_view.list_children(node)) | |||||
| { | |||||
| var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||||
| child_proto.NodeId = _node_ids[child.Refer]; | |||||
| child_proto.LocalName = child.Name; | |||||
| object_proto.Children.Add(child_proto); | |||||
| } | |||||
| foreach (var pair in _augmented_graph_view.list_dependencies(node)) | |||||
| { | |||||
| var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||||
| child_proto.NodeId = _node_ids[pair.Item2]; | |||||
| child_proto.LocalName = pair.Item1; | |||||
| object_proto.Dependencies.Add(child_proto); | |||||
| } | |||||
| if (_saveable_objects_map.ContainsKey(node)) | |||||
| { | |||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else if(_obj_to_registered_saver.ContainsKey(node)) | |||||
| { | |||||
| // TODO: complete it. | |||||
| // We now skip it for the lack of `SavedObject.registered_saver` API. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| proto.Nodes.Add(object_proto); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,10 @@ | |||||
| namespace Tensorflow; | |||||
| public static class TagConstants | |||||
| { | |||||
| public static readonly string SERVING = "serve"; | |||||
| public static readonly string TRAINING = "train"; | |||||
| public static readonly string EVAL = "eval"; | |||||
| public static readonly string GPU = "gpu"; | |||||
| public static readonly string TPU = "tpu"; | |||||
| } | |||||
| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow; | |||||
| public class BuilderUtils | |||||
| { | |||||
| public static void copy_assets_to_destination_dir(IDictionary<AssetInfo, string> asset_filename_map, | |||||
| string destination_dir, HashSet<string>? saved_files = null) | |||||
| { | |||||
| if (saved_files is null) saved_files = new HashSet<string>(); | |||||
| var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); | |||||
| // TODO: complete the implementation of this function. | |||||
| if (asset_filename_map is not null && asset_filename_map.Count > 0) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,256 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Google.Protobuf; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Exceptions; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow; | |||||
| public static partial class SavedModelUtils | |||||
| { | |||||
| private static readonly IEnumerable<int> byte_swappable = new List<TF_DataType>() | |||||
| { | |||||
| dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, | |||||
| dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, | |||||
| dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, | |||||
| TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 | |||||
| }.Select(x => (int)x); | |||||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | |||||
| string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new SaveOptions(); | |||||
| } | |||||
| var saved_model = new Tensorflow.SavedModel(); | |||||
| var meta_graph_def = new MetaGraphDef(); | |||||
| saved_model.MetaGraphs.Add(meta_graph_def); | |||||
| var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = | |||||
| _build_meta_graph(obj, signatures, options, meta_graph_def); | |||||
| saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; | |||||
| if (!experimental_skip_checkpoint) | |||||
| { | |||||
| Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); | |||||
| CheckpointOptions ckpt_options = new(options.experimental_io_device); | |||||
| object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||||
| } | |||||
| BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| // tensorflow python has a check of `context.async_wait()` here. | |||||
| } | |||||
| // TODO: deal with `pywrap_saved_model.Save(export_dir)`. | |||||
| var saved_model_serialized = saved_model.ToString(); | |||||
| // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. | |||||
| if (true) | |||||
| { | |||||
| var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), | |||||
| tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); | |||||
| // TODO: add c api and complete the fingerprint def. | |||||
| var fingerprint_proto = ""; | |||||
| File.WriteAllText(fingerprint_path, fingerprint_proto); | |||||
| } | |||||
| var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | |||||
| File.WriteAllText(path, saved_model.ToString()); | |||||
| if (options.save_debug_info) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| ops.dismantle_graph(exported_graph); | |||||
| return (saved_nodes, node_paths); | |||||
| } | |||||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||||
| Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||||
| IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||||
| { | |||||
| if (ops.inside_function()) | |||||
| { | |||||
| throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + | |||||
| "Move the call to the outer eagerly-executed context."); | |||||
| } | |||||
| if (meta_graph_def is null) | |||||
| { | |||||
| meta_graph_def = new MetaGraphDef(); | |||||
| } | |||||
| AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||||
| if (signatures is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // TODO: process of aignatures and wrapped_functions | |||||
| SaveableView saveable_view = new SaveableView(augmented_graph_view, options); | |||||
| TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); | |||||
| var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, | |||||
| options.namespace_white_list, options.experimental_custom_gradients); | |||||
| if (options.function_aliases is not null) | |||||
| { | |||||
| var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; | |||||
| foreach (var pair in options.function_aliases) | |||||
| { | |||||
| var alias = pair.Key; | |||||
| var func = pair.Value; | |||||
| // TODO: complete it. | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); | |||||
| meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); | |||||
| return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); | |||||
| } | |||||
| private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | |||||
| IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist, | |||||
| bool save_custom_gradients) | |||||
| { | |||||
| var resource_initializers = saveable_view.get_concrete_resource_initializers(); | |||||
| var exported_graph = new Graph(); | |||||
| Dictionary<Trackable, Trackable> object_map; | |||||
| Dictionary<Tensor, Tensor> tensor_map; | |||||
| AssetInfo asset_info; | |||||
| using (var g = exported_graph.as_default()) | |||||
| { | |||||
| (object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||||
| // TODO: deal with signatures. | |||||
| if (save_custom_gradients) | |||||
| { | |||||
| // TODO: trace gradient functions. | |||||
| } | |||||
| foreach (var resource_initializer_function in resource_initializers) | |||||
| { | |||||
| // List<Trackable> asset_dependencies = new(); | |||||
| // TODO: deal with initializers | |||||
| } | |||||
| // using(ops.control_dependencies(...)) | |||||
| var init_op = control_flow_ops.no_op(); | |||||
| if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); | |||||
| } | |||||
| else | |||||
| { | |||||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); | |||||
| } | |||||
| // Lack `CopyFrom` API | |||||
| // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||||
| } | |||||
| foreach (var obj in object_map.Values) | |||||
| { | |||||
| obj._maybe_initialize_trackable(); | |||||
| } | |||||
| var (named_saveable_objects, registered_savers) = | |||||
| SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); | |||||
| // TODO: complete the save of checkpoints with `MultiDeviceSaver`. | |||||
| saveable_view.dependency_sorted_node_ids(); | |||||
| var graph_def = exported_graph.as_graph_def(true); | |||||
| graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); | |||||
| verify_ops(graph_def, namespace_whitelist); | |||||
| meta_graph_def.GraphDef = new GraphDef(graph_def); | |||||
| meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); | |||||
| meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; | |||||
| // TODO: add git version. | |||||
| meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; | |||||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||||
| meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); | |||||
| meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); | |||||
| // TODO: deal with signatures here. | |||||
| meta_graph.strip_graph_default_valued_attrs(meta_graph_def); | |||||
| if (!BitConverter.IsLittleEndian) | |||||
| { | |||||
| swap_function_tensor_content(meta_graph_def); | |||||
| } | |||||
| return (asset_info, exported_graph); | |||||
| } | |||||
| private static void verify_ops(GraphDef graph_def, IEnumerable<string>? namespace_whitelist) | |||||
| { | |||||
| return; | |||||
| // if (namespace_whitelist is null || !namespace_whitelist.Any()) | |||||
| // { | |||||
| // return; | |||||
| // } | |||||
| // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. | |||||
| } | |||||
| public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) | |||||
| { | |||||
| var functions = meta_graph_def.GraphDef.Library.Function; | |||||
| foreach (var function in functions) | |||||
| { | |||||
| var node_def = function.NodeDef; | |||||
| foreach (var node in node_def) | |||||
| { | |||||
| if (node.Op == "Const") | |||||
| { | |||||
| var tensor = node.Attr["value"].Tensor; | |||||
| byte_swap_tensor_content(tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| public static void byte_swap_tensor_content(TensorProto tensor) | |||||
| { | |||||
| if (byte_swappable.Contains((int)tensor.Dtype)) | |||||
| { | |||||
| var tshape = tensor.TensorShape.Dim; | |||||
| var tensor_bytes = tensor.TensorContent; | |||||
| if (tensor_bytes is not null && !tensor_bytes.IsEmpty) | |||||
| { | |||||
| long tensor_size = 1; | |||||
| foreach (var sz in tshape) | |||||
| { | |||||
| tensor_size *= sz.Size; | |||||
| } | |||||
| var chunksize = tensor_bytes.Length / tensor_size; | |||||
| List<byte> reversed_bytes = new(); | |||||
| for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) | |||||
| { | |||||
| var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); | |||||
| reversed_bytes.AddRange(current); | |||||
| } | |||||
| tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,58 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow; | |||||
| public class SignatureMap: Trackable | |||||
| { | |||||
| private Dictionary<string, Function> _signatures; | |||||
| private Dictionary<string, ConcreteFunction> _concrete_signatures; | |||||
| public SignatureMap() | |||||
| { | |||||
| _signatures = new(); | |||||
| } | |||||
| public void _add_signature(string name, ConcreteFunction concrete_function) | |||||
| { | |||||
| _concrete_signatures[name] = concrete_function; | |||||
| } | |||||
| public void _add_signature(string name, Function concrete_function) | |||||
| { | |||||
| _signatures[name] = concrete_function; | |||||
| } | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| if (save_type != SaveType.SAVEDMODEL) | |||||
| { | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | |||||
| Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); | |||||
| foreach (var pair in _concrete_signatures) | |||||
| { | |||||
| res[pair.Key] = pair.Value; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures) | |||||
| { | |||||
| var signature_map = new SignatureMap(); | |||||
| foreach (var pair in signatures) | |||||
| { | |||||
| var name = pair.Key; | |||||
| var func = pair.Value; | |||||
| // TODO: assert the arg_keywords | |||||
| signature_map._add_signature(name, func); | |||||
| } | |||||
| return signature_map; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,52 @@ | |||||
| using System.IO; | |||||
| using System.Security.Cryptography.X509Certificates; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow; | |||||
| public static partial class SavedModelUtils | |||||
| { | |||||
| /// <summary> | |||||
| /// Return variables sub-directory, or create one if it doesn't exist. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| public static string get_or_create_variables_dir(string export_dir) | |||||
| { | |||||
| var variables_dir = get_variables_dir(export_dir); | |||||
| Directory.CreateDirectory(variables_dir); | |||||
| return variables_dir; | |||||
| } | |||||
| /// <summary> | |||||
| /// Return variables sub-directory in the SavedModel. | |||||
| /// </summary> | |||||
| /// <param name="export_dir"></param> | |||||
| /// <returns></returns> | |||||
| public static string get_variables_dir(string export_dir) | |||||
| { | |||||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); | |||||
| } | |||||
| /// <summary> | |||||
| /// Return assets sub-directory, or create one if it doesn't exist. | |||||
| /// </summary> | |||||
| /// <param name="export_dir"></param> | |||||
| /// <returns></returns> | |||||
| public static string get_or_create_assets_dir(string export_dir) | |||||
| { | |||||
| var assets_destination_dir = get_assets_dir(export_dir); | |||||
| Directory.CreateDirectory(assets_destination_dir); | |||||
| return assets_destination_dir; | |||||
| } | |||||
| /// <summary> | |||||
| /// Return path to asset directory in the SavedModel. | |||||
| /// </summary> | |||||
| /// <param name="export_dir"></param> | |||||
| /// <returns></returns> | |||||
| public static string get_assets_dir(string export_dir) | |||||
| { | |||||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | |||||
| } | |||||
| } | |||||
| @@ -17,12 +17,17 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class saveable_object_util | |||||
| public static class saveable_object_util | |||||
| { | { | ||||
| public class TrackableSaveable: MySaveableObject | |||||
| { | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the variables and names that will be used for a Saver. | /// Returns the variables and names that will be used for a Saver. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -121,5 +126,17 @@ namespace Tensorflow | |||||
| return names_to_saveables; | return names_to_saveables; | ||||
| } | } | ||||
| public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj) | |||||
| { | |||||
| // TODO: complete the implementation. | |||||
| return obj.gather_saveables_for_checkpoint(); | |||||
| } | |||||
| public static bool trackable_has_serialize_to_tensor(Trackable obj) | |||||
| { | |||||
| // TODO: implement it. | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,14 +14,38 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.ModelSaving; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| { | { | ||||
| public abstract class Trackable | public abstract class Trackable | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Corresponding to tensorflow/python/trackable/constants.py | |||||
| /// </summary> | |||||
| public static class Constants | |||||
| { | |||||
| public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; | |||||
| public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; | |||||
| public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; | |||||
| } | |||||
| protected int _self_update_uid; | protected int _self_update_uid; | ||||
| protected IDictionary<string, Trackable> _unconditional_dependency_names = | |||||
| new Dictionary<string, Trackable>(); | |||||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||||
| protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | |||||
| new Dictionary<string, ResourceVariable>(); | |||||
| public virtual string ObjectIdentifier | |||||
| { | |||||
| get => "_generic_user_object"; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -73,10 +97,63 @@ namespace Tensorflow.Train | |||||
| /// <summary> | /// <summary> | ||||
| /// Initialize dependency management. | /// Initialize dependency management. | ||||
| /// </summary> | /// </summary> | ||||
| protected void _maybe_initialize_trackable() | |||||
| public void _maybe_initialize_trackable() | |||||
| { | { | ||||
| // _self_unconditional_checkpoint_dependencies = [] | // _self_unconditional_checkpoint_dependencies = [] | ||||
| _self_update_uid = -1; | _self_update_uid = -1; | ||||
| } | } | ||||
| // TODO: cache | |||||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| _maybe_initialize_trackable(); | |||||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||||
| } | |||||
| public static Trackable convert_to_trackable(object obj, object? parent = null) | |||||
| { | |||||
| if (obj is Trackable) | |||||
| { | |||||
| return (Trackable)obj; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children) | |||||
| { | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | |||||
| public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources( | |||||
| SaveOptions? save_options) | |||||
| { | |||||
| return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | |||||
| } | |||||
| public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null, | |||||
| IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null) | |||||
| { | |||||
| var (self_object_map, self_tensor_map) = map_resources(options); | |||||
| foreach (var pair in self_object_map) | |||||
| { | |||||
| object_map.Add(pair); | |||||
| } | |||||
| foreach (var pair in self_tensor_map) | |||||
| { | |||||
| tensor_map.Add(pair); | |||||
| } | |||||
| return self_tensor_map.Keys.ToList(); | |||||
| } | |||||
| public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint() | |||||
| { | |||||
| return _self_saveable_object_factories; | |||||
| } | |||||
| } | } | ||||
| public record class TrackableReference(string Name, Trackable Refer); | |||||
| } | } | ||||
| @@ -0,0 +1,148 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Exceptions; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Training; | |||||
| public static class TrackableUtils | |||||
| { | |||||
| public class CyclicDependencyError: System.Exception | |||||
| { | |||||
| public IDictionary<int, IEnumerable<int>> LeftOverDependencyMap { get; } | |||||
| public CyclicDependencyError(IDictionary<int, IEnumerable<int>> leftover_dependency_map): base() | |||||
| { | |||||
| LeftOverDependencyMap = leftover_dependency_map; | |||||
| } | |||||
| public CyclicDependencyError(IDictionary<int, List<int>> leftover_dependency_map): base() | |||||
| { | |||||
| LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||||
| } | |||||
| } | |||||
| private static string _ESCAPE_CHAR = "."; | |||||
| private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
| private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
| private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||||
| public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||||
| { | |||||
| return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | |||||
| } | |||||
| public static string escape_local_name(string name) | |||||
| { | |||||
| return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); | |||||
| } | |||||
| public static string checkpoint_key(string object_path, string local_name) | |||||
| { | |||||
| var key_suffix = escape_local_name(local_name); | |||||
| if (local_name == SERIALIZE_TO_TENSORS_NAME) | |||||
| { | |||||
| key_suffix = ""; | |||||
| } | |||||
| return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; | |||||
| } | |||||
| /// <summary> | |||||
| /// Topologically sorts the keys of a map so that dependencies appear first. | |||||
| /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm | |||||
| /// </summary> | |||||
| /// <param name="dependency_map"></param> | |||||
| /// <exception cref="ValueError"></exception> | |||||
| public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||||
| { | |||||
| Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | |||||
| foreach (var pair in dependency_map) | |||||
| { | |||||
| foreach (var dep in pair.Value) | |||||
| { | |||||
| if (reverse_dependency_map.ContainsKey(dep)) | |||||
| { | |||||
| reverse_dependency_map[dep].Add(pair.Key); | |||||
| } | |||||
| else | |||||
| { | |||||
| reverse_dependency_map[dep] = new HashSet<int>(); | |||||
| reverse_dependency_map[dep].Add(pair.Key); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Validate that all values in the dependency map are also keys. | |||||
| var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); | |||||
| if (unknown_keys.Count() > 0) | |||||
| { | |||||
| throw new ValueError( | |||||
| $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); | |||||
| } | |||||
| // Generate the list sorted by objects without dependencies -> dependencies. | |||||
| // The returned list will reverse this. | |||||
| List<int> reversed_dependency_arr = new(); | |||||
| Queue<int> to_visit = new(); | |||||
| foreach (var x in dependency_map.Keys) | |||||
| { | |||||
| if (!reverse_dependency_map.ContainsKey(x)) | |||||
| { | |||||
| to_visit.Enqueue(x); | |||||
| } | |||||
| } | |||||
| while (to_visit.Count > 0) | |||||
| { | |||||
| var x = to_visit.Dequeue(); | |||||
| reversed_dependency_arr.Add(x); | |||||
| foreach (var dep in dependency_map[x].Distinct()) | |||||
| { | |||||
| var edges = reverse_dependency_map[dep]; | |||||
| edges.Remove(x); | |||||
| if (edges.Count == 0) | |||||
| { | |||||
| to_visit.Enqueue(dep); | |||||
| if (!reverse_dependency_map.Remove(dep)) | |||||
| { | |||||
| throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (reverse_dependency_map.Count > 0) | |||||
| { | |||||
| Dictionary<int, List<int>> leftover_dependency_map = new(); | |||||
| foreach (var pair in reverse_dependency_map) | |||||
| { | |||||
| foreach (var x in pair.Value) | |||||
| { | |||||
| if (leftover_dependency_map.ContainsKey(x)) | |||||
| { | |||||
| leftover_dependency_map[x].Add(pair.Key); | |||||
| } | |||||
| else | |||||
| { | |||||
| leftover_dependency_map[x] = new List<int>() { pair.Key }; | |||||
| } | |||||
| } | |||||
| } | |||||
| throw new CyclicDependencyError(leftover_dependency_map); | |||||
| } | |||||
| reversed_dependency_arr.Reverse(); | |||||
| return reversed_dependency_arr; | |||||
| } | |||||
| public static string pretty_print_node_path(IEnumerable<TrackableReference> paths) | |||||
| { | |||||
| if (paths.Count() == 0) | |||||
| { | |||||
| return "root object"; | |||||
| } | |||||
| else | |||||
| { | |||||
| return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -566,5 +566,23 @@ namespace Tensorflow | |||||
| else | else | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public static bool inside_function() | |||||
| { | |||||
| return get_default_graph().building_function; | |||||
| } | |||||
| public static void dismantle_graph(Graph graph) | |||||
| { | |||||
| } | |||||
| public class NullContextManager: IDisposable | |||||
| { | |||||
| public void Dispose() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,31 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Engine; | |||||
| public abstract partial class Layer | |||||
| { | |||||
| public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||||
| public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||||
| public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| IDictionary<string, Trackable> children; | |||||
| if (save_type == SaveType.SAVEDMODEL) | |||||
| { | |||||
| // TODO: deal with cache. | |||||
| children = TrackableSavedModelSaver.trackable_children(cache); | |||||
| } | |||||
| else | |||||
| { | |||||
| children = new Dictionary<string, Trackable>(); | |||||
| } | |||||
| return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | |||||
| @@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine | |||||
| public bool Built => built; | public bool Built => built; | ||||
| public bool Trainable => args.Trainable; | public bool Trainable => args.Trainable; | ||||
| public TF_DataType DType => args.DType; | public TF_DataType DType => args.DType; | ||||
| public bool AutoCast => args.Autocast; | |||||
| public IRegularizer ActivityRegularizer => args.ActivityRegularizer; | |||||
| /// <summary> | /// <summary> | ||||
| /// A stateful layer is a layer whose updates are run during inference too, | /// A stateful layer is a layer whose updates are run during inference too, | ||||
| @@ -162,7 +164,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="inputs"></param> | /// <param name="inputs"></param> | ||||
| /// <param name="state"></param> | /// <param name="state"></param> | ||||
| /// <param name="is_training"></param> | |||||
| /// <param name="training"></param> | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| { | { | ||||
| @@ -1,5 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| @@ -18,9 +20,18 @@ namespace Tensorflow.Keras.Engine | |||||
| bool overwrite = true, | bool overwrite = true, | ||||
| bool include_optimizer = true, | bool include_optimizer = true, | ||||
| string save_format = "tf", | string save_format = "tf", | ||||
| SaveOptions options = null) | |||||
| SaveOptions? options = null, | |||||
| IDictionary<string, ConcreteFunction>? signatures = null, | |||||
| bool save_traces = true) | |||||
| { | { | ||||
| saver.save(this, filepath); | |||||
| if (save_format != "pb") | |||||
| { | |||||
| saver.save(this, filepath); | |||||
| } | |||||
| else | |||||
| { | |||||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,6 +35,12 @@ namespace Tensorflow.Keras.Engine | |||||
| bool _base_model_initialized; | bool _base_model_initialized; | ||||
| bool stop_training; | bool stop_training; | ||||
| DataHandler data_handler; | DataHandler data_handler; | ||||
| public OptimizerV2 Optimizer | |||||
| { | |||||
| get => optimizer; | |||||
| set => optimizer = value; | |||||
| } | |||||
| public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
| : base(args) | : base(args) | ||||
| @@ -194,6 +194,18 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { | |||||
| OnConstruction(); | OnConstruction(); | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public SavedObject(int nodeId, string nodePath, | |||||
| global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) | |||||
| { | |||||
| OnConstruction(); | |||||
| nodeId_ = nodeId; | |||||
| nodePath_ = nodePath; | |||||
| identifier_ = identifier; | |||||
| metadata_ = metadata; | |||||
| version_ = version; | |||||
| } | |||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | ||||
| @@ -74,6 +74,13 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { | |||||
| public VersionDef() { | public VersionDef() { | ||||
| OnConstruction(); | OnConstruction(); | ||||
| } | } | ||||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||||
| public VersionDef(int producer, int minConsumer) { | |||||
| OnConstruction(); | |||||
| producer_ = producer; | |||||
| minConsumer_ = minConsumer; | |||||
| } | |||||
| partial void OnConstruction(); | partial void OnConstruction(); | ||||
| @@ -0,0 +1,41 @@ | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public static class Constants | |||||
| { | |||||
| /// <summary> | |||||
| /// Namespace used to store all attributes added during serialization. | |||||
| /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an | |||||
| /// object loaded from `tf.saved_model.load()`. | |||||
| /// </summary> | |||||
| public static readonly string KERAS_ATTR = "keras_api"; | |||||
| /// <summary> | |||||
| /// Keys for the serialization cache. | |||||
| /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} | |||||
| /// </summary> | |||||
| public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; | |||||
| /// <summary> | |||||
| /// Name of Keras metadata file stored in the SavedModel. | |||||
| /// </summary> | |||||
| public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; | |||||
| public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; | |||||
| public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; | |||||
| public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; | |||||
| public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; | |||||
| public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; | |||||
| public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; | |||||
| public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; | |||||
| public static readonly IList<string> KERAS_OBJECT_IDENTIFIERS = new List<string>() | |||||
| { | |||||
| INPUT_LAYER_IDENTIFIER, | |||||
| LAYER_IDENTIFIER, | |||||
| METRIC_IDENTIFIER, | |||||
| MODEL_IDENTIFIER, | |||||
| NETWORK_IDENTIFIER, | |||||
| RNN_LAYER_IDENTIFIER, | |||||
| SEQUENTIAL_IDENTIFIER | |||||
| }; | |||||
| } | |||||
| @@ -0,0 +1,11 @@ | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public class KerasObjectWrapper | |||||
| { | |||||
| } | |||||
| public class KerasObjectWrapper<T> | |||||
| { | |||||
| public T Item { get; set; } | |||||
| } | |||||
| @@ -0,0 +1,115 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using Google.Protobuf; | |||||
| using ICSharpCode.SharpZipLib.Zip; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Exceptions; | |||||
| using Tensorflow.IO; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | |||||
| { | |||||
| public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures, | |||||
| SaveOptions? options, bool save_traces = true) | |||||
| { | |||||
| if (!overwrite && File.Exists(filepath)) | |||||
| { | |||||
| throw new Exception("The file already exists but is not allowed to overwrite it."); | |||||
| } | |||||
| if (save_traces) | |||||
| { | |||||
| if(should_skip_serialization(model)) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| OptimizerV2? orig_optimizer = null; | |||||
| if (!include_optimizer) | |||||
| { | |||||
| orig_optimizer = model.Optimizer; | |||||
| model.Optimizer = null; | |||||
| model._delete_tracking("optimizer"); | |||||
| } | |||||
| IList<Trackable> saved_nodes; | |||||
| IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths; | |||||
| // skip two scopes of python | |||||
| using (KerasSavedModelUtils.keras_option_scope(save_traces)) | |||||
| { | |||||
| (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); | |||||
| } | |||||
| var metadata = generate_keras_metadata(saved_nodes, node_paths); | |||||
| using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, | |||||
| FileAccess.Write)) | |||||
| { | |||||
| var writer = new StreamWriter(f); | |||||
| writer.Write(metadata.ToString()); | |||||
| } | |||||
| if (!include_optimizer) | |||||
| { | |||||
| model.Optimizer = orig_optimizer!; | |||||
| } | |||||
| } | |||||
| public static SavedMetadata generate_keras_metadata(IList<Trackable> saved_nodes, | |||||
| IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths) | |||||
| { | |||||
| var metadata = new SavedMetadata(); | |||||
| for (int i = 0; i < saved_nodes.Count; i++) | |||||
| { | |||||
| var node = saved_nodes[i]; | |||||
| if (node is not Layer) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| Layer layer = (Layer)node; | |||||
| var path = node_paths[node]; | |||||
| string node_path; | |||||
| if (path is null) | |||||
| { | |||||
| node_path = "root"; | |||||
| } | |||||
| else | |||||
| { | |||||
| node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; | |||||
| } | |||||
| ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() | |||||
| { | |||||
| NodeId = i, | |||||
| NodePath = node_path, | |||||
| Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() | |||||
| { | |||||
| Producer = 2, | |||||
| MinConsumer = 1, | |||||
| BadConsumers = { } | |||||
| }, | |||||
| Identifier = layer.ObjectIdentifier, | |||||
| Metadata = layer.TrackingMetadata | |||||
| }; | |||||
| } | |||||
| return metadata; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,19 @@ | |||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | |||||
| { | |||||
| public static bool should_skip_serialization(object layer) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| public static IDictionary<string, KerasObjectWrapper> wrap_layer_objects(Layer layer, object serialization_cache) | |||||
| { | |||||
| // TODO: process the loss | |||||
| return null; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Newtonsoft.Json; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public abstract class SavedModelSaver | |||||
| { | |||||
| private Trackable _obj; | |||||
| public SavedModelSaver(Trackable obj) | |||||
| { | |||||
| _obj = obj; | |||||
| } | |||||
| public abstract string ObjectIdentifier { get; } | |||||
| public abstract string TrackingMetadata { get; } | |||||
| public abstract IDictionary<string, CheckpointableBase> objects_to_serialize( | |||||
| IDictionary<string, object> serialization_cache); | |||||
| public abstract IDictionary<string, Function> functions_to_serialize( | |||||
| IDictionary<string, object> serialization_cache); | |||||
| public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | |||||
| { | |||||
| if (!KerasSavedModelUtils.ShouldHaveTraces) | |||||
| { | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | |||||
| var children = objects_to_serialize(serialization_cache); | |||||
| return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) | |||||
| .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||||
| .ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,62 @@ | |||||
| using System.Collections.Generic; | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public class LayerSavedModelSaver: SavedModelSaver | |||||
| { | |||||
| private Layer _obj; | |||||
| public LayerSavedModelSaver(Layer obj): base(obj) | |||||
| { | |||||
| _obj = obj; | |||||
| } | |||||
| public override string ObjectIdentifier | |||||
| { | |||||
| get => Constants.LAYER_IDENTIFIER; | |||||
| } | |||||
| public override IDictionary<string, CheckpointableBase> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| { | |||||
| throw new System.NotImplementedException(); | |||||
| } | |||||
| public override IDictionary<string, Function> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| { | |||||
| throw new System.NotImplementedException(); | |||||
| } | |||||
| public override string TrackingMetadata | |||||
| { | |||||
| get | |||||
| { | |||||
| JObject metadata = new JObject(); | |||||
| metadata["name"] = _obj.Name; | |||||
| metadata["trainable"] = _obj.Trainable; | |||||
| // metadata["expects_training_arg"] = _obj._expects_training_arg; | |||||
| // metadata["dtype"] = policy.serialize(_obj._dtype_policy) | |||||
| metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); | |||||
| // metadata["stateful"] = _obj.stateful; | |||||
| // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | |||||
| // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | |||||
| metadata["autocast"] = _obj.AutoCast; | |||||
| metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings | |||||
| { | |||||
| // Handle conflicts by using values from obj2 | |||||
| MergeArrayHandling = MergeArrayHandling.Merge | |||||
| }); | |||||
| // skip the check of `input_spec` and `build_input_shape` for the lack of members. | |||||
| // skip the check of `activity_regularizer` for the type problem. | |||||
| return metadata.ToString(); | |||||
| } | |||||
| } | |||||
| public static LayerConfig get_serialized(Layer obj) | |||||
| { | |||||
| return generic_utils.serialize_keras_object(obj); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | |||||
| { | |||||
| public static bool ShouldHaveTraces { get; internal set; } | |||||
| public static SaveOptionsContext keras_option_scope(bool save_traces) | |||||
| { | |||||
| var res = new SaveOptionsContext(ShouldHaveTraces); | |||||
| ShouldHaveTraces = save_traces; | |||||
| return res; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Implementation of this class is different with that of python. | |||||
| /// But it could be used with `using` the same as `with` of python. | |||||
| /// </summary> | |||||
| public class SaveOptionsContext: IDisposable | |||||
| { | |||||
| public bool _old_value; | |||||
| public SaveOptionsContext(bool old_value) | |||||
| { | |||||
| _old_value = true; | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| KerasSavedModelUtils.ShouldHaveTraces = _old_value; | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,60 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Tensorflow.NumPy; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| using Tensorflow; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| using Tensorflow.Keras; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| namespace TensorFlowNET.Keras.UnitTest; | |||||
| // class MNISTLoader | |||||
| // { | |||||
| // public MNISTLoader() | |||||
| // { | |||||
| // var mnist = new MnistModelLoader() | |||||
| // | |||||
| // } | |||||
| // } | |||||
| [TestClass] | |||||
| public class SaveTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void Test() | |||||
| { | |||||
| var inputs = new KerasInterface().Input((28, 28, 1)); | |||||
| var x = new Flatten(new FlattenArgs()).Apply(inputs); | |||||
| x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); | |||||
| x = new LayersApi().Dense(units: 10).Apply(x); | |||||
| var outputs = new LayersApi().Softmax(axis: 1).Apply(x); | |||||
| var model = new KerasInterface().Model(inputs, outputs); | |||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); | |||||
| var data_loader = new MnistModelLoader(); | |||||
| var num_epochs = 1; | |||||
| var batch_size = 50; | |||||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
| { | |||||
| TrainDir = "mnist", | |||||
| OneHot = false, | |||||
| ValidationSize = 50000, | |||||
| }).Result; | |||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
| model.save("", save_format:"pb"); | |||||
| } | |||||
| } | |||||
| @@ -47,7 +47,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" /> | |||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" /> | <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" /> | ||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.8" /> | <PackageReference Include="MSTest.TestFramework" Version="2.2.8" /> | ||||