From bb8168b5ca9bc78a824d429eb7bd5f4ac9e4fa8d Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 21 Jan 2023 11:07:07 +0800 Subject: [PATCH] Init the serialization of keras pb model. --- src/TensorFlowNET.Core/APIs/tf.compat.cs | 22 ++ .../Checkpoint/CheckPointUtils.cs | 150 +++++++++ .../Checkpoint/CheckpointOptions.cs | 5 + .../Checkpoint/ObjectGraphView.cs | 63 ++++ .../Checkpoint/SaveUtilV1.cs | 229 ++++++++++++++ .../Checkpoint/TrackableSaver.cs | 109 +++++++ .../Checkpoint/TrackableView.cs | 75 +++++ .../Exceptions/AssertionError.cs | 14 + .../Framework/meta_graph.cs | 63 +++- .../Functions/ConcreteFunction.cs | 3 +- src/TensorFlowNET.Core/Functions/Function.cs | 11 +- .../ModelSaving/SaveOptions.cs | 8 +- .../Operations/resource_variable_ops.cs | 6 + .../Protobuf/SavedObjectGraph.cs | 10 +- .../Protobuf/TrackableObjectGraph.cs | 6 + .../Training/AutoTrackable.cs | 15 + src/TensorFlowNET.Core/Training/Optimizer.cs | 7 +- .../Training/Saving/SaveableObject.cs | 14 + .../Training/Saving/SavedModel/AssetInfo.cs | 11 + .../Saving/SavedModel/AugmentedGraphView.cs | 60 ++++ .../Training/Saving/SavedModel/Constants.cs | 33 ++ .../Saving/SavedModel/RevivedTypes.cs | 17 + .../Training/Saving/SavedModel/SaveType.cs | 9 + .../Saving/SavedModel/SaveableView.cs | 299 ++++++++++++++++++ .../Saving/SavedModel/TagConstants.cs | 10 + .../Training/Saving/SavedModel/builder.cs | 22 ++ .../Training/Saving/SavedModel/save.cs | 256 +++++++++++++++ .../SavedModel/signature_serialization.cs | 58 ++++ .../Training/Saving/SavedModel/utils.cs | 52 +++ .../Saving/saveable_object_util.py.cs | 19 +- src/TensorFlowNET.Core/Training/Trackable.cs | 79 ++++- .../Training/TrackableUtils.cs | 148 +++++++++ .../Variables/BaseResourceVariable.cs | 1 + src/TensorFlowNET.Core/ops.cs | 18 ++ .../Engine/Layer.Serialize.cs | 31 ++ src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 15 +- src/TensorFlowNET.Keras/Engine/Model.cs | 6 + .../Protobuf/SavedMetadata.cs | 12 + src/TensorFlowNET.Keras/Protobuf/Versions.cs | 7 + .../Saving/SavedModel/Constants.cs | 41 +++ .../Saving/SavedModel/KerasObjectWrapper.cs | 11 + .../Saving/SavedModel/Save.cs | 115 +++++++ .../Saving/SavedModel/SaveImpl.cs | 19 ++ .../Saving/SavedModel/base_serialization.cs | 40 +++ .../Saving/SavedModel/layer_serialization.cs | 62 ++++ .../Saving/SavedModel/utils.cs | 33 ++ test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 60 ++++ .../Tensorflow.Binding.UnitTest.csproj | 2 +- 49 files changed, 2347 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/TrackableView.cs create mode 100644 src/TensorFlowNET.Core/Exceptions/AssertionError.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs create mode 100644 src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs create mode 100644 src/TensorFlowNET.Core/Training/TrackableUtils.cs create mode 100644 src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveTest.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.compat.cs b/src/TensorFlowNET.Core/APIs/tf.compat.cs index 4d979eb5..5b2b5a10 100644 --- a/src/TensorFlowNET.Core/APIs/tf.compat.cs +++ b/src/TensorFlowNET.Core/APIs/tf.compat.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Text; + namespace Tensorflow { public partial class tensorflow @@ -23,6 +25,26 @@ namespace Tensorflow public class CompatApi { public CompatV1Api v1 { get; } = new CompatV1Api(); + + internal string as_text(string bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return bytes_or_text; + } + internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) + { + if(encoding is null) encoding = Encoding.UTF8; + return encoding.GetString(bytes_or_text); + } + + internal string as_str(string bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } + internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) + { + return as_text(bytes_or_text, encoding); + } } public bool executing_eagerly() diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs new file mode 100644 index 00000000..70d77155 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -0,0 +1,150 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; + +namespace Tensorflow.Checkpoint; + +public static class CheckPointUtils +{ + private static string _ESCAPE_CHAR = "."; + public static (List, Dictionary>, Dictionary, + IDictionary>, + Dictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); + return (trackable_objects, node_paths, node_ids, slot_variables, object_names); + } + + public static + IDictionary> + serialize_slot_variables(IEnumerable trackable_objects, + IDictionary node_ids, IDictionary object_names) + { + var non_slot_objects = trackable_objects.ToList(); + Dictionary> + slot_variables = new(); + foreach (var trackable in non_slot_objects) + { + if (trackable is not Optimizer) + { + continue; + } + + var optim = (Optimizer)trackable; + var slot_names = optim.get_slot_names(); + foreach (var slot_name in slot_names) + { + for (int original_variable_node_id = 0; + original_variable_node_id < non_slot_objects.Count; + original_variable_node_id++) + { + var original_variable = non_slot_objects[original_variable_node_id]; + IVariableV1 slot_variable; + if (original_variable is not IVariableV1) + { + slot_variable = null; + } + slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); + if(slot_variable is null) continue; + + // There're some problems about the inherits of `Variable` and `Trackable`. + throw new NotImplementedException(); + } + } + } + + return slot_variables; + } + + public static Trackable get_mapped_trackable(Trackable trackable, IDictionary? object_map) + { + if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) + { + return trackable; + } + else + { + return possible_res; + } + } + + public static string get_full_name(Trackable var) + { + // TODO: This state is not correct, the whole framework need to be updated in the future. + if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) + { + return ""; + } + // skip the check of attribute `_save_slice_info` . + + // TODO: Need to be revised!!! + return ((ResourceVariable)(object)var).Name; + } + + public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) + { + HashSet checkpointed_trackables = new(); + Dictionary> parents = new(); + for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + { + var object_proto = object_graph_proto.Nodes[i]; + // skip the process of registered saver. + if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || + object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) + { + checkpointed_trackables.Add(i); + } + + foreach (var child_proto in object_proto.Children) + { + var child = child_proto.NodeId; + if (!parents.ContainsKey(child)) + { + parents[child] = new HashSet(); + } + + parents[child].Add(i); + } + } + + Queue to_visit = new(checkpointed_trackables.AsEnumerable()); + while (to_visit.Count > 0) + { + var trackable = to_visit.Dequeue(); + if (!parents.ContainsKey(trackable)) continue; + var current_parents = parents[trackable]; + foreach (var parent in current_parents) + { + checkpointed_trackables.Add(parent); + if (parents.ContainsKey(parent)) + { + to_visit.Enqueue(parent); + } + } + parents.Remove(trackable); + } + + // TODO: Complete it after supporting checkpoint. + // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) + // { + // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); + // } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs new file mode 100644 index 00000000..d8297ea3 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -0,0 +1,5 @@ +namespace Tensorflow.Checkpoint; + +public record class CheckpointOptions( + string experimental_io_device = null, + bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs new file mode 100644 index 00000000..2ad55448 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Serilog.Debugging; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +public class ObjectGraphView: TrackableView, ICloneable +{ + protected IEnumerable? _attached_dependencies; + // TODO: attached_dependencies + public ObjectGraphView(Trackable root, IEnumerable? attached_dependencies = null): base(root) + { + _attached_dependencies = attached_dependencies; + } + + public object Clone() + { + // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ + return new ObjectGraphView(Root, _attached_dependencies); + } + + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + List res = base.children(obj, save_type) + .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + // Check the reference, not value. + if (obj == Root && _attached_dependencies is not null) + { + res.AddRange(_attached_dependencies); + } + + return res; + } + + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + } + + public IEnumerable? AttachedDependencies + { + get => _attached_dependencies; + } + + public virtual (List, Dictionary>) breadth_first_traversal() + { + return base._descendants_with_paths(); + } + + // TODO: complete the implementation + public void serialize_object_graph(object? saveables_cache = null) + { + throw new NotImplementedException(); + } + + // TODO: complete the implementation + public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs new file mode 100644 index 00000000..7724c6b7 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -0,0 +1,229 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow.Checkpoint; + +public static class SaveUtilV1 +{ + public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + IDictionary? object_map = null) + { + // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, + // till now only internal registrations are allowed. So, we won't return a saver in this function. + // The implementation of this function should be updated if tensorflow update it. + Dictionary> checkpoint_factory_map = new(); + foreach (var pair in object_names) + { + var trackable = pair.Key; + var object_name = pair.Value; + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + + // skip the registration process. + + List current_list = new(); + foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) + { + // treat name as key_suffix. + var name = name_and_factory.Key; + var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); + + current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); + } + + checkpoint_factory_map[trackable] = current_list; + } + + return (checkpoint_factory_map, null); + } + + public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, + object? saveables_cache = null) + { + + Graph target_context; + if (to_graph is not null) + { + using (to_graph.as_default()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + else + { + using (new ops.NullContextManager()) + { + var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, + object_map, call_with_mapped_captures, saveables_cache); + // tensorflow python: `with ops.device("/cpu:0")` + var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + return (named_saveable_objects, registered_savers); + } + } + } + + public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); + Dictionary object_names = new(); + foreach (var pair in node_paths) + { + object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); + } + + Dictionary node_ids = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + node_ids[trackable_objects[i]] = i; + } + + var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); + var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); + var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( + trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, + saveables_cache); + + CheckPointUtils.add_checkpoint_values_check(object_graph_proto); + return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); + } + + private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList trackable_objects, + IDictionary node_ids, + IDictionary> + slot_variables) + { + TrackableObjectGraph object_graph_proto = new(); + for (int i = 0; i < trackable_objects.Count; i++) + { + var trackable = trackable_objects[i]; + Debug.Assert(node_ids[trackable] == i); + TrackableObjectGraph.Types.TrackableObject object_proto; + if (slot_variables.TryGetValue(trackable, out var slots)) + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); + } + else + { + object_proto = new TrackableObjectGraph.Types.TrackableObject(); + } + object_graph_proto.Nodes.Add(object_proto); + foreach (var child in graph_view.list_children(trackable)) + { + object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() + { NodeId = node_ids[child.Refer], LocalName = child.Name }); + } + } + + return object_graph_proto; + } + + private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + TrackableObjectGraph object_graph_proto, IDictionary node_ids, + IDictionary object_names, IDictionary object_map, + bool call_with_mapped_captures, object? saveables_cache = null) + { + int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + Debug.Assert(node_ids[trackable_objects[i]] == i); + } + + var (checkpoint_factory_map, unmmaped_registered_savers) = + get_checkpoint_factories_and_keys(object_names, object_map); + + // skip the process of registered savers + + var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, + object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); + return (named_saveable_objects, feed_additions, null); + } + + public static (List, object?) generate_saveable_objects( + IDictionary> checkpoint_factory_map, + TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, + IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) + { + List named_saveable_objects = new(); + foreach (var pair in checkpoint_factory_map) + { + var trackable = pair.Key; + var factory_data_list = pair.Value; + bool fill_object_proto = object_graph_proto is not null && node_ids is not null; + TrackableObjectGraph.Types.TrackableObject object_proto = null!; + if (fill_object_proto) + { + object_proto = object_graph_proto.Nodes[node_ids[trackable]]; + } + + var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); + // skip cache + + foreach (var factory_data in factory_data_list) + { + var name = factory_data.name; + var key = factory_data.checkpoint_key; + var saveable_factory = factory_data.factory; + + // TODO: oneflow python has a process with callable `saveable_factory`. + var maybe_saveable = saveable_factory; + IEnumerable savesbles; + if (maybe_saveable is MySaveableObject) + { + savesbles = new List() { (MySaveableObject)maybe_saveable }; + } + else if (maybe_saveable is Tensor) + { + savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + } + else + { + throw new TypeError("Unexpected type."); + } + + foreach (var saveable in savesbles) + { + if (!saveable.name.Contains(key)) + { + throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + + $"'{saveable.name}' for attribute '{name}'. Expected a name" + + $" containing '{key}'."); + } + } + + // skip the process of PythonState + + named_saveable_objects.AddRange(savesbles); + + if(!fill_object_proto) continue; + + // skip the process of TrackableSaveable + + object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() + { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); + } + } + + return (named_saveable_objects, null); + } +} + +public record class CheckpointFactoryData +( + object factory, + string name, + string checkpoint_key +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs new file mode 100644 index 00000000..7d101d5e --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; + +namespace Tensorflow.Checkpoint; + +public class TrackableSaver +{ + private ObjectGraphView _graph_view; + private EagerTensor _cached_save_operation; + private TrackableObjectGraph _last_save_object_graph; + private Tensor? _object_graph_feed_tensor = null; + private Tensor? _file_prefix_feed_tensor = null; + public TrackableSaver(ObjectGraphView graph_view) + { + _graph_view = graph_view; + + // TODO: cache when not executing eagerly. + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` + + } + + private void gather_serialized_tensors(Tensor? object_graph_tensor = null) + { + throw new NotImplementedException(); + } + + private (EagerTensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + { + throw new NotImplementedException(); + } + + // TODO: parameter write_done_callback + public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, + CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + + Dictionary feed_dict = new(); + bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + if (checkpoint_number is not null) + { + file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; + } + + Tensor file_prefix_tensor; + Tensor object_graph_tensor; + if (use_session) + { + if (_object_graph_feed_tensor is null) + { + // In python there is `with ops.device("/cpu:0")`. + _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); + _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); + } + + object_graph_tensor = _object_graph_feed_tensor; + file_prefix_tensor = _file_prefix_feed_tensor; + feed_dict[file_prefix_tensor] = file_prefix; + } + else + { + // In python there is `with ops.device("/cpu:0")`. + file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); + object_graph_tensor = null; + } + + var (save_path, new_feed_additions) = + save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + + if (new_feed_additions is not null) + { + foreach (var pair in new_feed_additions) + { + feed_dict.Add(pair.Key, pair.Value); + } + } + if(!use_session) + { + session = null; + } + else if (session is null) + { + session = new Session(); // In python it uses `get_session`. + } + + if (session is not null) + { + var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); + return session.run((Tensor)save_path, s); + } + else if (use_session) + { + throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + + "in graph mode without a default session. Please use " + + "`with tf.Session():` to create a session."); + } + else + { + return save_path; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs new file mode 100644 index 00000000..ed1f3ec4 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -0,0 +1,75 @@ +using System; +using Tensorflow.Train; +using System.Collections.Generic; +using System.IO; + +namespace Tensorflow.Checkpoint; + +public class TrackableView +{ + protected WeakReference _root_ref; + public TrackableView(Trackable obj) + { + _root_ref = new WeakReference(obj); + } + + public TrackableView(WeakReference obj) + { + _root_ref = obj; + } + + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + obj._maybe_initialize_trackable(); + // Note: in python the return type of `Trackable._trackable_children` is not fixed. + // Therefore it uses `convert_to_trackable` to have an extra process. + return obj._trackable_children(save_type); + } + + public Trackable Root + { + get + { + if (_root_ref.TryGetTarget(out Trackable res)) + { + return res; + } + else + { + throw new InvalidDataException( + "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); + } + } + } + + /// + /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. + /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths + /// + protected (List, Dictionary>) _descendants_with_paths() + { + List bfs_sorted = new(); + Queue to_visit = new(); + Dictionary> node_paths = new(); + node_paths[this.Root] = new List(); + while (!to_visit.empty()) + { + var current_trackable = to_visit.Dequeue(); + bfs_sorted.Add(current_trackable); + var children_dict = this.children(current_trackable); + foreach (var name in children_dict.Keys) + { + var dependency = children_dict[name]; + if (!node_paths.ContainsKey(dependency)) + { + var list = new List(node_paths[current_trackable]); + list.Add(new TrackableReference(name, dependency)); + node_paths[dependency] = list; + to_visit.Enqueue(dependency); + } + } + } + + return (bfs_sorted, node_paths); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs new file mode 100644 index 00000000..84ec24cb --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -0,0 +1,14 @@ +namespace Tensorflow.Exceptions; + +public class AssertionError : TensorflowException +{ + public AssertionError() : base() + { + + } + + public AssertionError(string message) : base(message) + { + + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 6ce3bf3c..cce13b55 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -304,7 +304,7 @@ namespace Tensorflow } } - private static OpList stripped_op_list_for_graph(GraphDef graph_def) + public static OpList stripped_op_list_for_graph(GraphDef graph_def) { var used_ops = ops_used_by_graph_def(graph_def); @@ -345,5 +345,66 @@ namespace Tensorflow return used_ops.ToArray(); } + + private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) + { + foreach (var attr_def in op_def.Attr) + { + if (attr_def.Name == attr_name) + { + if (attr_def.DefaultValue is null) return false; + // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. + return true; + } + } + + return false; + } + + public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) + { + Dictionary op_name_to_function = new(); + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + op_name_to_function[function_def.Signature.Name] = function_def; + } + + Action _strip_node_default_valued_attrs = (node_def) => + { + if (op_name_to_function.ContainsKey(node_def.Op)) return; + + var op_def = op_def_registry.GetOpDef(node_def.Op); + if(op_def is null) return; + + HashSet attrs_to_strip = new(); + foreach (var attr in node_def.Attr) + { + if (is_default_attr_value(op_def, attr.Key, attr.Value)) + { + attrs_to_strip.Add(attr.Key); + } + } + + foreach (var attr in attrs_to_strip) + { + node_def.Attr.Remove(attr); + } + }; + + foreach (var node_def in meta_graph_def.GraphDef.Node) + { + _strip_node_default_valued_attrs(node_def); + } + + foreach (var function_def in meta_graph_def.GraphDef.Library.Function) + { + foreach (var function_node_def in function_def.NodeDef) + { + _strip_node_default_valued_attrs(function_node_def); + } + } + + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index c52d0b5f..bac9cedb 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Graphs; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -10,7 +11,7 @@ namespace Tensorflow.Functions /// /// /// - public class ConcreteFunction + public class ConcreteFunction: Trackable { FuncGraph func_graph; ForwardBackwardCall forward_backward; diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index d57097ae..056d15f4 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -1,16 +1,23 @@ using System; +using Tensorflow.Train; namespace Tensorflow { - public class Function + public class Function: Trackable { #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - + + public string Name { get; set; } public Function() { } + + public Function(string name) + { + Name = name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index e25537d8..fce42850 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving /// public class SaveOptions { - bool save_debug_info; + public bool save_debug_info = false; + public IList? namespace_white_list { get; set; } = null; + public IDictionary? function_aliases { get; set; } = null; + public string? experimental_io_device { get; set; } = null; + // TODO: experimental + public Object? experimental_variable_polict { get; set; } = null; + public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index ee751acf..d5a32c10 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.Train; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -38,6 +39,11 @@ namespace Tensorflow { return var is ResourceVariable; } + + public static bool is_resource_variable(Trackable var) + { + return var is BaseResourceVariable; + } /// /// Creates a variable handle with information to do shape inference. diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 9d3e854a..f2597574 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -156,7 +156,7 @@ namespace Tensorflow { /// Nodes[0] is considered the root node. /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public pbc::RepeatedField Nodes { + public pbc::RepeatedField Nodes { get { return nodes_; } } @@ -286,6 +286,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public SavedObject(SavedObject other) : this() { children_ = other.children_.Clone(); + dependencies_ = other.dependencies_.Clone(); slotVariables_ = other.slotVariables_.Clone(); saveableObjects_ = other.saveableObjects_.Clone(); switch (other.KindCase) { @@ -328,6 +329,7 @@ namespace Tensorflow { private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// /// Objects which this object depends on: named edges in the dependency /// graph. @@ -338,6 +340,11 @@ namespace Tensorflow { public pbc::RepeatedField Children { get { return children_; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Dependencies { + get { return dependencies_; } + } /// Field number for the "slot_variables" field. public const int SlotVariablesFieldNumber = 3; @@ -617,6 +624,7 @@ namespace Tensorflow { return; } children_.Add(other.children_); + dependencies_.Add(other.dependencies_); slotVariables_.Add(other.slotVariables_); saveableObjects_.Add(other.saveableObjects_); switch (other.KindCase) { diff --git a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs index 3aa747c2..93413667 100644 --- a/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs @@ -198,6 +198,12 @@ namespace Tensorflow { public TrackableObject() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public TrackableObject(pbc::RepeatedField slot) { + OnConstruction(); + slotVariables_ = slot; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index d2198e37..d8f6314b 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -2,5 +2,20 @@ { public abstract class AutoTrackable : Trackable { + public void _delete_tracking(string name) + { + _maybe_initialize_trackable(); + if (_unconditional_dependency_names.ContainsKey(name)) + { + _unconditional_dependency_names.Remove(name); + for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) + { + if (_unconditional_checkpoint_dependencies[i].Name == name) + { + _unconditional_checkpoint_dependencies.RemoveAt(i); + } + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Optimizer.cs b/src/TensorFlowNET.Core/Training/Optimizer.cs index f985c656..e656fe96 100644 --- a/src/TensorFlowNET.Core/Training/Optimizer.cs +++ b/src/TensorFlowNET.Core/Training/Optimizer.cs @@ -351,7 +351,7 @@ namespace Tensorflow /// /// /// - protected IVariableV1 get_slot(IVariableV1 var, string name) + internal IVariableV1 get_slot(IVariableV1 var, string name) { var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; if (named_slots == null) @@ -360,6 +360,11 @@ namespace Tensorflow return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; } + internal IEnumerable get_slot_names() + { + return _slots.Keys; + } + private string _var_key(IVariableV1 var) { return $"{var.Op.graph.graph_key}.{var.Op.name}"; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index c86075f8..6239030b 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -48,4 +48,18 @@ namespace Tensorflow validate_shape: restored_shapes == null && op.shape.IsFullyDefined); } } + + public class NoRestoreSaveable: MySaveableObject + { + public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, + new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) + { + + } + + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) + { + return control_flow_ops.no_op(); + } + } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs new file mode 100644 index 00000000..24c8f2f0 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -0,0 +1,11 @@ +using System.Collections.Generic; + +namespace Tensorflow; + +public record class AssetInfo +( + List asset_defs, + Dictionary asset_initializers_by_resource, + Dictionary asset_filename_map, + Dictionary asset_index +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs new file mode 100644 index 00000000..6723206c --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -0,0 +1,60 @@ +using System; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; + +namespace Tensorflow; + +public class AugmentedGraphView: ObjectGraphView +{ + // private object _children_cache; + // private object _serialization_cache; + private List _untraces_functions; + public AugmentedGraphView(Trackable root): base(root) + { + _untraces_functions = new(); + } + + public void set_signature(object signature_map, object wrapped_functions) + { + // TODO: cache + list_children(Root); + } + + public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + { + Dictionary children = new(); + foreach (var pair in base.list_children(obj, save_type)) + { + var name = pair.Name; + var child = pair.Refer; + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + + return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + } + + public override (List, Dictionary>) breadth_first_traversal() + { + // TODO: implement it if needed. + return base.breadth_first_traversal(); + } + + public List<(string, Trackable)> list_dependencies(Trackable obj) + { + // TODO: deal with cache. + return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + } + + public Trackable get_child(Trackable obj, string name) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs new file mode 100644 index 00000000..cb7abada --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -0,0 +1,33 @@ +namespace Tensorflow; + +public static class Constants +{ + public static readonly string ASSETS_DIRECTORY = "assets"; + public static readonly string ASSETS_KEY = "saved_model_assets"; + + public static readonly string DEBUG_DIRECTORY = "debug"; + + public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; + + public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; + + public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; + + public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; + + public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; + + public static readonly string MAIN_OP_KEY = "saved_model_main_op"; + + public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; + public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; + + public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; + + public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; + + public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; + + public static readonly string VARIABLES_DIRECTORY = "variables"; + public static readonly string VARIABLES_FILENAME = "variables"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs new file mode 100644 index 00000000..fa9d6e50 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -0,0 +1,17 @@ +using Tensorflow.Train; + +namespace Tensorflow; + +public class RevivedTypes +{ + /// + /// Create a SavedUserObject from a trackable object. + /// + /// + /// + public static SavedUserObject? serialize(Trackable obj) + { + // TODO: complete the implementation. + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs new file mode 100644 index 00000000..b973fd41 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -0,0 +1,9 @@ +using System; + +namespace Tensorflow; + +public enum SaveType +{ + SAVEDMODEL, + CHECKPOINT +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs new file mode 100644 index 00000000..6a241f0e --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -0,0 +1,299 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class SaveableView +{ + private AugmentedGraphView _augmented_graph_view; + private SaveOptions _options; + private List _trackable_objects; + private List _nodes; + private Dictionary> _node_paths; + private Dictionary _node_ids; + private IDictionary> + _slot_variables; + private Dictionary _object_names; + private List _gradient_functions; // to be completed + private List _gradient_defs; // to be completed + private List _concrete_functions; + private Dictionary _captured_tensor_node_ids; + private Dictionary> _saveable_objects_map; + private Dictionary _obj_to_registered_saver; + + public AugmentedGraphView AugmentedGraphView + { + get => _augmented_graph_view; + } + + public Trackable Root + { + get => _nodes[0]; + } + public List Nodes + { + get => _nodes; + } + public Dictionary NodeIds + { + get => _node_ids; + } + public List GradientDefs + { + get => _gradient_defs; + } + public Dictionary> NodePaths + { + get => _node_paths; + } + public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) + { + _augmented_graph_view = augmented_graph_view; + _options = options; + + (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = + CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); + + // TODO: deal with untraced functions. + + initialize_save_and_restore_functions(); + initialize_nodes_and_concrete_functions(); + + _captured_tensor_node_ids = new(); + } + + private void initialize_save_and_restore_functions() + { + // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. + SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. + _obj_to_registered_saver = new(); + _saveable_objects_map = new(); + } + + private void initialize_nodes_and_concrete_functions() + { + _nodes = _trackable_objects.ConvertAll(x => x); // deep copy + _gradient_functions = new(); + _gradient_defs = new(); + + // TODO: deal with the condition that obj in `_saveable_objects_map`. + // foreach (var obj in _nodes) + // { + // + // } + + foreach (var obj in _nodes) + { + if (obj is ConcreteFunction) + { + _concrete_functions.Add((ConcreteFunction)obj); + } + } + } + + public List get_concrete_resource_initializers() + { + // TODO: complete the implementation. + return new List(); + } + + public (Dictionary, Dictionary, AssetInfo) map_resources() + { + Debug.Assert(!tf.Context.executing_eagerly()); + + Dictionary object_map = new(); + Dictionary tensor_map = new(); + + AssetInfo assetInfo = new(new List(), new Dictionary(), + new Dictionary(), new Dictionary()); + + foreach (var node_id in dependency_sorted_node_ids()) + { + var obj = _nodes[node_id]; + var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); + // TODO: deal with Asset (if obj is Asset) + foreach (var tensor in tensors) + { + _captured_tensor_node_ids[tensor] = node_id; + } + } + + return (object_map, tensor_map, assetInfo); + } + + /// + /// Returns topologically sorted nodes, sorted by dependencies. + /// + public List dependency_sorted_node_ids() + { + Dictionary> dependency_map = new(); + foreach (var node in _nodes) + { + var node_id = _node_ids[node]; + List deps = new(); + + // TODO: deal with captured tensor. + + string node_path; + foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) + { + if (!_node_ids.ContainsKey(dep)) + { + node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + throw new ValueError( + $"Found an untracked dependency. Object {node_path} depends on {dep}, " + + $"but this dependency isn't listed as a child. Please track this child by " + + $"overriding `_trackable_children` or use `._track_trackable`."); + } + deps.Add(_node_ids[dep]); + } + } + + try + { + return TrackableUtils.order_by_dependency(dependency_map); + } + catch (TrackableUtils.CyclicDependencyError err) + { + List pretty_printed_nodes = new(); + List pretty_printed_dependencies = new(); + + foreach (var pair in err.LeftOverDependencyMap) + { + var x = pair.Key; + var deps = pair.Value; + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); + pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); + pretty_printed_dependencies.Add( + $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); + } + + throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + + $"Saving cannot continue until this cycle is resolved." + + $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + + $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); + } + } + + /// + /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph + /// + /// + /// + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + { + SavedObjectGraph proto = new(); + fill_object_graph_proto(proto); + + // TODO: complete the process of concrete functions. + + int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); + for (int i = 0; i < cnt; i++) + { + var obj = _nodes[i]; + var obj_proto = proto.Nodes[i]; + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), + options); + } + + return proto; + } + + private static void write_object_proto(Trackable obj, SavedObject proto, + IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + { + // skip the process of type Asset + if (resource_variable_ops.is_resource_variable(obj)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is Function) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if (obj is ConcreteFunction) + { + // TODO: complete it. + throw new NotImplementedException(); + } + // skip the process of type `_CapturedTensor` and `CapturableResource`. + else + { + var registered_type_proto = RevivedTypes.serialize(obj); + if (registered_type_proto is null) + { + registered_type_proto = new SavedUserObject() + { + Identifier = obj.ObjectIdentifier, + Version = new VersionDef() + { + Producer = 1, + MinConsumer = 1, + BadConsumers = { } + } + }; + } + + proto.UserObject = new SavedUserObject(registered_type_proto); + } + + // TODO: try get the registered_name from `registration`. + } + + public void fill_object_graph_proto(SavedObjectGraph proto) + { + for (int node_id = 0; node_id < _nodes.Count; node_id++) + { + var node = _nodes[node_id]; + Debug.Assert(_node_ids[node] == node_id); + SavedObject object_proto = new(); + if (_slot_variables.TryGetValue(node, out var value)) + { + object_proto.SlotVariables.AddRange(value); + } + // skip the check of type `_CapturedTensor` + foreach (var child in _augmented_graph_view.list_children(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[child.Refer]; + child_proto.LocalName = child.Name; + object_proto.Children.Add(child_proto); + } + + foreach (var pair in _augmented_graph_view.list_dependencies(node)) + { + var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); + child_proto.NodeId = _node_ids[pair.Item2]; + child_proto.LocalName = pair.Item1; + object_proto.Dependencies.Add(child_proto); + } + + if (_saveable_objects_map.ContainsKey(node)) + { + // TODO: complete it. + throw new NotImplementedException(); + } + else if(_obj_to_registered_saver.ContainsKey(node)) + { + // TODO: complete it. + // We now skip it for the lack of `SavedObject.registered_saver` API. + throw new NotImplementedException(); + } + + proto.Nodes.Add(object_proto); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs new file mode 100644 index 00000000..9a066eed --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -0,0 +1,10 @@ +namespace Tensorflow; + +public static class TagConstants +{ + public static readonly string SERVING = "serve"; + public static readonly string TRAINING = "train"; + public static readonly string EVAL = "eval"; + public static readonly string GPU = "gpu"; + public static readonly string TPU = "tpu"; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs new file mode 100644 index 00000000..bcd3ae05 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public class BuilderUtils +{ + public static void copy_assets_to_destination_dir(IDictionary asset_filename_map, + string destination_dir, HashSet? saved_files = null) + { + if (saved_files is null) saved_files = new HashSet(); + + var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); + + // TODO: complete the implementation of this function. + if (asset_filename_map is not null && asset_filename_map.Count > 0) + { + throw new NotImplementedException(); + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs new file mode 100644 index 00000000..69235605 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -0,0 +1,256 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Google.Protobuf; +using Tensorflow.Checkpoint; +using Tensorflow.Functions; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + private static readonly IEnumerable byte_swappable = new List() + { + dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, + dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, + dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, + TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 + }.Select(x => (int)x); + + public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, + string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + { + if (options is null) + { + options = new SaveOptions(); + } + + var saved_model = new Tensorflow.SavedModel(); + var meta_graph_def = new MetaGraphDef(); + saved_model.MetaGraphs.Add(meta_graph_def); + + var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = + _build_meta_graph(obj, signatures, options, meta_graph_def); + saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; + + if (!experimental_skip_checkpoint) + { + Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); + CheckpointOptions ckpt_options = new(options.experimental_io_device); + object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + } + BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); + + if (tf.Context.executing_eagerly()) + { + // tensorflow python has a check of `context.async_wait()` here. + } + + // TODO: deal with `pywrap_saved_model.Save(export_dir)`. + + var saved_model_serialized = saved_model.ToString(); + + // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. + if (true) + { + var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), + tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); + // TODO: add c api and complete the fingerprint def. + var fingerprint_proto = ""; + File.WriteAllText(fingerprint_path, fingerprint_proto); + } + + var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); + File.WriteAllText(path, saved_model.ToString()); + + if (options.save_debug_info) + { + throw new NotImplementedException(); + } + + ops.dismantle_graph(exported_graph); + + return (saved_nodes, node_paths); + } + + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, + Dictionary>) _build_meta_graph(Trackable obj, + IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + { + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } + + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } + + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is not null) + { + throw new NotImplementedException(); + } + + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) + { + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } + } + + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } + + private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, + IDictionary signatures, IEnumerable namespace_whitelist, + bool save_custom_gradients) + { + var resource_initializers = saveable_view.get_concrete_resource_initializers(); + var exported_graph = new Graph(); + + Dictionary object_map; + Dictionary tensor_map; + AssetInfo asset_info; + using (var g = exported_graph.as_default()) + { + (object_map, tensor_map, asset_info) = saveable_view.map_resources(); + // TODO: deal with signatures. + if (save_custom_gradients) + { + // TODO: trace gradient functions. + } + + foreach (var resource_initializer_function in resource_initializers) + { + // List asset_dependencies = new(); + // TODO: deal with initializers + } + + // using(ops.control_dependencies(...)) + var init_op = control_flow_ops.no_op(); + if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); + } + else + { + meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); + } + // Lack `CopyFrom` API + // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + } + + foreach (var obj in object_map.Values) + { + obj._maybe_initialize_trackable(); + } + + var (named_saveable_objects, registered_savers) = + SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); + + // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + + saveable_view.dependency_sorted_node_ids(); + + var graph_def = exported_graph.as_graph_def(true); + graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); + verify_ops(graph_def, namespace_whitelist); + + meta_graph_def.GraphDef = new GraphDef(graph_def); + meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); + meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; + // TODO: add git version. + meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; + meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; + meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); + meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); + + // TODO: deal with signatures here. + + meta_graph.strip_graph_default_valued_attrs(meta_graph_def); + + if (!BitConverter.IsLittleEndian) + { + swap_function_tensor_content(meta_graph_def); + } + + return (asset_info, exported_graph); + } + + private static void verify_ops(GraphDef graph_def, IEnumerable? namespace_whitelist) + { + return; + // if (namespace_whitelist is null || !namespace_whitelist.Any()) + // { + // return; + // } + + // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. + } + + public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) + { + var functions = meta_graph_def.GraphDef.Library.Function; + foreach (var function in functions) + { + var node_def = function.NodeDef; + foreach (var node in node_def) + { + if (node.Op == "Const") + { + var tensor = node.Attr["value"].Tensor; + byte_swap_tensor_content(tensor); + } + } + } + } + + public static void byte_swap_tensor_content(TensorProto tensor) + { + if (byte_swappable.Contains((int)tensor.Dtype)) + { + var tshape = tensor.TensorShape.Dim; + var tensor_bytes = tensor.TensorContent; + if (tensor_bytes is not null && !tensor_bytes.IsEmpty) + { + long tensor_size = 1; + foreach (var sz in tshape) + { + tensor_size *= sz.Size; + } + + var chunksize = tensor_bytes.Length / tensor_size; + List reversed_bytes = new(); + for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) + { + var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); + reversed_bytes.AddRange(current); + } + tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs new file mode 100644 index 00000000..21272941 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Train; + +namespace Tensorflow; + +public class SignatureMap: Trackable +{ + private Dictionary _signatures; + private Dictionary _concrete_signatures; + + public SignatureMap() + { + _signatures = new(); + } + + public void _add_signature(string name, ConcreteFunction concrete_function) + { + _concrete_signatures[name] = concrete_function; + } + + public void _add_signature(string name, Function concrete_function) + { + _signatures[name] = concrete_function; + } + + public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + if (save_type != SaveType.SAVEDMODEL) + { + return new Dictionary(); + } + + Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); + foreach (var pair in _concrete_signatures) + { + res[pair.Key] = pair.Value; + } + + return res; + } + + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + // TODO: assert the arg_keywords + signature_map._add_signature(name, func); + } + + return signature_map; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs new file mode 100644 index 00000000..723419f6 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -0,0 +1,52 @@ +using System.IO; +using System.Security.Cryptography.X509Certificates; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static partial class SavedModelUtils +{ + /// + /// Return variables sub-directory, or create one if it doesn't exist. + /// + /// + public static string get_or_create_variables_dir(string export_dir) + { + var variables_dir = get_variables_dir(export_dir); + Directory.CreateDirectory(variables_dir); + return variables_dir; + } + + /// + /// Return variables sub-directory in the SavedModel. + /// + /// + /// + public static string get_variables_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); + } + + /// + /// Return assets sub-directory, or create one if it doesn't exist. + /// + /// + /// + public static string get_or_create_assets_dir(string export_dir) + { + var assets_destination_dir = get_assets_dir(export_dir); + Directory.CreateDirectory(assets_destination_dir); + return assets_destination_dir; + } + + /// + /// Return path to asset directory in the SavedModel. + /// + /// + /// + public static string get_assets_dir(string export_dir) + { + return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 3a664788..98cdb274 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -17,12 +17,17 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow { - public class saveable_object_util + public static class saveable_object_util { + public class TrackableSaveable: MySaveableObject + { + + } /// /// Returns the variables and names that will be used for a Saver. /// @@ -121,5 +126,17 @@ namespace Tensorflow return names_to_saveables; } + + public static IDictionary saveable_objects_from_trackable(Trackable obj) + { + // TODO: complete the implementation. + return obj.gather_saveables_for_checkpoint(); + } + + public static bool trackable_has_serialize_to_tensor(Trackable obj) + { + // TODO: implement it. + return false; + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 79d6dca9..dce0be2a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -14,14 +14,38 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.ModelSaving; using static Tensorflow.Binding; namespace Tensorflow.Train { public abstract class Trackable { + /// + /// Corresponding to tensorflow/python/trackable/constants.py + /// + public static class Constants + { + public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; + public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; + public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; + } protected int _self_update_uid; + protected IDictionary _unconditional_dependency_names = + new Dictionary(); + + protected IList _unconditional_checkpoint_dependencies = new List(); + protected IDictionary _self_saveable_object_factories = + new Dictionary(); + public virtual string ObjectIdentifier + { + get => "_generic_user_object"; + } + /// /// Restore-on-create for a variable be saved with this `Checkpointable`. /// @@ -73,10 +97,63 @@ namespace Tensorflow.Train /// /// Initialize dependency management. /// - protected void _maybe_initialize_trackable() + public void _maybe_initialize_trackable() { // _self_unconditional_checkpoint_dependencies = [] _self_update_uid = -1; } + + // TODO: cache + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + { + _maybe_initialize_trackable(); + return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); + } + + public static Trackable convert_to_trackable(object obj, object? parent = null) + { + if (obj is Trackable) + { + return (Trackable)obj; + } + else + { + throw new NotImplementedException(); + } + } + + public virtual IDictionary deserialization_dependencies(IDictionary children) + { + return new Dictionary(); + } + + public virtual (IDictionary, IDictionary) map_resources( + SaveOptions? save_options) + { + return (new Dictionary(), new Dictionary()); + } + + public virtual List export_to_saved_model_graph(IDictionary? object_map = null, + IDictionary? tensor_map = null, SaveOptions? options = null) + { + var (self_object_map, self_tensor_map) = map_resources(options); + foreach (var pair in self_object_map) + { + object_map.Add(pair); + } + foreach (var pair in self_tensor_map) + { + tensor_map.Add(pair); + } + + return self_tensor_map.Keys.ToList(); + } + + public virtual IDictionary gather_saveables_for_checkpoint() + { + return _self_saveable_object_factories; + } } + + public record class TrackableReference(string Name, Trackable Refer); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs new file mode 100644 index 00000000..99020702 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -0,0 +1,148 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Exceptions; +using Tensorflow.Train; + +namespace Tensorflow.Training; + +public static class TrackableUtils +{ + public class CyclicDependencyError: System.Exception + { + public IDictionary> LeftOverDependencyMap { get; } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map; + } + public CyclicDependencyError(IDictionary> leftover_dependency_map): base() + { + LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); + } + } + private static string _ESCAPE_CHAR = "."; + private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + public static string object_path_to_string(IEnumerable node_path_arr) + { + return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); + } + + public static string escape_local_name(string name) + { + return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); + } + + public static string checkpoint_key(string object_path, string local_name) + { + var key_suffix = escape_local_name(local_name); + if (local_name == SERIALIZE_TO_TENSORS_NAME) + { + key_suffix = ""; + } + + return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; + } + + /// + /// Topologically sorts the keys of a map so that dependencies appear first. + /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + /// + /// + /// + public static List order_by_dependency(IDictionary> dependency_map) + { + Dictionary> reverse_dependency_map = new(); + foreach (var pair in dependency_map) + { + foreach (var dep in pair.Value) + { + if (reverse_dependency_map.ContainsKey(dep)) + { + reverse_dependency_map[dep].Add(pair.Key); + } + else + { + reverse_dependency_map[dep] = new HashSet(); + reverse_dependency_map[dep].Add(pair.Key); + } + } + } + + // Validate that all values in the dependency map are also keys. + var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); + if (unknown_keys.Count() > 0) + { + throw new ValueError( + $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); + } + + // Generate the list sorted by objects without dependencies -> dependencies. + // The returned list will reverse this. + List reversed_dependency_arr = new(); + + Queue to_visit = new(); + foreach (var x in dependency_map.Keys) + { + if (!reverse_dependency_map.ContainsKey(x)) + { + to_visit.Enqueue(x); + } + } + + while (to_visit.Count > 0) + { + var x = to_visit.Dequeue(); + reversed_dependency_arr.Add(x); + foreach (var dep in dependency_map[x].Distinct()) + { + var edges = reverse_dependency_map[dep]; + edges.Remove(x); + if (edges.Count == 0) + { + to_visit.Enqueue(dep); + if (!reverse_dependency_map.Remove(dep)) + { + throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); + } + } + } + } + + if (reverse_dependency_map.Count > 0) + { + Dictionary> leftover_dependency_map = new(); + foreach (var pair in reverse_dependency_map) + { + foreach (var x in pair.Value) + { + if (leftover_dependency_map.ContainsKey(x)) + { + leftover_dependency_map[x].Add(pair.Key); + } + else + { + leftover_dependency_map[x] = new List() { pair.Key }; + } + } + } + + throw new CyclicDependencyError(leftover_dependency_map); + } + + reversed_dependency_arr.Reverse(); + return reversed_dependency_arr; + } + + public static string pretty_print_node_path(IEnumerable paths) + { + if (paths.Count() == 0) + { + return "root object"; + } + else + { + return $"root.{string.Join(".", paths.Select(x => x.Name))}"; + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index b270ec57..0a050d0f 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -2,6 +2,7 @@ using System; using Tensorflow.Eager; using Tensorflow.Variables; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 95e8db57..bf5ae7be 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -566,5 +566,23 @@ namespace Tensorflow else throw new NotImplementedException(""); } + + public static bool inside_function() + { + return get_default_graph().building_function; + } + + public static void dismantle_graph(Graph graph) + { + + } + + public class NullContextManager: IDisposable + { + public void Dispose() + { + + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs new file mode 100644 index 00000000..1675fba1 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -0,0 +1,31 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Engine; + +public abstract partial class Layer +{ + public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + + public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + + public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + { + IDictionary children; + if (save_type == SaveType.SAVEDMODEL) + { + // TODO: deal with cache. + children = TrackableSavedModelSaver.trackable_children(cache); + } + else + { + children = new Dictionary(); + } + + return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index ba40b1a2..e95e55d6 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine public bool Built => built; public bool Trainable => args.Trainable; public TF_DataType DType => args.DType; + public bool AutoCast => args.Autocast; + public IRegularizer ActivityRegularizer => args.ActivityRegularizer; /// /// A stateful layer is a layer whose updates are run during inference too, @@ -162,7 +164,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - /// + /// /// protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index c287309d..59f74cd2 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,5 +1,7 @@ using System.Collections.Generic; +using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; namespace Tensorflow.Keras.Engine @@ -18,9 +20,18 @@ namespace Tensorflow.Keras.Engine bool overwrite = true, bool include_optimizer = true, string save_format = "tf", - SaveOptions options = null) + SaveOptions? options = null, + IDictionary? signatures = null, + bool save_traces = true) { - saver.save(this, filepath); + if (save_format != "pb") + { + saver.save(this, filepath); + } + else + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 162d06c5..835f6041 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -35,6 +35,12 @@ namespace Tensorflow.Keras.Engine bool _base_model_initialized; bool stop_training; DataHandler data_handler; + + public OptimizerV2 Optimizer + { + get => optimizer; + set => optimizer = value; + } public Model(ModelArgs args) : base(args) diff --git a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs index 61cec646..f29f2dec 100644 --- a/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs +++ b/src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs @@ -194,6 +194,18 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { OnConstruction(); } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(int nodeId, string nodePath, + global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) + { + OnConstruction(); + nodeId_ = nodeId; + nodePath_ = nodePath; + identifier_ = identifier; + metadata_ = metadata; + version_ = version; + } + partial void OnConstruction(); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Keras/Protobuf/Versions.cs b/src/TensorFlowNET.Keras/Protobuf/Versions.cs index 40405a5a..ff9a23c6 100644 --- a/src/TensorFlowNET.Keras/Protobuf/Versions.cs +++ b/src/TensorFlowNET.Keras/Protobuf/Versions.cs @@ -74,6 +74,13 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { public VersionDef() { OnConstruction(); } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public VersionDef(int producer, int minConsumer) { + OnConstruction(); + producer_ = producer; + minConsumer_ = minConsumer; + } partial void OnConstruction(); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs new file mode 100644 index 00000000..ea6853fd --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public static class Constants +{ + /// + /// Namespace used to store all attributes added during serialization. + /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an + /// object loaded from `tf.saved_model.load()`. + /// + public static readonly string KERAS_ATTR = "keras_api"; + /// + /// Keys for the serialization cache. + /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} + /// + public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; + /// + /// Name of Keras metadata file stored in the SavedModel. + /// + public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; + + public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; + public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; + public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; + public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; + public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; + public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; + public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; + + public static readonly IList KERAS_OBJECT_IDENTIFIERS = new List() + { + INPUT_LAYER_IDENTIFIER, + LAYER_IDENTIFIER, + METRIC_IDENTIFIER, + MODEL_IDENTIFIER, + NETWORK_IDENTIFIER, + RNN_LAYER_IDENTIFIER, + SEQUENTIAL_IDENTIFIER + }; +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs new file mode 100644 index 00000000..a5f315bb --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs @@ -0,0 +1,11 @@ +namespace Tensorflow.Keras.Saving.SavedModel; + +public class KerasObjectWrapper +{ + +} + +public class KerasObjectWrapper +{ + public T Item { get; set; } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs new file mode 100644 index 00000000..76453ca0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Google.Protobuf; +using ICSharpCode.SharpZipLib.Zip; +using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Functions; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.ModelSaving; +using Tensorflow.Train; +using Tensorflow.Exceptions; +using Tensorflow.IO; +using Tensorflow.Keras.Optimizers; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + SaveOptions? options, bool save_traces = true) + { + if (!overwrite && File.Exists(filepath)) + { + throw new Exception("The file already exists but is not allowed to overwrite it."); + } + + if (save_traces) + { + if(should_skip_serialization(model)) + { + throw new NotImplementedException(); + } + } + + OptimizerV2? orig_optimizer = null; + if (!include_optimizer) + { + orig_optimizer = model.Optimizer; + model.Optimizer = null; + model._delete_tracking("optimizer"); + } + + IList saved_nodes; + IDictionary> node_paths; + // skip two scopes of python + using (KerasSavedModelUtils.keras_option_scope(save_traces)) + { + (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); + } + + var metadata = generate_keras_metadata(saved_nodes, node_paths); + using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, + FileAccess.Write)) + { + var writer = new StreamWriter(f); + writer.Write(metadata.ToString()); + } + + if (!include_optimizer) + { + model.Optimizer = orig_optimizer!; + } + } + + public static SavedMetadata generate_keras_metadata(IList saved_nodes, + IDictionary> node_paths) + { + var metadata = new SavedMetadata(); + for (int i = 0; i < saved_nodes.Count; i++) + { + var node = saved_nodes[i]; + if (node is not Layer) + { + continue; + } + + Layer layer = (Layer)node; + + var path = node_paths[node]; + string node_path; + if (path is null) + { + node_path = "root"; + } + else + { + node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; + } + + ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() + { + NodeId = i, + NodePath = node_path, + Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() + { + Producer = 2, + MinConsumer = 1, + BadConsumers = { } + }, + Identifier = layer.ObjectIdentifier, + Metadata = layer.TrackingMetadata + }; + + } + + return metadata; + } + + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs new file mode 100644 index 00000000..ba0bcc66 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -0,0 +1,19 @@ +using System.Collections.Generic; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool should_skip_serialization(object layer) + { + return false; + } + + public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + { + // TODO: process the loss + + return null; + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs new file mode 100644 index 00000000..36111a18 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -0,0 +1,40 @@ +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Keras.Engine; +using Newtonsoft.Json; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public abstract class SavedModelSaver +{ + private Trackable _obj; + public SavedModelSaver(Trackable obj) + { + _obj = obj; + } + + public abstract string ObjectIdentifier { get; } + public abstract string TrackingMetadata { get; } + + public abstract IDictionary objects_to_serialize( + IDictionary serialization_cache); + + public abstract IDictionary functions_to_serialize( + IDictionary serialization_cache); + + public IDictionary trackable_children(IDictionary? serialization_cache) + { + if (!KerasSavedModelUtils.ShouldHaveTraces) + { + return new Dictionary(); + } + + var children = objects_to_serialize(serialization_cache); + + return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) + .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + .ToDictionary(x => x.Key, x => x.Value); + } + +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs new file mode 100644 index 00000000..ade8ae73 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public class LayerSavedModelSaver: SavedModelSaver +{ + private Layer _obj; + public LayerSavedModelSaver(Layer obj): base(obj) + { + _obj = obj; + } + public override string ObjectIdentifier + { + get => Constants.LAYER_IDENTIFIER; + } + + public override IDictionary objects_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override IDictionary functions_to_serialize(IDictionary serialization_cache) + { + throw new System.NotImplementedException(); + } + + public override string TrackingMetadata + { + get + { + JObject metadata = new JObject(); + metadata["name"] = _obj.Name; + metadata["trainable"] = _obj.Trainable; + // metadata["expects_training_arg"] = _obj._expects_training_arg; + // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + // metadata["stateful"] = _obj.stateful; + // metadata["must_restore_from_config"] = _obj.must_restore_from_config; + // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; + metadata["autocast"] = _obj.AutoCast; + + metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + { + // Handle conflicts by using values from obj2 + MergeArrayHandling = MergeArrayHandling.Merge + }); + // skip the check of `input_spec` and `build_input_shape` for the lack of members. + // skip the check of `activity_regularizer` for the type problem. + return metadata.ToString(); + } + } + + public static LayerConfig get_serialized(Layer obj) + { + return generic_utils.serialize_keras_object(obj); + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs new file mode 100644 index 00000000..30e89582 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -0,0 +1,33 @@ +using System; + +namespace Tensorflow.Keras.Saving.SavedModel; + +public partial class KerasSavedModelUtils +{ + public static bool ShouldHaveTraces { get; internal set; } + + public static SaveOptionsContext keras_option_scope(bool save_traces) + { + var res = new SaveOptionsContext(ShouldHaveTraces); + ShouldHaveTraces = save_traces; + return res; + } +} + +/// +/// Implementation of this class is different with that of python. +/// But it could be used with `using` the same as `with` of python. +/// +public class SaveOptionsContext: IDisposable +{ + public bool _old_value; + public SaveOptionsContext(bool old_value) + { + _old_value = true; + } + + public void Dispose() + { + KerasSavedModelUtils.ShouldHaveTraces = _old_value; + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs new file mode 100644 index 00000000..9d1b3088 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -0,0 +1,60 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; + +namespace TensorFlowNET.Keras.UnitTest; + +// class MNISTLoader +// { +// public MNISTLoader() +// { +// var mnist = new MnistModelLoader() +// +// } +// } + +[TestClass] +public class SaveTest +{ + [TestMethod] + public void Test() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("", save_format:"pb"); + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 36ff4a3d..56c212d0 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -47,7 +47,7 @@ - +