| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| @@ -23,6 +25,26 @@ namespace Tensorflow | |||
| public class CompatApi | |||
| { | |||
| public CompatV1Api v1 { get; } = new CompatV1Api(); | |||
| internal string as_text(string bytes_or_text, Encoding? encoding = null) | |||
| { | |||
| if(encoding is null) encoding = Encoding.UTF8; | |||
| return bytes_or_text; | |||
| } | |||
| internal string as_text(byte[] bytes_or_text, Encoding? encoding = null) | |||
| { | |||
| if(encoding is null) encoding = Encoding.UTF8; | |||
| return encoding.GetString(bytes_or_text); | |||
| } | |||
| internal string as_str(string bytes_or_text, Encoding? encoding = null) | |||
| { | |||
| return as_text(bytes_or_text, encoding); | |||
| } | |||
| internal string as_str(byte[] bytes_or_text, Encoding? encoding = null) | |||
| { | |||
| return as_text(bytes_or_text, encoding); | |||
| } | |||
| } | |||
| public bool executing_eagerly() | |||
| @@ -0,0 +1,150 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| namespace Tensorflow.Checkpoint; | |||
| public static class CheckPointUtils | |||
| { | |||
| private static string _ESCAPE_CHAR = "."; | |||
| public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>, | |||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||
| Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| foreach (var pair in node_paths) | |||
| { | |||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||
| } | |||
| Dictionary<Trackable, int> node_ids = new(); | |||
| for (int i = 0; i < trackable_objects.Count; i++) | |||
| { | |||
| node_ids[trackable_objects[i]] = i; | |||
| } | |||
| var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names); | |||
| return (trackable_objects, node_paths, node_ids, slot_variables, object_names); | |||
| } | |||
| public static | |||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||
| serialize_slot_variables(IEnumerable<Trackable> trackable_objects, | |||
| IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names) | |||
| { | |||
| var non_slot_objects = trackable_objects.ToList(); | |||
| Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||
| slot_variables = new(); | |||
| foreach (var trackable in non_slot_objects) | |||
| { | |||
| if (trackable is not Optimizer) | |||
| { | |||
| continue; | |||
| } | |||
| var optim = (Optimizer)trackable; | |||
| var slot_names = optim.get_slot_names(); | |||
| foreach (var slot_name in slot_names) | |||
| { | |||
| for (int original_variable_node_id = 0; | |||
| original_variable_node_id < non_slot_objects.Count; | |||
| original_variable_node_id++) | |||
| { | |||
| var original_variable = non_slot_objects[original_variable_node_id]; | |||
| IVariableV1 slot_variable; | |||
| if (original_variable is not IVariableV1) | |||
| { | |||
| slot_variable = null; | |||
| } | |||
| slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name); | |||
| if(slot_variable is null) continue; | |||
| // There're some problems about the inherits of `Variable` and `Trackable`. | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| return slot_variables; | |||
| } | |||
| public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map) | |||
| { | |||
| if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res)) | |||
| { | |||
| return trackable; | |||
| } | |||
| else | |||
| { | |||
| return possible_res; | |||
| } | |||
| } | |||
| public static string get_full_name(Trackable var) | |||
| { | |||
| // TODO: This state is not correct, the whole framework need to be updated in the future. | |||
| if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) | |||
| { | |||
| return ""; | |||
| } | |||
| // skip the check of attribute `_save_slice_info` . | |||
| // TODO: Need to be revised!!! | |||
| return ((ResourceVariable)(object)var).Name; | |||
| } | |||
| public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) | |||
| { | |||
| HashSet<int> checkpointed_trackables = new(); | |||
| Dictionary<int, HashSet<int>> parents = new(); | |||
| for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||
| { | |||
| var object_proto = object_graph_proto.Nodes[i]; | |||
| // skip the process of registered saver. | |||
| if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 || | |||
| object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0) | |||
| { | |||
| checkpointed_trackables.Add(i); | |||
| } | |||
| foreach (var child_proto in object_proto.Children) | |||
| { | |||
| var child = child_proto.NodeId; | |||
| if (!parents.ContainsKey(child)) | |||
| { | |||
| parents[child] = new HashSet<int>(); | |||
| } | |||
| parents[child].Add(i); | |||
| } | |||
| } | |||
| Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable()); | |||
| while (to_visit.Count > 0) | |||
| { | |||
| var trackable = to_visit.Dequeue(); | |||
| if (!parents.ContainsKey(trackable)) continue; | |||
| var current_parents = parents[trackable]; | |||
| foreach (var parent in current_parents) | |||
| { | |||
| checkpointed_trackables.Add(parent); | |||
| if (parents.ContainsKey(parent)) | |||
| { | |||
| to_visit.Enqueue(parent); | |||
| } | |||
| } | |||
| parents.Remove(trackable); | |||
| } | |||
| // TODO: Complete it after supporting checkpoint. | |||
| // for (int i = 0; i < object_graph_proto.Nodes.Count; i++) | |||
| // { | |||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||
| // } | |||
| } | |||
| } | |||
| @@ -0,0 +1,5 @@ | |||
| namespace Tensorflow.Checkpoint; | |||
| public record class CheckpointOptions( | |||
| string experimental_io_device = null, | |||
| bool experimental_enable_async_checkpoint = false); | |||
| @@ -0,0 +1,63 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Serilog.Debugging; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Checkpoint; | |||
| public class ObjectGraphView: TrackableView, ICloneable | |||
| { | |||
| protected IEnumerable<TrackableReference>? _attached_dependencies; | |||
| // TODO: attached_dependencies | |||
| public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root) | |||
| { | |||
| _attached_dependencies = attached_dependencies; | |||
| } | |||
| public object Clone() | |||
| { | |||
| // TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__ | |||
| return new ObjectGraphView(Root, _attached_dependencies); | |||
| } | |||
| public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
| { | |||
| List<TrackableReference> res = base.children(obj, save_type) | |||
| .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||
| // Check the reference, not value. | |||
| if (obj == Root && _attached_dependencies is not null) | |||
| { | |||
| res.AddRange(_attached_dependencies); | |||
| } | |||
| return res; | |||
| } | |||
| public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
| { | |||
| return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); | |||
| } | |||
| public IEnumerable<TrackableReference>? AttachedDependencies | |||
| { | |||
| get => _attached_dependencies; | |||
| } | |||
| public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| { | |||
| return base._descendants_with_paths(); | |||
| } | |||
| // TODO: complete the implementation | |||
| public void serialize_object_graph(object? saveables_cache = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO: complete the implementation | |||
| public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,229 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Checkpoint; | |||
| public static class SaveUtilV1 | |||
| { | |||
| public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||
| IDictionary<Trackable, Trackable>? object_map = null) | |||
| { | |||
| // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | |||
| // till now only internal registrations are allowed. So, we won't return a saver in this function. | |||
| // The implementation of this function should be updated if tensorflow update it. | |||
| Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new(); | |||
| foreach (var pair in object_names) | |||
| { | |||
| var trackable = pair.Key; | |||
| var object_name = pair.Value; | |||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||
| // skip the registration process. | |||
| List<CheckpointFactoryData> current_list = new(); | |||
| foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save)) | |||
| { | |||
| // treat name as key_suffix. | |||
| var name = name_and_factory.Key; | |||
| var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name); | |||
| current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key)); | |||
| } | |||
| checkpoint_factory_map[trackable] = current_list; | |||
| } | |||
| return (checkpoint_factory_map, null); | |||
| } | |||
| public static (List<MySaveableObject>, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||
| object? saveables_cache = null) | |||
| { | |||
| Graph target_context; | |||
| if (to_graph is not null) | |||
| { | |||
| using (to_graph.as_default()) | |||
| { | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| // tensorflow python: `with ops.device("/cpu:0")` | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| return (named_saveable_objects, registered_savers); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| using (new ops.NullContextManager()) | |||
| { | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| // tensorflow python: `with ops.device("/cpu:0")` | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| return (named_saveable_objects, registered_savers); | |||
| } | |||
| } | |||
| } | |||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| foreach (var pair in node_paths) | |||
| { | |||
| object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value); | |||
| } | |||
| Dictionary<Trackable, int> node_ids = new(); | |||
| for (int i = 0; i < trackable_objects.Count; i++) | |||
| { | |||
| node_ids[trackable_objects[i]] = i; | |||
| } | |||
| var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names); | |||
| var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables); | |||
| var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph( | |||
| trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures, | |||
| saveables_cache); | |||
| CheckPointUtils.add_checkpoint_values_check(object_graph_proto); | |||
| return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers); | |||
| } | |||
| private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects, | |||
| IDictionary<Trackable, int> node_ids, | |||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||
| slot_variables) | |||
| { | |||
| TrackableObjectGraph object_graph_proto = new(); | |||
| for (int i = 0; i < trackable_objects.Count; i++) | |||
| { | |||
| var trackable = trackable_objects[i]; | |||
| Debug.Assert(node_ids[trackable] == i); | |||
| TrackableObjectGraph.Types.TrackableObject object_proto; | |||
| if (slot_variables.TryGetValue(trackable, out var slots)) | |||
| { | |||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(slots); | |||
| } | |||
| else | |||
| { | |||
| object_proto = new TrackableObjectGraph.Types.TrackableObject(); | |||
| } | |||
| object_graph_proto.Nodes.Add(object_proto); | |||
| foreach (var child in graph_view.list_children(trackable)) | |||
| { | |||
| object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference() | |||
| { NodeId = node_ids[child.Refer], LocalName = child.Name }); | |||
| } | |||
| } | |||
| return object_graph_proto; | |||
| } | |||
| private static (List<MySaveableObject>, object?, object?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||
| bool call_with_mapped_captures, object? saveables_cache = null) | |||
| { | |||
| int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count); | |||
| for (int i = 0; i < cnt; i++) | |||
| { | |||
| Debug.Assert(node_ids[trackable_objects[i]] == i); | |||
| } | |||
| var (checkpoint_factory_map, unmmaped_registered_savers) = | |||
| get_checkpoint_factories_and_keys(object_names, object_map); | |||
| // skip the process of registered savers | |||
| var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map, | |||
| object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache); | |||
| return (named_saveable_objects, feed_additions, null); | |||
| } | |||
| public static (List<MySaveableObject>, object?) generate_saveable_objects( | |||
| IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | |||
| TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | |||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||
| { | |||
| List<MySaveableObject> named_saveable_objects = new(); | |||
| foreach (var pair in checkpoint_factory_map) | |||
| { | |||
| var trackable = pair.Key; | |||
| var factory_data_list = pair.Value; | |||
| bool fill_object_proto = object_graph_proto is not null && node_ids is not null; | |||
| TrackableObjectGraph.Types.TrackableObject object_proto = null!; | |||
| if (fill_object_proto) | |||
| { | |||
| object_proto = object_graph_proto.Nodes[node_ids[trackable]]; | |||
| } | |||
| var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map); | |||
| // skip cache | |||
| foreach (var factory_data in factory_data_list) | |||
| { | |||
| var name = factory_data.name; | |||
| var key = factory_data.checkpoint_key; | |||
| var saveable_factory = factory_data.factory; | |||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||
| var maybe_saveable = saveable_factory; | |||
| IEnumerable<MySaveableObject> savesbles; | |||
| if (maybe_saveable is MySaveableObject) | |||
| { | |||
| savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable }; | |||
| } | |||
| else if (maybe_saveable is Tensor) | |||
| { | |||
| savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Unexpected type."); | |||
| } | |||
| foreach (var saveable in savesbles) | |||
| { | |||
| if (!saveable.name.Contains(key)) | |||
| { | |||
| throw new AssertionError($"The object {trackable} produced a SaveableObject with name " + | |||
| $"'{saveable.name}' for attribute '{name}'. Expected a name" + | |||
| $" containing '{key}'."); | |||
| } | |||
| } | |||
| // skip the process of PythonState | |||
| named_saveable_objects.AddRange(savesbles); | |||
| if(!fill_object_proto) continue; | |||
| // skip the process of TrackableSaveable | |||
| object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() | |||
| { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); | |||
| } | |||
| } | |||
| return (named_saveable_objects, null); | |||
| } | |||
| } | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| object factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -0,0 +1,109 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| namespace Tensorflow.Checkpoint; | |||
| public class TrackableSaver | |||
| { | |||
| private ObjectGraphView _graph_view; | |||
| private EagerTensor _cached_save_operation; | |||
| private TrackableObjectGraph _last_save_object_graph; | |||
| private Tensor? _object_graph_feed_tensor = null; | |||
| private Tensor? _file_prefix_feed_tensor = null; | |||
| public TrackableSaver(ObjectGraphView graph_view) | |||
| { | |||
| _graph_view = graph_view; | |||
| // TODO: cache when not executing eagerly. | |||
| // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, | |||
| // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` | |||
| } | |||
| private void gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| private (EagerTensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO: parameter write_done_callback | |||
| public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null, | |||
| CheckpointOptions? options = null) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new CheckpointOptions(); | |||
| } | |||
| Dictionary<Tensor, string> feed_dict = new(); | |||
| bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); | |||
| if (checkpoint_number is not null) | |||
| { | |||
| file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; | |||
| } | |||
| Tensor file_prefix_tensor; | |||
| Tensor object_graph_tensor; | |||
| if (use_session) | |||
| { | |||
| if (_object_graph_feed_tensor is null) | |||
| { | |||
| // In python there is `with ops.device("/cpu:0")`. | |||
| _object_graph_feed_tensor = constant_op.constant("", dtypes.variant); | |||
| _file_prefix_feed_tensor = constant_op.constant("", dtypes.variant); | |||
| } | |||
| object_graph_tensor = _object_graph_feed_tensor; | |||
| file_prefix_tensor = _file_prefix_feed_tensor; | |||
| feed_dict[file_prefix_tensor] = file_prefix; | |||
| } | |||
| else | |||
| { | |||
| // In python there is `with ops.device("/cpu:0")`. | |||
| file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant); | |||
| object_graph_tensor = null; | |||
| } | |||
| var (save_path, new_feed_additions) = | |||
| save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); | |||
| if (new_feed_additions is not null) | |||
| { | |||
| foreach (var pair in new_feed_additions) | |||
| { | |||
| feed_dict.Add(pair.Key, pair.Value); | |||
| } | |||
| } | |||
| if(!use_session) | |||
| { | |||
| session = null; | |||
| } | |||
| else if (session is null) | |||
| { | |||
| session = new Session(); // In python it uses `get_session`. | |||
| } | |||
| if (session is not null) | |||
| { | |||
| var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray(); | |||
| return session.run((Tensor)save_path, s); | |||
| } | |||
| else if (use_session) | |||
| { | |||
| throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " + | |||
| "in graph mode without a default session. Please use " + | |||
| "`with tf.Session():` to create a session."); | |||
| } | |||
| else | |||
| { | |||
| return save_path; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,75 @@ | |||
| using System; | |||
| using Tensorflow.Train; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| namespace Tensorflow.Checkpoint; | |||
| public class TrackableView | |||
| { | |||
| protected WeakReference<Trackable> _root_ref; | |||
| public TrackableView(Trackable obj) | |||
| { | |||
| _root_ref = new WeakReference<Trackable>(obj); | |||
| } | |||
| public TrackableView(WeakReference<Trackable> obj) | |||
| { | |||
| _root_ref = obj; | |||
| } | |||
| public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
| { | |||
| obj._maybe_initialize_trackable(); | |||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||
| // Therefore it uses `convert_to_trackable` to have an extra process. | |||
| return obj._trackable_children(save_type); | |||
| } | |||
| public Trackable Root | |||
| { | |||
| get | |||
| { | |||
| if (_root_ref.TryGetTarget(out Trackable res)) | |||
| { | |||
| return res; | |||
| } | |||
| else | |||
| { | |||
| throw new InvalidDataException( | |||
| "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor."); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | |||
| /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | |||
| /// </summary> | |||
| protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||
| { | |||
| List<Trackable> bfs_sorted = new(); | |||
| Queue<Trackable> to_visit = new(); | |||
| Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new(); | |||
| node_paths[this.Root] = new List<TrackableReference>(); | |||
| while (!to_visit.empty()) | |||
| { | |||
| var current_trackable = to_visit.Dequeue(); | |||
| bfs_sorted.Add(current_trackable); | |||
| var children_dict = this.children(current_trackable); | |||
| foreach (var name in children_dict.Keys) | |||
| { | |||
| var dependency = children_dict[name]; | |||
| if (!node_paths.ContainsKey(dependency)) | |||
| { | |||
| var list = new List<TrackableReference>(node_paths[current_trackable]); | |||
| list.Add(new TrackableReference(name, dependency)); | |||
| node_paths[dependency] = list; | |||
| to_visit.Enqueue(dependency); | |||
| } | |||
| } | |||
| } | |||
| return (bfs_sorted, node_paths); | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| namespace Tensorflow.Exceptions; | |||
| public class AssertionError : TensorflowException | |||
| { | |||
| public AssertionError() : base() | |||
| { | |||
| } | |||
| public AssertionError(string message) : base(message) | |||
| { | |||
| } | |||
| } | |||
| @@ -304,7 +304,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||
| public static OpList stripped_op_list_for_graph(GraphDef graph_def) | |||
| { | |||
| var used_ops = ops_used_by_graph_def(graph_def); | |||
| @@ -345,5 +345,66 @@ namespace Tensorflow | |||
| return used_ops.ToArray(); | |||
| } | |||
| private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value) | |||
| { | |||
| foreach (var attr_def in op_def.Attr) | |||
| { | |||
| if (attr_def.Name == attr_name) | |||
| { | |||
| if (attr_def.DefaultValue is null) return false; | |||
| // TODO: add new c_api `EqualAttrValueWrapper` and complete the check. | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def) | |||
| { | |||
| Dictionary<string, FunctionDef> op_name_to_function = new(); | |||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||
| { | |||
| op_name_to_function[function_def.Signature.Name] = function_def; | |||
| } | |||
| Action<NodeDef> _strip_node_default_valued_attrs = (node_def) => | |||
| { | |||
| if (op_name_to_function.ContainsKey(node_def.Op)) return; | |||
| var op_def = op_def_registry.GetOpDef(node_def.Op); | |||
| if(op_def is null) return; | |||
| HashSet<string> attrs_to_strip = new(); | |||
| foreach (var attr in node_def.Attr) | |||
| { | |||
| if (is_default_attr_value(op_def, attr.Key, attr.Value)) | |||
| { | |||
| attrs_to_strip.Add(attr.Key); | |||
| } | |||
| } | |||
| foreach (var attr in attrs_to_strip) | |||
| { | |||
| node_def.Attr.Remove(attr); | |||
| } | |||
| }; | |||
| foreach (var node_def in meta_graph_def.GraphDef.Node) | |||
| { | |||
| _strip_node_default_valued_attrs(node_def); | |||
| } | |||
| foreach (var function_def in meta_graph_def.GraphDef.Library.Function) | |||
| { | |||
| foreach (var function_node_def in function_def.NodeDef) | |||
| { | |||
| _strip_node_default_valued_attrs(function_node_def); | |||
| } | |||
| } | |||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Functions | |||
| @@ -10,7 +11,7 @@ namespace Tensorflow.Functions | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| public class ConcreteFunction | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| FuncGraph func_graph; | |||
| ForwardBackwardCall forward_backward; | |||
| @@ -1,16 +1,23 @@ | |||
| using System; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| { | |||
| public class Function | |||
| public class Function: Trackable | |||
| { | |||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | |||
| private IntPtr _handle; | |||
| #pragma warning restore CS0169 // The field 'Function._handle' is never used | |||
| public string Name { get; set; } | |||
| public Function() | |||
| { | |||
| } | |||
| public Function(string name) | |||
| { | |||
| Name = name; | |||
| } | |||
| } | |||
| } | |||
| @@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving | |||
| /// </summary> | |||
| public class SaveOptions | |||
| { | |||
| bool save_debug_info; | |||
| public bool save_debug_info = false; | |||
| public IList<string>? namespace_white_list { get; set; } = null; | |||
| public IDictionary<string, object>? function_aliases { get; set; } = null; | |||
| public string? experimental_io_device { get; set; } = null; | |||
| // TODO: experimental | |||
| public Object? experimental_variable_polict { get; set; } = null; | |||
| public bool experimental_custom_gradients { get; set; } = true; | |||
| public SaveOptions(bool save_debug_info = false) | |||
| { | |||
| this.save_debug_info = save_debug_info; | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow | |||
| @@ -38,6 +39,11 @@ namespace Tensorflow | |||
| { | |||
| return var is ResourceVariable; | |||
| } | |||
| public static bool is_resource_variable(Trackable var) | |||
| { | |||
| return var is BaseResourceVariable; | |||
| } | |||
| /// <summary> | |||
| /// Creates a variable handle with information to do shape inference. | |||
| @@ -156,7 +156,7 @@ namespace Tensorflow { | |||
| /// Nodes[0] is considered the root node. | |||
| /// </summary> | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||
| public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes { | |||
| get { return nodes_; } | |||
| } | |||
| @@ -286,6 +286,7 @@ namespace Tensorflow { | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject(SavedObject other) : this() { | |||
| children_ = other.children_.Clone(); | |||
| dependencies_ = other.dependencies_.Clone(); | |||
| slotVariables_ = other.slotVariables_.Clone(); | |||
| saveableObjects_ = other.saveableObjects_.Clone(); | |||
| switch (other.KindCase) { | |||
| @@ -328,6 +329,7 @@ namespace Tensorflow { | |||
| private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec | |||
| = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); | |||
| /// <summary> | |||
| /// Objects which this object depends on: named edges in the dependency | |||
| /// graph. | |||
| @@ -338,6 +340,11 @@ namespace Tensorflow { | |||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { | |||
| get { return children_; } | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies { | |||
| get { return dependencies_; } | |||
| } | |||
| /// <summary>Field number for the "slot_variables" field.</summary> | |||
| public const int SlotVariablesFieldNumber = 3; | |||
| @@ -617,6 +624,7 @@ namespace Tensorflow { | |||
| return; | |||
| } | |||
| children_.Add(other.children_); | |||
| dependencies_.Add(other.dependencies_); | |||
| slotVariables_.Add(other.slotVariables_); | |||
| saveableObjects_.Add(other.saveableObjects_); | |||
| switch (other.KindCase) { | |||
| @@ -198,6 +198,12 @@ namespace Tensorflow { | |||
| public TrackableObject() { | |||
| OnConstruction(); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) { | |||
| OnConstruction(); | |||
| slotVariables_ = slot; | |||
| } | |||
| partial void OnConstruction(); | |||
| @@ -2,5 +2,20 @@ | |||
| { | |||
| public abstract class AutoTrackable : Trackable | |||
| { | |||
| public void _delete_tracking(string name) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| if (_unconditional_dependency_names.ContainsKey(name)) | |||
| { | |||
| _unconditional_dependency_names.Remove(name); | |||
| for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--) | |||
| { | |||
| if (_unconditional_checkpoint_dependencies[i].Name == name) | |||
| { | |||
| _unconditional_checkpoint_dependencies.RemoveAt(i); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -351,7 +351,7 @@ namespace Tensorflow | |||
| /// <param name="var"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| protected IVariableV1 get_slot(IVariableV1 var, string name) | |||
| internal IVariableV1 get_slot(IVariableV1 var, string name) | |||
| { | |||
| var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; | |||
| if (named_slots == null) | |||
| @@ -360,6 +360,11 @@ namespace Tensorflow | |||
| return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; | |||
| } | |||
| internal IEnumerable<string> get_slot_names() | |||
| { | |||
| return _slots.Keys; | |||
| } | |||
| private string _var_key(IVariableV1 var) | |||
| { | |||
| return $"{var.Op.graph.graph_key}.{var.Op.name}"; | |||
| @@ -48,4 +48,18 @@ namespace Tensorflow | |||
| validate_shape: restored_shapes == null && op.shape.IsFullyDefined); | |||
| } | |||
| } | |||
| public class NoRestoreSaveable: MySaveableObject | |||
| { | |||
| public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor, | |||
| new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name) | |||
| { | |||
| } | |||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||
| { | |||
| return control_flow_ops.no_op(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow; | |||
| public record class AssetInfo | |||
| ( | |||
| List<AssetFileDef> asset_defs, | |||
| Dictionary<object, object> asset_initializers_by_resource, | |||
| Dictionary<AssetInfo, string> asset_filename_map, | |||
| Dictionary<object, object> asset_index | |||
| ); | |||
| @@ -0,0 +1,60 @@ | |||
| using System; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Train; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow; | |||
| public class AugmentedGraphView: ObjectGraphView | |||
| { | |||
| // private object _children_cache; | |||
| // private object _serialization_cache; | |||
| private List<string> _untraces_functions; | |||
| public AugmentedGraphView(Trackable root): base(root) | |||
| { | |||
| _untraces_functions = new(); | |||
| } | |||
| public void set_signature(object signature_map, object wrapped_functions) | |||
| { | |||
| // TODO: cache | |||
| list_children(Root); | |||
| } | |||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
| { | |||
| Dictionary<string, Trackable> children = new(); | |||
| foreach (var pair in base.list_children(obj, save_type)) | |||
| { | |||
| var name = pair.Name; | |||
| var child = pair.Refer; | |||
| children[name] = child; | |||
| } | |||
| if (obj is Function && children.Count == 0) | |||
| { | |||
| _untraces_functions.Add(((Function)obj).Name); | |||
| } | |||
| return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); | |||
| } | |||
| public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| { | |||
| // TODO: implement it if needed. | |||
| return base.breadth_first_traversal(); | |||
| } | |||
| public List<(string, Trackable)> list_dependencies(Trackable obj) | |||
| { | |||
| // TODO: deal with cache. | |||
| return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); | |||
| } | |||
| public Trackable get_child(Trackable obj, string name) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| namespace Tensorflow; | |||
| public static class Constants | |||
| { | |||
| public static readonly string ASSETS_DIRECTORY = "assets"; | |||
| public static readonly string ASSETS_KEY = "saved_model_assets"; | |||
| public static readonly string DEBUG_DIRECTORY = "debug"; | |||
| public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"; | |||
| public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra"; | |||
| public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb"; | |||
| public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"; | |||
| public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op"; | |||
| public static readonly string MAIN_OP_KEY = "saved_model_main_op"; | |||
| public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb"; | |||
| public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"; | |||
| public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1; | |||
| public static readonly string TRAIN_OP_KEY = "saved_model_train_op"; | |||
| public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"; | |||
| public static readonly string VARIABLES_DIRECTORY = "variables"; | |||
| public static readonly string VARIABLES_FILENAME = "variables"; | |||
| } | |||
| @@ -0,0 +1,17 @@ | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow; | |||
| public class RevivedTypes | |||
| { | |||
| /// <summary> | |||
| /// Create a SavedUserObject from a trackable object. | |||
| /// </summary> | |||
| /// <param name="obj"></param> | |||
| /// <returns></returns> | |||
| public static SavedUserObject? serialize(Trackable obj) | |||
| { | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,9 @@ | |||
| using System; | |||
| namespace Tensorflow; | |||
| public enum SaveType | |||
| { | |||
| SAVEDMODEL, | |||
| CHECKPOINT | |||
| } | |||
| @@ -0,0 +1,299 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow; | |||
| public class SaveableView | |||
| { | |||
| private AugmentedGraphView _augmented_graph_view; | |||
| private SaveOptions _options; | |||
| private List<Trackable> _trackable_objects; | |||
| private List<Trackable> _nodes; | |||
| private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||
| private Dictionary<Trackable, int> _node_ids; | |||
| private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||
| _slot_variables; | |||
| private Dictionary<Trackable, string> _object_names; | |||
| private List<object> _gradient_functions; // to be completed | |||
| private List<RegisteredGradient> _gradient_defs; // to be completed | |||
| private List<ConcreteFunction> _concrete_functions; | |||
| private Dictionary<Tensor, int> _captured_tensor_node_ids; | |||
| private Dictionary<Trackable, IDictionary<string, ConcreteFunction>> _saveable_objects_map; | |||
| private Dictionary<Trackable, string> _obj_to_registered_saver; | |||
| public AugmentedGraphView AugmentedGraphView | |||
| { | |||
| get => _augmented_graph_view; | |||
| } | |||
| public Trackable Root | |||
| { | |||
| get => _nodes[0]; | |||
| } | |||
| public List<Trackable> Nodes | |||
| { | |||
| get => _nodes; | |||
| } | |||
| public Dictionary<Trackable, int> NodeIds | |||
| { | |||
| get => _node_ids; | |||
| } | |||
| public List<RegisteredGradient> GradientDefs | |||
| { | |||
| get => _gradient_defs; | |||
| } | |||
| public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||
| { | |||
| get => _node_paths; | |||
| } | |||
| public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options) | |||
| { | |||
| _augmented_graph_view = augmented_graph_view; | |||
| _options = options; | |||
| (_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) = | |||
| CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view); | |||
| // TODO: deal with untraced functions. | |||
| initialize_save_and_restore_functions(); | |||
| initialize_nodes_and_concrete_functions(); | |||
| _captured_tensor_node_ids = new(); | |||
| } | |||
| private void initialize_save_and_restore_functions() | |||
| { | |||
| // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. | |||
| SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); | |||
| // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. | |||
| _obj_to_registered_saver = new(); | |||
| _saveable_objects_map = new(); | |||
| } | |||
| private void initialize_nodes_and_concrete_functions() | |||
| { | |||
| _nodes = _trackable_objects.ConvertAll(x => x); // deep copy | |||
| _gradient_functions = new(); | |||
| _gradient_defs = new(); | |||
| // TODO: deal with the condition that obj in `_saveable_objects_map`. | |||
| // foreach (var obj in _nodes) | |||
| // { | |||
| // | |||
| // } | |||
| foreach (var obj in _nodes) | |||
| { | |||
| if (obj is ConcreteFunction) | |||
| { | |||
| _concrete_functions.Add((ConcreteFunction)obj); | |||
| } | |||
| } | |||
| } | |||
| public List<ConcreteFunction> get_concrete_resource_initializers() | |||
| { | |||
| // TODO: complete the implementation. | |||
| return new List<ConcreteFunction>(); | |||
| } | |||
| public (Dictionary<Trackable, Trackable>, Dictionary<Tensor, Tensor>, AssetInfo) map_resources() | |||
| { | |||
| Debug.Assert(!tf.Context.executing_eagerly()); | |||
| Dictionary<Trackable, Trackable> object_map = new(); | |||
| Dictionary<Tensor, Tensor> tensor_map = new(); | |||
| AssetInfo assetInfo = new(new List<AssetFileDef>(), new Dictionary<object, object>(), | |||
| new Dictionary<AssetInfo, string>(), new Dictionary<object, object>()); | |||
| foreach (var node_id in dependency_sorted_node_ids()) | |||
| { | |||
| var obj = _nodes[node_id]; | |||
| var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options); | |||
| // TODO: deal with Asset (if obj is Asset) | |||
| foreach (var tensor in tensors) | |||
| { | |||
| _captured_tensor_node_ids[tensor] = node_id; | |||
| } | |||
| } | |||
| return (object_map, tensor_map, assetInfo); | |||
| } | |||
| /// <summary> | |||
| /// Returns topologically sorted nodes, sorted by dependencies. | |||
| /// </summary> | |||
| public List<int> dependency_sorted_node_ids() | |||
| { | |||
| Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||
| foreach (var node in _nodes) | |||
| { | |||
| var node_id = _node_ids[node]; | |||
| List<int> deps = new(); | |||
| // TODO: deal with captured tensor. | |||
| string node_path; | |||
| foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) | |||
| { | |||
| if (!_node_ids.ContainsKey(dep)) | |||
| { | |||
| node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); | |||
| throw new ValueError( | |||
| $"Found an untracked dependency. Object {node_path} depends on {dep}, " + | |||
| $"but this dependency isn't listed as a child. Please track this child by " + | |||
| $"overriding `_trackable_children` or use `._track_trackable`."); | |||
| } | |||
| deps.Add(_node_ids[dep]); | |||
| } | |||
| } | |||
| try | |||
| { | |||
| return TrackableUtils.order_by_dependency(dependency_map); | |||
| } | |||
| catch (TrackableUtils.CyclicDependencyError err) | |||
| { | |||
| List<string> pretty_printed_nodes = new(); | |||
| List<string> pretty_printed_dependencies = new(); | |||
| foreach (var pair in err.LeftOverDependencyMap) | |||
| { | |||
| var x = pair.Key; | |||
| var deps = pair.Value; | |||
| var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]); | |||
| pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})"); | |||
| pretty_printed_dependencies.Add( | |||
| $"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]"); | |||
| } | |||
| throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " + | |||
| $"Saving cannot continue until this cycle is resolved." + | |||
| $"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" + | |||
| $"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph | |||
| /// </summary> | |||
| /// <param name="asset_index"></param> | |||
| /// <returns></returns> | |||
| public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index, SaveOptions options) | |||
| { | |||
| SavedObjectGraph proto = new(); | |||
| fill_object_graph_proto(proto); | |||
| // TODO: complete the process of concrete functions. | |||
| int cnt = Math.Min(_nodes.Count, proto.Nodes.Count); | |||
| for (int i = 0; i < cnt; i++) | |||
| { | |||
| var obj = _nodes[i]; | |||
| var obj_proto = proto.Nodes[i]; | |||
| write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), | |||
| options); | |||
| } | |||
| return proto; | |||
| } | |||
| private static void write_object_proto(Trackable obj, SavedObject proto, | |||
| IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn, SaveOptions options) | |||
| { | |||
| // skip the process of type Asset | |||
| if (resource_variable_ops.is_resource_variable(obj)) | |||
| { | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| else if (obj is Function) | |||
| { | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| else if (obj is ConcreteFunction) | |||
| { | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| // skip the process of type `_CapturedTensor` and `CapturableResource`. | |||
| else | |||
| { | |||
| var registered_type_proto = RevivedTypes.serialize(obj); | |||
| if (registered_type_proto is null) | |||
| { | |||
| registered_type_proto = new SavedUserObject() | |||
| { | |||
| Identifier = obj.ObjectIdentifier, | |||
| Version = new VersionDef() | |||
| { | |||
| Producer = 1, | |||
| MinConsumer = 1, | |||
| BadConsumers = { } | |||
| } | |||
| }; | |||
| } | |||
| proto.UserObject = new SavedUserObject(registered_type_proto); | |||
| } | |||
| // TODO: try get the registered_name from `registration`. | |||
| } | |||
| public void fill_object_graph_proto(SavedObjectGraph proto) | |||
| { | |||
| for (int node_id = 0; node_id < _nodes.Count; node_id++) | |||
| { | |||
| var node = _nodes[node_id]; | |||
| Debug.Assert(_node_ids[node] == node_id); | |||
| SavedObject object_proto = new(); | |||
| if (_slot_variables.TryGetValue(node, out var value)) | |||
| { | |||
| object_proto.SlotVariables.AddRange(value); | |||
| } | |||
| // skip the check of type `_CapturedTensor` | |||
| foreach (var child in _augmented_graph_view.list_children(node)) | |||
| { | |||
| var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||
| child_proto.NodeId = _node_ids[child.Refer]; | |||
| child_proto.LocalName = child.Name; | |||
| object_proto.Children.Add(child_proto); | |||
| } | |||
| foreach (var pair in _augmented_graph_view.list_dependencies(node)) | |||
| { | |||
| var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference(); | |||
| child_proto.NodeId = _node_ids[pair.Item2]; | |||
| child_proto.LocalName = pair.Item1; | |||
| object_proto.Dependencies.Add(child_proto); | |||
| } | |||
| if (_saveable_objects_map.ContainsKey(node)) | |||
| { | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| else if(_obj_to_registered_saver.ContainsKey(node)) | |||
| { | |||
| // TODO: complete it. | |||
| // We now skip it for the lack of `SavedObject.registered_saver` API. | |||
| throw new NotImplementedException(); | |||
| } | |||
| proto.Nodes.Add(object_proto); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| namespace Tensorflow; | |||
| public static class TagConstants | |||
| { | |||
| public static readonly string SERVING = "serve"; | |||
| public static readonly string TRAINING = "train"; | |||
| public static readonly string EVAL = "eval"; | |||
| public static readonly string GPU = "gpu"; | |||
| public static readonly string TPU = "tpu"; | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow; | |||
| public class BuilderUtils | |||
| { | |||
| public static void copy_assets_to_destination_dir(IDictionary<AssetInfo, string> asset_filename_map, | |||
| string destination_dir, HashSet<string>? saved_files = null) | |||
| { | |||
| if (saved_files is null) saved_files = new HashSet<string>(); | |||
| var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir); | |||
| // TODO: complete the implementation of this function. | |||
| if (asset_filename_map is not null && asset_filename_map.Count > 0) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,256 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Google.Protobuf; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow; | |||
| public static partial class SavedModelUtils | |||
| { | |||
| private static readonly IEnumerable<int> byte_swappable = new List<TF_DataType>() | |||
| { | |||
| dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16, | |||
| dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32, | |||
| dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16, | |||
| TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32 | |||
| }.Select(x => (int)x); | |||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj, | |||
| string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new SaveOptions(); | |||
| } | |||
| var saved_model = new Tensorflow.SavedModel(); | |||
| var meta_graph_def = new MetaGraphDef(); | |||
| saved_model.MetaGraphs.Add(meta_graph_def); | |||
| var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) = | |||
| _build_meta_graph(obj, signatures, options, meta_graph_def); | |||
| saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION; | |||
| if (!experimental_skip_checkpoint) | |||
| { | |||
| Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); | |||
| CheckpointOptions ckpt_options = new(options.experimental_io_device); | |||
| object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||
| } | |||
| BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| // tensorflow python has a check of `context.async_wait()` here. | |||
| } | |||
| // TODO: deal with `pywrap_saved_model.Save(export_dir)`. | |||
| var saved_model_serialized = saved_model.ToString(); | |||
| // This is a state depending on some py-c APIs. Here we temporarily set it as `true`. | |||
| if (true) | |||
| { | |||
| var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir), | |||
| tf.compat.as_str(Constants.FINGERPRINT_FILENAME)); | |||
| // TODO: add c api and complete the fingerprint def. | |||
| var fingerprint_proto = ""; | |||
| File.WriteAllText(fingerprint_path, fingerprint_proto); | |||
| } | |||
| var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | |||
| File.WriteAllText(path, saved_model.ToString()); | |||
| if (options.save_debug_info) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| ops.dismantle_graph(exported_graph); | |||
| return (saved_nodes, node_paths); | |||
| } | |||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||
| Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||
| IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||
| { | |||
| if (ops.inside_function()) | |||
| { | |||
| throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + | |||
| "Move the call to the outer eagerly-executed context."); | |||
| } | |||
| if (meta_graph_def is null) | |||
| { | |||
| meta_graph_def = new MetaGraphDef(); | |||
| } | |||
| AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||
| if (signatures is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| // TODO: process of aignatures and wrapped_functions | |||
| SaveableView saveable_view = new SaveableView(augmented_graph_view, options); | |||
| TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); | |||
| var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, | |||
| options.namespace_white_list, options.experimental_custom_gradients); | |||
| if (options.function_aliases is not null) | |||
| { | |||
| var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; | |||
| foreach (var pair in options.function_aliases) | |||
| { | |||
| var alias = pair.Key; | |||
| var func = pair.Value; | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); | |||
| meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); | |||
| return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); | |||
| } | |||
| private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | |||
| IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist, | |||
| bool save_custom_gradients) | |||
| { | |||
| var resource_initializers = saveable_view.get_concrete_resource_initializers(); | |||
| var exported_graph = new Graph(); | |||
| Dictionary<Trackable, Trackable> object_map; | |||
| Dictionary<Tensor, Tensor> tensor_map; | |||
| AssetInfo asset_info; | |||
| using (var g = exported_graph.as_default()) | |||
| { | |||
| (object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||
| // TODO: deal with signatures. | |||
| if (save_custom_gradients) | |||
| { | |||
| // TODO: trace gradient functions. | |||
| } | |||
| foreach (var resource_initializer_function in resource_initializers) | |||
| { | |||
| // List<Trackable> asset_dependencies = new(); | |||
| // TODO: deal with initializers | |||
| } | |||
| // using(ops.control_dependencies(...)) | |||
| var init_op = control_flow_ops.no_op(); | |||
| if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY)) | |||
| { | |||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name); | |||
| } | |||
| else | |||
| { | |||
| meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef(); | |||
| } | |||
| // Lack `CopyFrom` API | |||
| // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||
| } | |||
| foreach (var obj in object_map.Values) | |||
| { | |||
| obj._maybe_initialize_trackable(); | |||
| } | |||
| var (named_saveable_objects, registered_savers) = | |||
| SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); | |||
| // TODO: complete the save of checkpoints with `MultiDeviceSaver`. | |||
| saveable_view.dependency_sorted_node_ids(); | |||
| var graph_def = exported_graph.as_graph_def(true); | |||
| graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs); | |||
| verify_ops(graph_def, namespace_whitelist); | |||
| meta_graph_def.GraphDef = new GraphDef(graph_def); | |||
| meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING); | |||
| meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION; | |||
| // TODO: add git version. | |||
| meta_graph_def.MetaInfoDef.TensorflowGitVersion = ""; | |||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||
| meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef)); | |||
| meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs); | |||
| // TODO: deal with signatures here. | |||
| meta_graph.strip_graph_default_valued_attrs(meta_graph_def); | |||
| if (!BitConverter.IsLittleEndian) | |||
| { | |||
| swap_function_tensor_content(meta_graph_def); | |||
| } | |||
| return (asset_info, exported_graph); | |||
| } | |||
| private static void verify_ops(GraphDef graph_def, IEnumerable<string>? namespace_whitelist) | |||
| { | |||
| return; | |||
| // if (namespace_whitelist is null || !namespace_whitelist.Any()) | |||
| // { | |||
| // return; | |||
| // } | |||
| // skip the check for the lack of `meta_graph.ops_used_by_graph_def`. | |||
| } | |||
| public static void swap_function_tensor_content(MetaGraphDef meta_graph_def) | |||
| { | |||
| var functions = meta_graph_def.GraphDef.Library.Function; | |||
| foreach (var function in functions) | |||
| { | |||
| var node_def = function.NodeDef; | |||
| foreach (var node in node_def) | |||
| { | |||
| if (node.Op == "Const") | |||
| { | |||
| var tensor = node.Attr["value"].Tensor; | |||
| byte_swap_tensor_content(tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| public static void byte_swap_tensor_content(TensorProto tensor) | |||
| { | |||
| if (byte_swappable.Contains((int)tensor.Dtype)) | |||
| { | |||
| var tshape = tensor.TensorShape.Dim; | |||
| var tensor_bytes = tensor.TensorContent; | |||
| if (tensor_bytes is not null && !tensor_bytes.IsEmpty) | |||
| { | |||
| long tensor_size = 1; | |||
| foreach (var sz in tshape) | |||
| { | |||
| tensor_size *= sz.Size; | |||
| } | |||
| var chunksize = tensor_bytes.Length / tensor_size; | |||
| List<byte> reversed_bytes = new(); | |||
| for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize) | |||
| { | |||
| var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse(); | |||
| reversed_bytes.AddRange(current); | |||
| } | |||
| tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,58 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow; | |||
| public class SignatureMap: Trackable | |||
| { | |||
| private Dictionary<string, Function> _signatures; | |||
| private Dictionary<string, ConcreteFunction> _concrete_signatures; | |||
| public SignatureMap() | |||
| { | |||
| _signatures = new(); | |||
| } | |||
| public void _add_signature(string name, ConcreteFunction concrete_function) | |||
| { | |||
| _concrete_signatures[name] = concrete_function; | |||
| } | |||
| public void _add_signature(string name, Function concrete_function) | |||
| { | |||
| _signatures[name] = concrete_function; | |||
| } | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
| { | |||
| if (save_type != SaveType.SAVEDMODEL) | |||
| { | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); | |||
| foreach (var pair in _concrete_signatures) | |||
| { | |||
| res[pair.Key] = pair.Value; | |||
| } | |||
| return res; | |||
| } | |||
| public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures) | |||
| { | |||
| var signature_map = new SignatureMap(); | |||
| foreach (var pair in signatures) | |||
| { | |||
| var name = pair.Key; | |||
| var func = pair.Value; | |||
| // TODO: assert the arg_keywords | |||
| signature_map._add_signature(name, func); | |||
| } | |||
| return signature_map; | |||
| } | |||
| } | |||
| @@ -0,0 +1,52 @@ | |||
| using System.IO; | |||
| using System.Security.Cryptography.X509Certificates; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow; | |||
| public static partial class SavedModelUtils | |||
| { | |||
| /// <summary> | |||
| /// Return variables sub-directory, or create one if it doesn't exist. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static string get_or_create_variables_dir(string export_dir) | |||
| { | |||
| var variables_dir = get_variables_dir(export_dir); | |||
| Directory.CreateDirectory(variables_dir); | |||
| return variables_dir; | |||
| } | |||
| /// <summary> | |||
| /// Return variables sub-directory in the SavedModel. | |||
| /// </summary> | |||
| /// <param name="export_dir"></param> | |||
| /// <returns></returns> | |||
| public static string get_variables_dir(string export_dir) | |||
| { | |||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); | |||
| } | |||
| /// <summary> | |||
| /// Return assets sub-directory, or create one if it doesn't exist. | |||
| /// </summary> | |||
| /// <param name="export_dir"></param> | |||
| /// <returns></returns> | |||
| public static string get_or_create_assets_dir(string export_dir) | |||
| { | |||
| var assets_destination_dir = get_assets_dir(export_dir); | |||
| Directory.CreateDirectory(assets_destination_dir); | |||
| return assets_destination_dir; | |||
| } | |||
| /// <summary> | |||
| /// Return path to asset directory in the SavedModel. | |||
| /// </summary> | |||
| /// <param name="export_dir"></param> | |||
| /// <returns></returns> | |||
| public static string get_assets_dir(string export_dir) | |||
| { | |||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | |||
| } | |||
| } | |||
| @@ -17,12 +17,17 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class saveable_object_util | |||
| public static class saveable_object_util | |||
| { | |||
| public class TrackableSaveable: MySaveableObject | |||
| { | |||
| } | |||
| /// <summary> | |||
| /// Returns the variables and names that will be used for a Saver. | |||
| /// </summary> | |||
| @@ -121,5 +126,17 @@ namespace Tensorflow | |||
| return names_to_saveables; | |||
| } | |||
| public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj) | |||
| { | |||
| // TODO: complete the implementation. | |||
| return obj.gather_saveables_for_checkpoint(); | |||
| } | |||
| public static bool trackable_has_serialize_to_tensor(Trackable obj) | |||
| { | |||
| // TODO: implement it. | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,14 +14,38 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.ModelSaving; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| { | |||
| public abstract class Trackable | |||
| { | |||
| /// <summary> | |||
| /// Corresponding to tensorflow/python/trackable/constants.py | |||
| /// </summary> | |||
| public static class Constants | |||
| { | |||
| public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"; | |||
| public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE"; | |||
| public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"; | |||
| } | |||
| protected int _self_update_uid; | |||
| protected IDictionary<string, Trackable> _unconditional_dependency_names = | |||
| new Dictionary<string, Trackable>(); | |||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||
| protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | |||
| new Dictionary<string, ResourceVariable>(); | |||
| public virtual string ObjectIdentifier | |||
| { | |||
| get => "_generic_user_object"; | |||
| } | |||
| /// <summary> | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| /// </summary> | |||
| @@ -73,10 +97,63 @@ namespace Tensorflow.Train | |||
| /// <summary> | |||
| /// Initialize dependency management. | |||
| /// </summary> | |||
| protected void _maybe_initialize_trackable() | |||
| public void _maybe_initialize_trackable() | |||
| { | |||
| // _self_unconditional_checkpoint_dependencies = [] | |||
| _self_update_uid = -1; | |||
| } | |||
| // TODO: cache | |||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||
| } | |||
| public static Trackable convert_to_trackable(object obj, object? parent = null) | |||
| { | |||
| if (obj is Trackable) | |||
| { | |||
| return (Trackable)obj; | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children) | |||
| { | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources( | |||
| SaveOptions? save_options) | |||
| { | |||
| return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>()); | |||
| } | |||
| public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null, | |||
| IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null) | |||
| { | |||
| var (self_object_map, self_tensor_map) = map_resources(options); | |||
| foreach (var pair in self_object_map) | |||
| { | |||
| object_map.Add(pair); | |||
| } | |||
| foreach (var pair in self_tensor_map) | |||
| { | |||
| tensor_map.Add(pair); | |||
| } | |||
| return self_tensor_map.Keys.ToList(); | |||
| } | |||
| public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint() | |||
| { | |||
| return _self_saveable_object_factories; | |||
| } | |||
| } | |||
| public record class TrackableReference(string Name, Trackable Refer); | |||
| } | |||
| @@ -0,0 +1,148 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Training; | |||
| public static class TrackableUtils | |||
| { | |||
| public class CyclicDependencyError: System.Exception | |||
| { | |||
| public IDictionary<int, IEnumerable<int>> LeftOverDependencyMap { get; } | |||
| public CyclicDependencyError(IDictionary<int, IEnumerable<int>> leftover_dependency_map): base() | |||
| { | |||
| LeftOverDependencyMap = leftover_dependency_map; | |||
| } | |||
| public CyclicDependencyError(IDictionary<int, List<int>> leftover_dependency_map): base() | |||
| { | |||
| LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||
| } | |||
| } | |||
| private static string _ESCAPE_CHAR = "."; | |||
| private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
| private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
| private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
| public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||
| { | |||
| return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); | |||
| } | |||
| public static string escape_local_name(string name) | |||
| { | |||
| return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S"); | |||
| } | |||
| public static string checkpoint_key(string object_path, string local_name) | |||
| { | |||
| var key_suffix = escape_local_name(local_name); | |||
| if (local_name == SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| key_suffix = ""; | |||
| } | |||
| return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}"; | |||
| } | |||
| /// <summary> | |||
| /// Topologically sorts the keys of a map so that dependencies appear first. | |||
| /// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm | |||
| /// </summary> | |||
| /// <param name="dependency_map"></param> | |||
| /// <exception cref="ValueError"></exception> | |||
| public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||
| { | |||
| Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | |||
| foreach (var pair in dependency_map) | |||
| { | |||
| foreach (var dep in pair.Value) | |||
| { | |||
| if (reverse_dependency_map.ContainsKey(dep)) | |||
| { | |||
| reverse_dependency_map[dep].Add(pair.Key); | |||
| } | |||
| else | |||
| { | |||
| reverse_dependency_map[dep] = new HashSet<int>(); | |||
| reverse_dependency_map[dep].Add(pair.Key); | |||
| } | |||
| } | |||
| } | |||
| // Validate that all values in the dependency map are also keys. | |||
| var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys); | |||
| if (unknown_keys.Count() > 0) | |||
| { | |||
| throw new ValueError( | |||
| $"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}"); | |||
| } | |||
| // Generate the list sorted by objects without dependencies -> dependencies. | |||
| // The returned list will reverse this. | |||
| List<int> reversed_dependency_arr = new(); | |||
| Queue<int> to_visit = new(); | |||
| foreach (var x in dependency_map.Keys) | |||
| { | |||
| if (!reverse_dependency_map.ContainsKey(x)) | |||
| { | |||
| to_visit.Enqueue(x); | |||
| } | |||
| } | |||
| while (to_visit.Count > 0) | |||
| { | |||
| var x = to_visit.Dequeue(); | |||
| reversed_dependency_arr.Add(x); | |||
| foreach (var dep in dependency_map[x].Distinct()) | |||
| { | |||
| var edges = reverse_dependency_map[dep]; | |||
| edges.Remove(x); | |||
| if (edges.Count == 0) | |||
| { | |||
| to_visit.Enqueue(dep); | |||
| if (!reverse_dependency_map.Remove(dep)) | |||
| { | |||
| throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (reverse_dependency_map.Count > 0) | |||
| { | |||
| Dictionary<int, List<int>> leftover_dependency_map = new(); | |||
| foreach (var pair in reverse_dependency_map) | |||
| { | |||
| foreach (var x in pair.Value) | |||
| { | |||
| if (leftover_dependency_map.ContainsKey(x)) | |||
| { | |||
| leftover_dependency_map[x].Add(pair.Key); | |||
| } | |||
| else | |||
| { | |||
| leftover_dependency_map[x] = new List<int>() { pair.Key }; | |||
| } | |||
| } | |||
| } | |||
| throw new CyclicDependencyError(leftover_dependency_map); | |||
| } | |||
| reversed_dependency_arr.Reverse(); | |||
| return reversed_dependency_arr; | |||
| } | |||
| public static string pretty_print_node_path(IEnumerable<TrackableReference> paths) | |||
| { | |||
| if (paths.Count() == 0) | |||
| { | |||
| return "root object"; | |||
| } | |||
| else | |||
| { | |||
| return $"root.{string.Join(".", paths.Select(x => x.Name))}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Variables; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -566,5 +566,23 @@ namespace Tensorflow | |||
| else | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public static bool inside_function() | |||
| { | |||
| return get_default_graph().building_function; | |||
| } | |||
| public static void dismantle_graph(Graph graph) | |||
| { | |||
| } | |||
| public class NullContextManager: IDisposable | |||
| { | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Engine; | |||
| public abstract partial class Layer | |||
| { | |||
| public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||
| public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
| public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||
| { | |||
| IDictionary<string, Trackable> children; | |||
| if (save_type == SaveType.SAVEDMODEL) | |||
| { | |||
| // TODO: deal with cache. | |||
| children = TrackableSavedModelSaver.trackable_children(cache); | |||
| } | |||
| else | |||
| { | |||
| children = new Dictionary<string, Trackable>(); | |||
| } | |||
| return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| @@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine | |||
| public bool Built => built; | |||
| public bool Trainable => args.Trainable; | |||
| public TF_DataType DType => args.DType; | |||
| public bool AutoCast => args.Autocast; | |||
| public IRegularizer ActivityRegularizer => args.ActivityRegularizer; | |||
| /// <summary> | |||
| /// A stateful layer is a layer whose updates are run during inference too, | |||
| @@ -162,7 +164,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="state"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <param name="training"></param> | |||
| /// <returns></returns> | |||
| protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| @@ -1,5 +1,7 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.ModelSaving; | |||
| namespace Tensorflow.Keras.Engine | |||
| @@ -18,9 +20,18 @@ namespace Tensorflow.Keras.Engine | |||
| bool overwrite = true, | |||
| bool include_optimizer = true, | |||
| string save_format = "tf", | |||
| SaveOptions options = null) | |||
| SaveOptions? options = null, | |||
| IDictionary<string, ConcreteFunction>? signatures = null, | |||
| bool save_traces = true) | |||
| { | |||
| saver.save(this, filepath); | |||
| if (save_format != "pb") | |||
| { | |||
| saver.save(this, filepath); | |||
| } | |||
| else | |||
| { | |||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -35,6 +35,12 @@ namespace Tensorflow.Keras.Engine | |||
| bool _base_model_initialized; | |||
| bool stop_training; | |||
| DataHandler data_handler; | |||
| public OptimizerV2 Optimizer | |||
| { | |||
| get => optimizer; | |||
| set => optimizer = value; | |||
| } | |||
| public Model(ModelArgs args) | |||
| : base(args) | |||
| @@ -194,6 +194,18 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { | |||
| OnConstruction(); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public SavedObject(int nodeId, string nodePath, | |||
| global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata) | |||
| { | |||
| OnConstruction(); | |||
| nodeId_ = nodeId; | |||
| nodePath_ = nodePath; | |||
| identifier_ = identifier; | |||
| metadata_ = metadata; | |||
| version_ = version; | |||
| } | |||
| partial void OnConstruction(); | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| @@ -74,6 +74,13 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf { | |||
| public VersionDef() { | |||
| OnConstruction(); | |||
| } | |||
| [global::System.Diagnostics.DebuggerNonUserCodeAttribute] | |||
| public VersionDef(int producer, int minConsumer) { | |||
| OnConstruction(); | |||
| producer_ = producer; | |||
| minConsumer_ = minConsumer; | |||
| } | |||
| partial void OnConstruction(); | |||
| @@ -0,0 +1,41 @@ | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public static class Constants | |||
| { | |||
| /// <summary> | |||
| /// Namespace used to store all attributes added during serialization. | |||
| /// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an | |||
| /// object loaded from `tf.saved_model.load()`. | |||
| /// </summary> | |||
| public static readonly string KERAS_ATTR = "keras_api"; | |||
| /// <summary> | |||
| /// Keys for the serialization cache. | |||
| /// Maps to the keras serialization dict {Layer --> SerializedAttributes object} | |||
| /// </summary> | |||
| public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes"; | |||
| /// <summary> | |||
| /// Name of Keras metadata file stored in the SavedModel. | |||
| /// </summary> | |||
| public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb"; | |||
| public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer"; | |||
| public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer"; | |||
| public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric"; | |||
| public static readonly string MODEL_IDENTIFIER = "_tf_keras_model"; | |||
| public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network"; | |||
| public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer"; | |||
| public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential"; | |||
| public static readonly IList<string> KERAS_OBJECT_IDENTIFIERS = new List<string>() | |||
| { | |||
| INPUT_LAYER_IDENTIFIER, | |||
| LAYER_IDENTIFIER, | |||
| METRIC_IDENTIFIER, | |||
| MODEL_IDENTIFIER, | |||
| NETWORK_IDENTIFIER, | |||
| RNN_LAYER_IDENTIFIER, | |||
| SEQUENTIAL_IDENTIFIER | |||
| }; | |||
| } | |||
| @@ -0,0 +1,11 @@ | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public class KerasObjectWrapper | |||
| { | |||
| } | |||
| public class KerasObjectWrapper<T> | |||
| { | |||
| public T Item { get; set; } | |||
| } | |||
| @@ -0,0 +1,115 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Google.Protobuf; | |||
| using ICSharpCode.SharpZipLib.Zip; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.IO; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures, | |||
| SaveOptions? options, bool save_traces = true) | |||
| { | |||
| if (!overwrite && File.Exists(filepath)) | |||
| { | |||
| throw new Exception("The file already exists but is not allowed to overwrite it."); | |||
| } | |||
| if (save_traces) | |||
| { | |||
| if(should_skip_serialization(model)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| OptimizerV2? orig_optimizer = null; | |||
| if (!include_optimizer) | |||
| { | |||
| orig_optimizer = model.Optimizer; | |||
| model.Optimizer = null; | |||
| model._delete_tracking("optimizer"); | |||
| } | |||
| IList<Trackable> saved_nodes; | |||
| IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths; | |||
| // skip two scopes of python | |||
| using (KerasSavedModelUtils.keras_option_scope(save_traces)) | |||
| { | |||
| (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options); | |||
| } | |||
| var metadata = generate_keras_metadata(saved_nodes, node_paths); | |||
| using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, | |||
| FileAccess.Write)) | |||
| { | |||
| var writer = new StreamWriter(f); | |||
| writer.Write(metadata.ToString()); | |||
| } | |||
| if (!include_optimizer) | |||
| { | |||
| model.Optimizer = orig_optimizer!; | |||
| } | |||
| } | |||
| public static SavedMetadata generate_keras_metadata(IList<Trackable> saved_nodes, | |||
| IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths) | |||
| { | |||
| var metadata = new SavedMetadata(); | |||
| for (int i = 0; i < saved_nodes.Count; i++) | |||
| { | |||
| var node = saved_nodes[i]; | |||
| if (node is not Layer) | |||
| { | |||
| continue; | |||
| } | |||
| Layer layer = (Layer)node; | |||
| var path = node_paths[node]; | |||
| string node_path; | |||
| if (path is null) | |||
| { | |||
| node_path = "root"; | |||
| } | |||
| else | |||
| { | |||
| node_path = $"root.{string.Join(".", path.Select(x => x.Name))}"; | |||
| } | |||
| ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new() | |||
| { | |||
| NodeId = i, | |||
| NodePath = node_path, | |||
| Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef() | |||
| { | |||
| Producer = 2, | |||
| MinConsumer = 1, | |||
| BadConsumers = { } | |||
| }, | |||
| Identifier = layer.ObjectIdentifier, | |||
| Metadata = layer.TrackingMetadata | |||
| }; | |||
| } | |||
| return metadata; | |||
| } | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static bool should_skip_serialization(object layer) | |||
| { | |||
| return false; | |||
| } | |||
| public static IDictionary<string, KerasObjectWrapper> wrap_layer_objects(Layer layer, object serialization_cache) | |||
| { | |||
| // TODO: process the loss | |||
| return null; | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public abstract class SavedModelSaver | |||
| { | |||
| private Trackable _obj; | |||
| public SavedModelSaver(Trackable obj) | |||
| { | |||
| _obj = obj; | |||
| } | |||
| public abstract string ObjectIdentifier { get; } | |||
| public abstract string TrackingMetadata { get; } | |||
| public abstract IDictionary<string, CheckpointableBase> objects_to_serialize( | |||
| IDictionary<string, object> serialization_cache); | |||
| public abstract IDictionary<string, Function> functions_to_serialize( | |||
| IDictionary<string, object> serialization_cache); | |||
| public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | |||
| { | |||
| if (!KerasSavedModelUtils.ShouldHaveTraces) | |||
| { | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| var children = objects_to_serialize(serialization_cache); | |||
| return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) | |||
| .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||
| .ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| @@ -0,0 +1,62 @@ | |||
| using System.Collections.Generic; | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public class LayerSavedModelSaver: SavedModelSaver | |||
| { | |||
| private Layer _obj; | |||
| public LayerSavedModelSaver(Layer obj): base(obj) | |||
| { | |||
| _obj = obj; | |||
| } | |||
| public override string ObjectIdentifier | |||
| { | |||
| get => Constants.LAYER_IDENTIFIER; | |||
| } | |||
| public override IDictionary<string, CheckpointableBase> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||
| { | |||
| throw new System.NotImplementedException(); | |||
| } | |||
| public override IDictionary<string, Function> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||
| { | |||
| throw new System.NotImplementedException(); | |||
| } | |||
| public override string TrackingMetadata | |||
| { | |||
| get | |||
| { | |||
| JObject metadata = new JObject(); | |||
| metadata["name"] = _obj.Name; | |||
| metadata["trainable"] = _obj.Trainable; | |||
| // metadata["expects_training_arg"] = _obj._expects_training_arg; | |||
| // metadata["dtype"] = policy.serialize(_obj._dtype_policy) | |||
| metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); | |||
| // metadata["stateful"] = _obj.stateful; | |||
| // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | |||
| // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | |||
| metadata["autocast"] = _obj.AutoCast; | |||
| metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings | |||
| { | |||
| // Handle conflicts by using values from obj2 | |||
| MergeArrayHandling = MergeArrayHandling.Merge | |||
| }); | |||
| // skip the check of `input_spec` and `build_input_shape` for the lack of members. | |||
| // skip the check of `activity_regularizer` for the type problem. | |||
| return metadata.ToString(); | |||
| } | |||
| } | |||
| public static LayerConfig get_serialized(Layer obj) | |||
| { | |||
| return generic_utils.serialize_keras_object(obj); | |||
| } | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| using System; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static bool ShouldHaveTraces { get; internal set; } | |||
| public static SaveOptionsContext keras_option_scope(bool save_traces) | |||
| { | |||
| var res = new SaveOptionsContext(ShouldHaveTraces); | |||
| ShouldHaveTraces = save_traces; | |||
| return res; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Implementation of this class is different with that of python. | |||
| /// But it could be used with `using` the same as `with` of python. | |||
| /// </summary> | |||
| public class SaveOptionsContext: IDisposable | |||
| { | |||
| public bool _old_value; | |||
| public SaveOptionsContext(bool old_value) | |||
| { | |||
| _old_value = true; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| KerasSavedModelUtils.ShouldHaveTraces = _old_value; | |||
| } | |||
| } | |||
| @@ -0,0 +1,60 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Optimizers; | |||
| namespace TensorFlowNET.Keras.UnitTest; | |||
| // class MNISTLoader | |||
| // { | |||
| // public MNISTLoader() | |||
| // { | |||
| // var mnist = new MnistModelLoader() | |||
| // | |||
| // } | |||
| // } | |||
| [TestClass] | |||
| public class SaveTest | |||
| { | |||
| [TestMethod] | |||
| public void Test() | |||
| { | |||
| var inputs = new KerasInterface().Input((28, 28, 1)); | |||
| var x = new Flatten(new FlattenArgs()).Apply(inputs); | |||
| x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); | |||
| x = new LayersApi().Dense(units: 10).Apply(x); | |||
| var outputs = new LayersApi().Softmax(axis: 1).Apply(x); | |||
| var model = new KerasInterface().Model(inputs, outputs); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); | |||
| var data_loader = new MnistModelLoader(); | |||
| var num_epochs = 1; | |||
| var batch_size = 50; | |||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = false, | |||
| ValidationSize = 50000, | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| model.save("", save_format:"pb"); | |||
| } | |||
| } | |||
| @@ -47,7 +47,7 @@ | |||
| <ItemGroup> | |||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" /> | |||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" /> | |||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" /> | |||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.8" /> | |||