diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 8ae2dae8..9812d3c6 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -149,4 +149,13 @@ public static class CheckPointUtils // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // } } + + /// + /// Traverse the object graph and list all accessible objects. + /// + /// + public static IList list_objects(ObjectGraphView graph_view) + { + return objects_ids_and_slot_variables_and_paths(graph_view).Item1; + } } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 0c2862da..a10e8953 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -6,8 +6,10 @@ using System.Linq; using Tensorflow.Contexts; using Tensorflow.Eager; using Tensorflow.Train; +using Tensorflow.Exceptions; using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; using static Tensorflow.Binding; +using Tensorflow.Operations; namespace Tensorflow.Checkpoint; @@ -21,8 +23,20 @@ public class TrackableSaver private TrackableObjectGraph _last_save_object_graph; private Tensor? _object_graph_feed_tensor = null; private Tensor? _file_prefix_feed_tensor = null; + private Tensor? _file_prefix_placeholder = null; private Dictionary? _object_map = null; private object? _cache = null; + public Tensor? FilePrefixPlaceHolder + { + get + { + return _file_prefix_placeholder; + } + set + { + _file_prefix_placeholder = value; + } + } public TrackableSaver(ObjectGraphView graph_view) { _graph_view = graph_view; @@ -192,4 +206,207 @@ public class TrackableSaver return save_path; } } + + public LoadStatus restore(string? save_path, CheckpointOptions? options = null) + { + if (options is null) + { + options = new CheckpointOptions(); + } + if(save_path is null) + { + return new InitializationOnlyStatus(_graph_view, ops.uid()); + } + + CheckpointReader reader = new CheckpointReader(save_path); + bool graph_building = tf.Context.executing_eagerly(); + Dictionary dtype_map = null; + if (!graph_building) + { + dtype_map = reader.VariableToDataTypeMap; + } + Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); + + Dictionary file_prefix_feed_dict; + Tensor file_prefix_tensor; + if (graph_building) + { + if(_file_prefix_placeholder is null) + { + tf.device("/cpu:0"); + _file_prefix_placeholder = constant_op.constant("model"); + } + file_prefix_tensor = _file_prefix_placeholder; + file_prefix_feed_dict = new(); + file_prefix_feed_dict[_file_prefix_placeholder] = save_path; + } + else + { + tf.device("/cpu:0"); + file_prefix_tensor = constant_op.constant(save_path); + file_prefix_feed_dict = null; + } + TrackableObjectGraph object_graph_proto = new(); + object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); + CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( + object_graph_proto: object_graph_proto, + save_path: save_path, + save_path_tensor: file_prefix_tensor, + reader: reader, + restore_op_cache: null, + graph_view: _graph_view, + options: options, + saveables_cache: null + ); + + throw new NotImplementedException(); + } +} + +internal class CheckpointRestoreCoordinator +{ + private CheckpointOptions _options; + private TrackableObjectGraph _object_graph_proto; + private int _restore_uid; + private HashSet _matched_proto_ids; + private Tensor _save_path_tensor; + private string _save_path_string; + private CheckpointReader _reader; + private Dictionary _dtype_map; + private Dictionary _shape_map; + private ObjectGraphView _graph_view; + private Dictionary> _slot_restorations; + private bool _expect_partial_attr; + private List _restore_ops; + private List _all_trackables; + private Dictionary _object_by_proto_id; + + public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, + CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) + { + // TODO(Rinne): cache. + _options = options; + _object_graph_proto = object_graph_proto; + _restore_uid = ops.uid(); + _save_path_tensor = save_path_tensor; + _save_path_string = save_path; + _reader = reader; + if(_reader is null) + { + _reader = new CheckpointReader(save_path); + } + _dtype_map = _reader.VariableToDataTypeMap; + _shape_map = _reader.VariableToShapeMap; + _graph_view = graph_view; + _restore_ops = new List(); + _all_trackables = new List(); + _matched_proto_ids = new HashSet(); + _object_by_proto_id = new Dictionary(); + _slot_restorations = new Dictionary>(); + + _expect_partial_attr = false; + for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) + { + var node = _object_graph_proto.Nodes[i]; + foreach(var slot_reference in node.SlotVariables) + { + _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List()) + .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); + } + } + + // skip the deleter and cache. + } + + public bool ExpectPartial + { + get + { + return _expect_partial_attr; + } + set + { + _expect_partial_attr = value; + } + } + + public List AllTrackables => _all_trackables; + public HashSet MatchedProtoIds => _matched_proto_ids; + public Dictionary ObjectByProtoId => _object_by_proto_id; + public int RestoreUid => _restore_uid; + + public void new_restore_ops(IEnumerable new_ops) + { + _restore_ops.AddRange(new_ops); + // skip the callback. + } + + public List restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) + { + throw new NotImplementedException(); + } +} + +public abstract class LoadStatus +{ + public abstract void assert_consumed(); + public abstract void assert_existing_objects_matched(); + public abstract void assert_nontrivial_match(); + public abstract void run_restore_ops(Session? session = null); + public abstract void initialize_or_restore(Session? session = null); + public virtual LoadStatus expect_partial() + { + return this; + } +} + +public class InitializationOnlyStatus: LoadStatus +{ + private int _restore_uid; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) + { + _restore_uid = restore_uid; + _object_graph_view = object_graph_view; + _root = object_graph_view.Root; + } + public override void assert_consumed() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override void assert_existing_objects_matched() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override void assert_nontrivial_match() + { + throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); + } + public override void run_restore_ops(Session? session = null) + { + throw new AssertionError("No checkpoint specified, so no restore ops are available " + + "(save_path=None to Saver.restore)."); + } + public override void initialize_or_restore(Session? session = null) + { + if (tf.Context.executing_eagerly()) + { + return; + } + if(session is null) + { + session = new Session(); + } + var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } +} + +public class CheckpointLoadStatus +{ + public CheckpointLoadStatus() + { + + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs new file mode 100644 index 00000000..2d8bf096 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Train; + +namespace Tensorflow.Checkpoint; + +internal class CheckpointPosition +{ + private CheckpointRestoreCoordinator _checkpoint; + private int _proto_id; + private bool _skip_restore; + public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) + { + _checkpoint = checkpoint; + _proto_id = proto_id; + _skip_restore = false; + } + + public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; + + public void restore(Trackable trackable) + { + using (ops.init_scope()) + { + if (bind_project(trackable)) + { + + } + } + } + + /// + /// Set a checkpoint<->object correspondence. + /// + /// + /// + public bool bind_project(Trackable trackable) + { + _checkpoint.AllTrackables.Add(trackable); + _checkpoint.MatchedProtoIds.Add(_proto_id); + if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) + { + // skip the `logging.warning`. + return false; + } + else + { + _checkpoint.ObjectByProtoId[_proto_id] = trackable; + return true; + } + } + + public void gather_ops_or_named_saveables() + { + // skip the registered_saver + + + } + + /// + /// Restore the bound Trackable and dependencies (may be deferred). + /// + private void _restore_descendants() + { + Queue<(CheckpointPosition, Trackable)> visit_queue = new(); + visit_queue.Enqueue((this, this.Trackable)); + + } + + private void _single_restore() + { + var trackable = this.Trackable; + trackable._maybe_initialize_trackable(); + if(_checkpoint.RestoreUid > trackable.UpdateUid) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs index 5f08702d..142b8b64 100644 --- a/src/TensorFlowNET.Core/IO/gfile.cs +++ b/src/TensorFlowNET.Core/IO/gfile.cs @@ -16,8 +16,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow.IO { @@ -63,5 +65,15 @@ namespace Tensorflow.IO dirs.AddRange(Directory.GetFiles(dir)); return dirs.ToArray(); } + + public string join(params string[] paths) + { + Debug.Assert(paths.Length >= 1); + if (paths[0].Substring(1).Contains("://")) + { + throw new NotImplementedException("The combination of urls has not been implemented."); + } + return Path.Combine(paths); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs index 4e190605..dfd8735b 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - var axis = serializer.Deserialize(reader, typeof(long[])); + var axis = serializer.Deserialize(reader, typeof(int[])); if (axis is null) { throw new ValueError("Cannot deserialize 'null' to `Axis`."); diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs index 4437ba0a..9ff38129 100644 --- a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs +++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow.ModelSaving { diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 1b1fa003..6ce7a0b0 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,8 +17,8 @@ using System; using System.Linq; using Tensorflow.Framework; -using Tensorflow.ModelSaving; using Tensorflow.Train; +using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs new file mode 100644 index 00000000..df9bdc1b --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public record class LoadOptions + { + public bool allow_partial_checkpoint; + public string experimental_io_device; + public bool experimental_skip_checkpoint; + public VariablePolicy experimental_variable_policy; + + public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, + bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) + { + this.allow_partial_checkpoint = allow_partial_checkpoint; + this.experimental_io_device = experimental_io_device; + this.experimental_skip_checkpoint = experimental_skip_checkpoint; + this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index fe0403c3..60188293 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -1,4 +1,5 @@ -using Tensorflow.Train; +using System; +using Tensorflow.Train; namespace Tensorflow; @@ -14,4 +15,10 @@ public class RevivedTypes // TODO: complete the implementation. return null; } + + public static Tuple> deserialize(object proto) + { + // TODO: complete the implementation. + return null; + } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs similarity index 83% rename from src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs rename to src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs index 45ebd884..d42f5253 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.Text; -namespace Tensorflow.ModelSaving +namespace Tensorflow { /// /// Options for saving to SavedModel. @@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving public bool save_variable_devices() { - return this != VariablePolicy.None; + return this != None; } /// @@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving /// public static VariablePolicy from_obj(object obj) { - if (obj is null) return VariablePolicy.None; + if (obj is null) return None; if (obj is VariablePolicy) return (VariablePolicy)obj; var key = obj.ToString().ToLower(); return key switch { - null => VariablePolicy.None, - "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, - "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + null => None, + "save_variable_devices" => SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") }; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 1be54287..5752d728 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -5,7 +5,6 @@ using System.Linq; using Tensorflow.Checkpoint; using Tensorflow.Contexts; using Tensorflow.Functions; -using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs new file mode 100644 index 00000000..1f8d1a01 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -0,0 +1,601 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Train; +using Tensorflow.Training; +using pbc = global::Google.Protobuf.Collections; +using static Tensorflow.Binding; +using System.Runtime.CompilerServices; +using Tensorflow.Variables; + +namespace Tensorflow +{ + /// + /// Helper class to load an object-based SavedModel. + /// + public partial class Loader + { + private pbc::RepeatedField _asset_file_def; + private Dictionary> _operation_attributes; + private SavedObjectGraph _proto; + private string _export_dir; + private CheckpointOptions _checkpoint_options; + private LoadOptions _save_options; + private IDictionary)> _node_filters; + private Dictionary? _node_path_to_id; + private List? _filtered_nodes; + private List _ordered_node_ids; + private Dictionary)> _loaded_nodes; + private List _nodes; + private Dictionary> _node_setters; + public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, + CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary)> filters) + { + var meta_graph = saved_model_proto.MetaGraphs[0]; + _asset_file_def = meta_graph.AssetFileDef; + _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); + _proto = object_graph_proto; + _export_dir = export_dir; + // TODO: `this._concrete_functions` and `this._restored_concrete_functions` + _checkpoint_options = ckpt_options; + _save_options = save_options; + + // TODO: `this._pretty_printer` + + _node_filters = filters; + _node_path_to_id = _convert_node_paths_to_ints(); + _loaded_nodes = new Dictionary)>(); + foreach(var filter in filters) + { + _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; + } + + _filtered_nodes = _retrieve_all_filtered_nodes(); + + _ordered_node_ids = _generate_ordered_node_ids(); + + _load_all(); + + + if (!save_options.experimental_skip_checkpoint) + { + // TODO: implement it. + } + foreach(var node in _nodes) + { + // skip the process of `CapturableResource`. + } + } + + /// + /// Maps all string node paths in node_filters to the int node ids. + /// + /// + private Dictionary? _convert_node_paths_to_ints() + { + if( _node_filters is null) + { + return null; + } + Dictionary path_to_int = new(); + foreach(var node_id in _node_filters.Keys) + { + int int_node_id; + var node_path = node_id.Split('.'); + if (node_path[0] != "root") + { + throw new ValueError($"When passing string identifiers to node_filters, the first name" + + $" must be root. Received {node_path[0]}."); + } + int_node_id = 0; + for(int i = 0; i < node_path.Length - 1; i++) + { + var name = node_path[i + 1]; + int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); + } + path_to_int[node_id] = int_node_id; + } + return path_to_int; + } + + private int _find_node_child(int node_id, string child_name, string path) + { + foreach(var refer in _proto.Nodes[node_id].Children) + { + if(refer.LocalName == child_name) + { + return refer.NodeId; + } + } + throw new ValueError($"Unable to find node {path}."); + } + + private List? _retrieve_all_filtered_nodes() + { + if(_node_filters is null) + { + return null; + } + + HashSet all_filtered_nodes = new(); + Queue nodes_to_visit = new Queue(_node_filters.Keys); + + while(nodes_to_visit.Count > 0) + { + var node_path = nodes_to_visit.Dequeue(); + var node_id = _node_path_to_id[node_path]; + if (all_filtered_nodes.Contains(node_id)) + { + continue; + } + all_filtered_nodes.Add(node_id); + Trackable node = null; + Action setter = null; + if(_loaded_nodes.TryGetValue(node_id, out var res)) + { + (node, setter) = res; + } + if(node is not null) + { + node._maybe_initialize_trackable(); + } + + foreach(var refer in _proto.Nodes[node_id].Children) + { + Trackable children_object = null; + if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) + { + children_object = result.Item1; + } + // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. + if(children_object is null && node is not null) + { + children_object = node._lookup_dependency(refer.LocalName); + if(children_object is TrackableDataStructure) + { + // TODO: set setter as lambda. + + _loaded_nodes[refer.NodeId] = (children_object, setter); + } + } + string child_path = $"{node_path}.{refer.LocalName}"; + _node_path_to_id[child_path] = refer.NodeId; + nodes_to_visit.Enqueue(child_path); + } + } + + if (all_filtered_nodes.Contains(0)) + { + return null; + } + return all_filtered_nodes.ToList(); + } + + /// + /// Orders the node ids so that dependencies appear first. + /// + /// + private List _generate_ordered_node_ids() + { + List unordered_ids; + if(_filtered_nodes is null) + { + unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); + } + else + { + unordered_ids = new List(_filtered_nodes); + } + + Dictionary> dependency_map = new(); + foreach(var node_id in unordered_ids) + { + var deps = dependency_map.SetDefault(node_id, new List()); + if (_loaded_nodes.ContainsKey(node_id)) + { + continue; + } + var proto = _proto.Nodes[node_id]; + foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) + { + deps.Add(dep); + if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) + { + // TODO: add info with `_pretty_printer`. + throw new ValueError($"Unable to partially load SavedModel since the specified filter " + + $"does not include all required objects for loading (e.g. " + + $"variables used in functions or deserialization dependencies). " + + $"Please include this path in the filter: {dep}"); + } + } + int? prev_slot = null; + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + // The optimizer and original variable must be created before the slot + // variable, since the slot variable is generated using the Optimizer's + // add_slot API. + var slot_deps = dependency_map[slot_variable_node_id]; + slot_deps.Add(node_id); + slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); + + if(prev_slot is not null) + { + slot_deps.Add(prev_slot.Value); + } + prev_slot = slot_variable_node_id; + } + } + try + { + return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable)); + } + catch (TrackableUtils.CyclicDependencyError ex) + { + throw new ValueError("Encountered a cycle in the deserialization dependencies" + + "in the SavedModel. This is extremely unexpected, please" + + "file a bug and make sure you are not manually modifying the SavedModel."); + } + } + + /// + /// Returns a dictionary of all dependencies of an object. + /// + /// + /// + private Dictionary, int> _get_node_dependencies(SavedObject proto) + { + Dictionary, int> dependencies = new(); + foreach(var refer in proto.Dependencies) + { + dependencies[refer.LocalName] = refer.NodeId; + } + if(proto.KindCase == SavedObject.KindOneofCase.Function) + { + var concreete_functions = proto.Function.ConcreteFunctions; + foreach(var fn_name in concreete_functions) + { + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) + { + var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; + foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) + { + dependencies[bound_input] = bound_input; + } + } + else if(proto.KindCase == SavedObject.KindOneofCase.Resource) + { + foreach(var child in proto.Children) + { + if(child.LocalName == "_create_resource") + { + dependencies["_create_resource"] = child.NodeId; + } + } + } + return dependencies; + } + + /// + /// Loads all nodes and functions from the SavedModel and their edges. + /// + private void _load_all() + { + _load_nodes(); + _load_edges(); + + _setup_remaining_functions(); + _load_checkpoint_save_and_restore_functions(); + } + + /// + /// Restores the checkpoint-related save/restore functions to all nodes. + /// + private void _load_checkpoint_save_and_restore_functions() + { + foreach(var (node_id, proto) in _iter_all_nodes()) + { + var node = get(node_id); + if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + // Restore Trackable serialize- and restore-from-tensor functions. + Debug.Assert(proto.SaveableObjects.Count == 1); + var saveable_object_proto = proto.SaveableObjects.Values.First(); + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + else + { + // Restore legacy SaveableObject functions. + Dictionary saveable_fn_by_name = new(); + foreach(var item in proto.SaveableObjects) + { + var name = item.Key; + var saveable_object_proto = item.Value; + var save_fn_id = saveable_object_proto.SaveFunction; + var restore_fn_id = saveable_object_proto.RestoreFunction; + saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); + } + node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); + } + } + } + + /// + /// Load all saved objects. + /// + private void _load_nodes() + { + // `nodes` maps from node ids to recreated objects + // `node_setters` maps from node ids to setter functions + // (same signature as setattr) for setting children. + var (nodes, node_setters) = _initialize_loaded_nodes(); + + Dictionary + slot_variable_node_ids = new(); + + foreach(var (node_id, proto) in _iter_all_nodes()) + { + foreach(var slot_variable_proto in proto.SlotVariables) + { + var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; + slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); + } + } + + // Re-create everything. + foreach (var (node_id, proto) in _iter_all_nodes()) + { + if (nodes.ContainsKey(node_id)) + { + continue; + } + else if (slot_variable_node_ids.ContainsKey(node_id)) + { + // Use the public Optimizer interface when creating slot variables. + var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; + var optimizer_object = nodes[optimizer_node_id]; + var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; + + // TODO: implement it. + throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + + " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + else + { + var (node, setter) = _recreate(proto, node_id, nodes); + nodes[node_id] = node; + node_setters[node_id] = setter; + } + } + + if (!nodes.ContainsKey(0)) + { + nodes[0] = _recreate_base_user_object().Item1; + } + _nodes = new List(); + for(int i = 0; i < _proto.Nodes.Count; i++) + { + _nodes.Add(nodes[i]); + } + _node_setters = node_setters; + } + + /// + /// Load state from checkpoint into the deserialized objects. + /// + private void _restore_checkpoint() + { + var variables_path = SavedModelUtils.get_variables_dir(_export_dir); + var saver = new TrackableSaver(new ObjectGraphView(get(0))); + tf.device("CPU"); + saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + if (_save_options.allow_partial_checkpoint) + { + + } + } + + /// + /// Adds edges from objects to other objects and functions. + /// + private void _load_edges() + { + foreach(var (node_id, object_proto) in _iter_all_nodes()) + { + _add_object_graph_edges(object_proto, node_id); + } + + if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) + { + var root = get(0); + foreach(var node_path in _node_filters.Keys) + { + var loaded_node = _nodes[_node_path_to_id[node_path]]; + + var path = node_path.Split('.'); + var current_node = root; + foreach(var name in path.Skip(1).Take(path.Length - 2)) + { + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + // `hasattr` and `setattr` is used here + throw new NotImplementedException(); + } + } + } + + private void _setup_remaining_functions() + { + // TODO: implement it with concrete functions. + } + + public Trackable get(int node_id) + { + return _nodes[node_id]; + } + + public Trackable get(string node_id) + { + return get(_node_path_to_id[node_id]); + } + + /// + /// Adds edges from an object to its children. + /// + /// + /// + private void _add_object_graph_edges(SavedObject proto, int node_id) + { + var obj = _nodes[node_id]; + var setter = _node_setters[node_id]; + + foreach(var refer in proto.Children) + { + setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); + // skip the process of "__call__" + } + } + + private (Dictionary, Dictionary>) _initialize_loaded_nodes() + { + Dictionary nodes = new(); + Dictionary> node_setters = new(); + foreach(var item in _loaded_nodes) + { + var node_id = item.Key; + var (node, setter) = item.Value; + nodes[node_id] = node; + node_setters[node_id] = setter; + } + return (nodes, node_setters); + } + + private IEnumerable<(int, SavedObject)> _iter_all_nodes() + { + foreach(var node_id in _ordered_node_ids) + { + yield return (node_id, _proto.Nodes[node_id]); + } + } + + private (Trackable, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) + { + // skip the registered classes. + + Dictionary, Trackable> dependencies = new(); + foreach(var item in _get_node_dependencies(proto)) + { + dependencies[item.Key] = nodes[item.Value]; + } + + return _recreate_default(proto, node_id, dependencies); + } + + /// + /// Creates a Python object from a SavedObject protocol buffer. + /// + /// + /// + /// + private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, Trackable> dependencies) + { + return proto.KindCase switch + { + SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), + SavedObject.KindOneofCase.Function => throw new NotImplementedException(), + SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), + SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), + SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() + }; + } + + private (Trackable, Action) _recreate_user_object(SavedUserObject? proto, int node_id) + { + // skip the check of proto identifier because of lack of property. + + var looked_up = RevivedTypes.deserialize(proto); + if(looked_up is null) + { + return _recreate_base_user_object(proto, node_id); + } + return (looked_up.Item1, looked_up.Item2); + } + + private (Trackable, Action) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) + { + return (new _UserObject(), setattr); + } + + private (BaseResourceVariable, Action) _recreate_variable(SavedVariable proto) + { + string name = proto.Name; + string dbg_name = !string.IsNullOrEmpty(name) ? name : ""; + + // TODO(Rinne): `validate_synchronization_aggregation_trainable` + + var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( + proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); + + var saved_device = proto.Device; + var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); + + if (load_with_device) + { + tf.device(saved_device); + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + } + else + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + } + } + + // TODO: remove this to a common class. + public static Action setattr = (x, y, z) => + { + Debug.Assert(y is string); + var properties = x.GetType().GetProperties(); + foreach(var p in properties) + { + if((string)y == p.Name) + { + p.SetValue(x, z); + return; + } + } + // TODO(Rinne): check if the property has been set successfully. + //throw new ValueError($"Cannot find the property {y} of {x}."); + }; + + public class _UserObject: AutoTrackable + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs new file mode 100644 index 00000000..a92cb550 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs @@ -0,0 +1,122 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Tensorflow.Checkpoint; +using Tensorflow.Operations; +using Tensorflow.Train; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public partial class Loader + { + public static SavedModel parse_saved_model(string export_dir) + { + var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); + var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); + + SavedModel saved_model = new SavedModel(); + if (File.Exists(path_to_pb)) + { + byte[] file_content; + using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) + { + file_content = new byte[f.Length]; + Debug.Assert(f.Length <= int.MaxValue); + f.Read(file_content, 0, (int)f.Length); + } + // TODO: change to stream mode. + saved_model.MergeFrom(file_content); + return saved_model; + } + else if (File.Exists(path_to_pbtxt)) + { + throw new NotImplementedException(); + } + else + { + throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + + $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); + } + } + + // TODO: revise the type of `tags` + public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) + { + return load_partial(export_dir, null, tags, options)["root"]; + } + + public static IDictionary load_partial(string export_dir, IDictionary)>? filters, object? tags = null, LoadOptions? options = null) + { + if (options is null) + { + options = new LoadOptions(); + } + if (tags is not null) + { + throw new NotImplementedException(); + } + var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); + + Trackable root = null; + Loader loader = null; + if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) + { + // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` + var meta_graph_def = saved_model_proto.MetaGraphs[0]; + if (!BitConverter.IsLittleEndian) + { + SavedModelUtils.swap_function_tensor_content(meta_graph_def); + } + + var object_graph_proto = meta_graph_def.ObjectGraphDef; + var ckpt_options = new CheckpointOptions(options.experimental_io_device); + tf_with(ops.init_scope(), x => + { + loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); + root = loader.get(0); + // skip the assignment of `graph_debug_info`. + }); + // skip the assignment of `tensorflow_version` + // skip the assignment of `tensorflow_git_version` + // skip the process of `metrics`. + } + else + { + if(filters is not null && filters.Count > 0) + { + throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" + + " version) cannot be loaded with node filters."); + } + tf_with(ops.init_scope(), x => + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + }); + } + if(filters != null && filters.Count > 0) + { + return filters.Keys.ToDictionary(x => x, x => loader.get(x)); + } + else + { + var res = new Dictionary(); + res["root"] = root; + return res; + } + } + + public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) + { + var saved_model = parse_saved_model(export_dir); + + // TODO: implement debug info. + + return (saved_model, null); + } + + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index 94760e3d..4313920f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -6,7 +6,6 @@ using System.Text; using Google.Protobuf; using Tensorflow.Checkpoint; using Tensorflow.Functions; -using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Exceptions; using static Tensorflow.Binding; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs index 4cfe0b69..47d8cbab 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Text; -using Tensorflow.ModelSaving; namespace Tensorflow.Training.Saving.SavedModel { diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index a6e21e3e..cc9be7a2 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -333,6 +333,21 @@ namespace Tensorflow return restored_ops; }; } + + /// + /// Returns a dict of SaveableObject factories generated from loaded fns. + /// + /// + /// + public static IDictionary> recreate_saveable_objects( + IDictionary saveable_fn_by_name, IEnumerable? temp_session) + { + if (saveable_fn_by_name.Count > 0) + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + return new Dictionary>(); + } } public class SaveableCompatibilityConverter: Trackable diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 132571f2..a1de569e 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -20,8 +20,8 @@ using System.Diagnostics; using System.Linq; using Tensorflow.Checkpoint; using Tensorflow.Keras.Saving.SavedModel; -using Tensorflow.ModelSaving; using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; using static Tensorflow.Binding; namespace Tensorflow.Train @@ -71,6 +71,17 @@ namespace Tensorflow.Train public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } + public IDictionary> SelfSaveableObjectFactories + { + get + { + return _self_saveable_object_factories; + } + set + { + _self_saveable_object_factories = value; + } + } /// /// Restore-on-create for a variable be saved with this `Checkpointable`. @@ -259,4 +270,6 @@ namespace Tensorflow.Train } public record class TrackableReference(string Name, Trackable Refer); + + public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 390d95c7..8d513191 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Checkpoint; using Tensorflow.Exceptions; using Tensorflow.Train; diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4005d564..97203604 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -5,9 +5,9 @@ using Tensorflow.Variables; using Tensorflow.Train; using static Tensorflow.Binding; using System.Collections.Generic; -using Tensorflow.ModelSaving; using System.Diagnostics; using Tensorflow.Checkpoint; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow { @@ -19,7 +19,11 @@ namespace Tensorflow protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; - protected string handle_name => _handle_name; + public string handle_name + { + get { return _handle_name; } + set { _handle_name = value; } + } protected string _unique_id; public string UniqueId => _unique_id; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 1645d713..3b1f1e96 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -238,5 +238,23 @@ namespace Tensorflow { return _graph_element.eval(session); } + + public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( + VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) + { + if(aggregation is null) + { + aggregation = VariableAggregation.None; + } + if(synchronization is null) + { + synchronization = VariableSynchronization.Auto; + } + if (trainable is null) + { + trainable = synchronization != VariableSynchronization.OnRead; + } + return (synchronization.Value, aggregation.Value, trainable.Value); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs index b0d1b2b6..2ea1b82e 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) + public static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) { // Layer instances created during the graph reconstruction process. var created_layers = new Dictionary(); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 44eaef53..61eae06e 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine Inputs = inputs, Outputs = outputs }) + { + Initialize(inputs, outputs, name); + } + + internal void Initialize(Tensors inputs, Tensors outputs, string name = null) { _input_layers = new List(); _output_layers = new List(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 31b37d68..81f3a7d9 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Linq; @@ -75,7 +76,17 @@ namespace Tensorflow.Keras.Engine public int Id => id; protected string name; protected string base_name; - public string Name => name; + public string Name + { + get + { + return name; + } + set + { + name = value; + } + } protected bool computePreviousMask; protected List updates; @@ -89,6 +100,8 @@ namespace Tensorflow.Keras.Engine List outboundNodes; public List OutboundNodes => outboundNodes; + public JObject SerializedAttributes { get; set; } + ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; public Tensor[] input @@ -117,6 +130,11 @@ namespace Tensorflow.Keras.Engine protected List _self_tracked_trackables; public Layer(LayerArgs args) + { + Initialize(args); + } + + internal virtual void Initialize(LayerArgs args) { this.args = args; // A stateful layer is a layer whose updates are run during inference too, diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index a1e891f9..a3956ccc 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine { using (SharedObjectSavingScope.Enter()) { - KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index dd3e11a2..2a2a3662 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; + + public bool IsGraphNetwork => _is_graph_network; public OptimizerV2 Optimizer { @@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine _init_batch_counters(); } + internal override void Initialize(LayerArgs args) + { + _init_batch_counters(); + base.Initialize(args); + } + void _configure_steps_per_execution(int steps_per_execution) { _steps_per_execution = tf.Variable(steps_per_execution, diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 4d87659b..69665388 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine : base(args.Inputs, args.Outputs, name: args.Name) { this.args = args; - if (args.Layers == null) - args.Layers = new List(); // SupportsMasking = true; _compute_output_and_mask_jointly = true; _auto_track_sub_layers = false; @@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine _created_nodes = new List(); // Add to the model any layers passed to the constructor. - if (args.Layers != null) + if (args.Layers is not null) { - foreach (var layer in args.Layers) - add(layer); + InitLayers(args.Layers); + } + } + + public void InitLayers(IEnumerable layers) + { + foreach(var layer in layers) + { + add(layer); } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 45f64720..9cb5b756 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { { throw new ValueError("Alpha must be a number greater than 0."); } - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index 2fd2caee..981f96f0 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { } public override void build(Shape input_shape) { - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 1ef8d0e5..9b5bc0e6 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { if ( alpha < 0f ) { throw new ValueError("Alpha must be a number greater than 0."); } - _buildInputShape = input_shape; - built = true; + base.build(input_shape); } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index fc8cab0c..b378ea64 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -1,12 +1,24 @@ using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; using System.Linq; +using System.Reflection; using System.Text.RegularExpressions; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; +using Tensorflow.Train; +using Tensorflow.Training; using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.ApiDef.Types; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving { public class KerasObjectLoader { - SavedMetadata _metadata; - SavedObjectGraph _proto; - Dictionary _node_paths = new Dictionary(); - Dictionary model_layer_dependencies = new Dictionary(); - List _traversed_nodes_from_config = new List(); + private static readonly IDictionary PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; + private SavedMetadata _metadata; + private SavedObjectGraph _proto; + private Dictionary _node_paths = new Dictionary(); + private Dictionary model_layer_ids_dependencies = new Dictionary(); + private Dictionary model_layer_dependencies = new Dictionary(); + private List _traversed_nodes_from_config = new List(); + private Dictionary)> loaded_nodes; + private List _models_to_reconstruct; + public Dictionary)> LoadedNodes => loaded_nodes; + + static KerasObjectLoader() + { + PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; + } public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) { _metadata = metadata; _proto = object_graph_def; _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); + _models_to_reconstruct = new List(); + loaded_nodes = new Dictionary)>(); } /// @@ -42,15 +66,251 @@ namespace Tensorflow.Keras.Saving continue; } - _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); + } + foreach(var node_metadata in metric_list) + { + try + { + loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, + node_metadata.Metadata); + } + catch(ValueError e) + { + if (compile) + { + throw e; + } + // TODO: add logging.warning. + } + } + } + + public string get_path(int node_id) + { + return _node_paths[node_id]; + } + + /// + /// Finish setting up Keras objects. + /// + /// This function is executed after all objects and functions have been created. + /// Call functions and losses are attached to each layer, and once all layers + /// have been fully set up, graph networks are initialized. + /// + /// Subclassed models that are revived from the SavedModel are treated like + /// layers, and have their call/loss functions attached here. + /// + public void finalize_objects() + { + List layers_revived_from_config = new(); + List layers_revived_from_saved_model = new(); + foreach(var item in loaded_nodes) + { + var node_id = item.Key; + var node = item.Value.Item1; + if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) + { + continue; + } + + _unblock_model_reconstruction(node_id, node as Layer); + + if(node is InputLayer or Metric) + { + continue; + } + + // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. + layers_revived_from_config.Add(node as Layer); + } + + _finalize_saved_model_layers(layers_revived_from_saved_model); + _finalize_config_layers(layers_revived_from_config); + + _reconstruct_all_models(); + } + + private void _reconstruct_all_models() + { + HashSet all_initialized_models = new(); + for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) + { + int model_id = _models_to_reconstruct[i]; + all_initialized_models.Add(model_id); + var (model, layers) = model_layer_dependencies[model_id]; + _reconstruct_model(model_id, model, layers.ToList()); + _finalize_config_layers(new List() { model }); + } + + Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); + } + + private void _reconstruct_model(int model_id, Model model, List layers) + { + var config = JsonConvert.DeserializeObject(_metadata.Nodes[model_id].Metadata)["config"]; + + if(model.input is not null && model.input.Length > 0) + { + + } + else if(model is Sequential s) + { + if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) + { + if (config["layers"][0]["class_name"].ToObject() == "InputLayer") + { + layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject())); + } + else if (config["layers"][0]["config"]["batch_input_shape"] is not null) + { + // TODO: implement it + } + } + + // `model.__init__(layers, config["name"])` + s.InitLayers(layers); + s.Name = config["name"].ToObject(); + if(s.input is null || s.input.Length == 0) + { + var first_layer = _get_child_layer_node_ids(model_id)[0]; + var input_specs = _infer_inputs(first_layer); + var input_shapes = _infer_inputs(first_layer, true); + // `model._set_inputs(input_specs)` + + // skip the check of input_specs is Dictionary + if (!s.Built) + { + s.build(input_shapes); + } + } + } + else + { + // skip the parameter `created_layers`. + var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject()); + // skip the `model.__init__` + (model as Functional).Initialize(inputs, outputs, config["name"].ToObject()); + (model as Functional).connect_ancillary_layers(created_layers); } + + _set_network_attributes_from_metadata(model); + _unblock_model_reconstruction(model_id, model); } - void _load_layer(int node_id, string identifier, string metadata_json) + private void _set_network_attributes_from_metadata(Model revived_object) + { + // TODO: implement it. + } + + /// + /// Runs the final steps of loading Keras Layers from config. + /// + /// + private void _finalize_config_layers(List layers) + { + foreach(var layer in layers) + { + if (_is_graph_network(layer)) + { + _restore_layer_unconditional_losses(layer); + } + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + + // TODO(Rinne): deal with RNN. + } + } + + /// + /// Runs the final steps of loading Keras Layers from SavedModel. + /// + /// + private void _finalize_saved_model_layers(List layers) + { + foreach(var layer in layers) + { + // TODO(Rinne): deal with `RevivedNetwork`. + + _restore_layer_unconditional_losses(layer); + _restore_layer_activation_loss(layer); + _restore_layer_metrics(layer); + } + } + + private void _restore_layer_unconditional_losses(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_activation_loss(Layer layer) + { + // TODO(Rinne): implement it. + } + + private void _restore_layer_metrics(Layer layer) + { + // TODO(Rinne): implement it. + } + + /// + /// Removes layer from blocking model reconstruction. + /// + /// + /// + private void _unblock_model_reconstruction(int layer_id, Layer layer) + { + foreach(var depencency in model_layer_ids_dependencies) + { + var layer_ids = depencency.Value.Item2; + var layers = model_layer_dependencies.SetDefault(depencency.Key, + (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; + if (!layer_ids.Contains(layer_id)) + { + continue; + } + layers[Array.IndexOf(layer_ids, layer_id)] = layer; + if (layers.All(x => x is not null)) + { + _models_to_reconstruct.Add(depencency.Key); + } + } + } + + private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) { metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); var metadata = JsonConvert.DeserializeObject(metadata_json); - _revive_from_config(identifier, metadata, node_id); + + if (loaded_nodes.ContainsKey(node_id)) + { + var (node, setter) = loaded_nodes[node_id]; + + _maybe_add_serialized_attributes(node as Layer, metadata); + var config = metadata.Config; + if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) + { + Debug.Assert(node is Model); + var child_nodes = _get_child_layer_node_ids(node_id); + model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); + if(child_nodes is null || child_nodes.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } + } + return (node, setter); + } + else + { + var (obj, setter) = _revive_from_config(identifier, metadata, node_id); + if (obj is null) + { + (obj, setter) = _revive_custom_object(identifier, metadata); + } + Debug.Assert(obj is Layer); + _maybe_add_serialized_attributes(obj as Layer, metadata); + return (obj, setter); + } } /// @@ -59,11 +319,32 @@ namespace Tensorflow.Keras.Saving /// /// /// - void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) + private (Trackable, Action) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) { - var obj = _revive_graph_network(identifier, metadata, node_id); - obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + Trackable obj; + if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + else + { + obj = _revive_graph_network(identifier, metadata, node_id); + obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); + } + + if(obj is null) + { + return (null, null); + } + var setter = _config_node_setter(_revive_setter); _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); + return (obj, setter); + } + + private (Trackable, Action) _revive_custom_object(string identifier, KerasMetaData metadata) + { + // TODO: implement it. + throw new NotImplementedException(); } Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) @@ -71,6 +352,12 @@ namespace Tensorflow.Keras.Saving var config = metadata.Config; var class_name = metadata.ClassName; Model model = null; + + if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") + { + return null; + } + if (class_name == "Sequential") { model = new Sequential(new SequentialArgs @@ -78,34 +365,83 @@ namespace Tensorflow.Keras.Saving Name = config.GetValue("name").ToString() }); } - else if (class_name == "Functional") + else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) { - throw new NotImplementedException(""); + model = model = new Sequential(new SequentialArgs + { + Name = class_name + }); + } + else + { + // TODO: implement it. + throw new NotImplementedException("Not implemented"); } - - if (!metadata.IsGraphNetwork) - return null; // Record this model and its layers. This will later be used to reconstruct // the model. var layers = _get_child_layer_node_ids(node_id); - model_layer_dependencies[node_id] = (model, layers); + model_layer_ids_dependencies[node_id] = (model, layers); + if(layers is null || layers.Length == 0) + { + _models_to_reconstruct.Add(node_id); + } return model; } - Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) + Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) { var config = metadata.Config; var class_name = metadata.ClassName; var shared_object_id = metadata.SharedObjectId; var must_restore_from_config = metadata.MustRestoreFromConfig; - var obj = class_name switch - { - "Resizing" => Resizing.from_config(config), - _ => throw new NotImplementedException("") - }; + + var obj = generic_utils.deserialize_keras_object(class_name, config); + + obj.Name = metadata.Name; + // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` + + var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); - return null; + if (!built) + { + return null; + } + return obj; + } + + private void _revive_setter(object layer, object name, object value) + { + Debug.Assert(name is string); + Debug.Assert(layer is Layer); + if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) + { + if(value is Trackable) + { + (layer as Layer)._track_trackable(value as Trackable, name as string); + } + if((layer as Layer).SerializedAttributes is null) + { + (layer as Layer).SerializedAttributes = new JObject(); + } + (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); + } + else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + { + (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); + } + else + { + var properties = layer.GetType().GetProperties(); + foreach(var p in properties) + { + if(p.Name == name as string && p.GetValue(layer) is not null) + { + return; + } + } + Loader.setattr(layer, name, value); + } } /// @@ -143,34 +479,186 @@ namespace Tensorflow.Keras.Saving /// /// /// - void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) + void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) { if (_traversed_nodes_from_config.Contains(node_id)) return; var parent_path = _node_paths[node_id]; _traversed_nodes_from_config.Add(node_id); - if (!obj.Built) + obj._maybe_initialize_trackable(); + + if(obj is Layer layer && !layer.Built) { - var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); - var metadata = JsonConvert.DeserializeObject(metadata_json); - _try_build_layer(obj, node_id, metadata.BuildInputShape); + var metadata = JsonConvert.DeserializeObject(_metadata.Nodes[node_id].Metadata); + _try_build_layer(layer, node_id, metadata.BuildInputShape); + } + + + List<(Trackable, int, string)> children = new(); + foreach(var refer in proto.Children) + { + var obj_child = obj._lookup_dependency(refer.LocalName); + children.Add((obj_child, refer.NodeId, refer.LocalName)); + } + + var metric_list_node_id = _search_for_child_node(node_id, new string[] { + Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" + }); + if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) + { + var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); + foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) + { + if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) + { + var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; + children.Add((metric, refer.NodeId, metric_path)); + } + } + } + + foreach(var (obj_child, child_id, child_name) in children) + { + if(obj_child is null) + { + continue; + } + var child_proto = _proto.Nodes[child_id]; + + // skip the check for registered identifier + + Action setter; + if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) + { + setter = _revive_setter; + } + else + { + setter = Loader.setattr; + } + + if (loaded_nodes.ContainsKey(child_id)) + { + // skip the logging.warning + continue; + } + + if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) + { + (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; + } + + if(obj_child is TrackableDataStructure) + { + setter = (x, y, z) => { }; + } + + var child_path = $"{parent_path}.{child_name}"; + _node_paths[child_id] = child_path; + _add_children_recreated_from_config(obj_child, child_proto, child_id); + loaded_nodes[child_id] = (obj_child, setter); } } - bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) + private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) { if (obj.Built) return true; + if(build_input_shape is null) + { + build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); + } + + if(build_input_shape is not null) + { + obj.build(build_input_shape); + // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. + // On the one hand, C# does not support call a method from specified parent class. + // On the other hand, currently All class derived from Layer call `Layer.Build` or + // move the implementation of `Layer.build` to its own `build` method. + // Therefore we do not call it here. + // However, it's still quite risky once in the future a certain class derived from + // `Layer` does not call `Layer.build`. + + return true; + } + return false; } - bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) + /// + /// Infers input shape of layer from SavedModel functions. + /// + /// + /// + /// + private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) { - if (obj.Built) - return true; + var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); + if(call_fn_id is null) + { + return null; + } + var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; + if(concrete_functions is null) + { + return null; + } + var call_fn_name = concrete_functions[0]; + var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + + private int? _search_for_child_node(int parent_id, IEnumerable path_to_child) + { + if(path_to_child is null || path_to_child.Count() == 0) + { + return parent_id; + } + + foreach(var child in _proto.Nodes[parent_id].Children) + { + if(child.LocalName == path_to_child.First()) + { + return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); + } + } + return null; + } + + private bool _is_graph_network(Layer layer) + { + // TODO: deal with `RevivedLayer` + if(layer is Functional) + { + return (layer as Functional).IsGraphNetwork || layer is Sequential; + } return false; } + + private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) + { + // TODO: deal with `RevivedLayer` + } + + /// + /// Creates edges for nodes that are recreated from config. + /// + /// + private Action _config_node_setter(Action setter) + { + void setattr_wrapper(object obj, object name, object value) + { + Debug.Assert(obj is Trackable); + Debug.Assert(name is string); + if((obj as Trackable)._lookup_dependency(name as string) is null) + { + setter(obj, name, value); + } + } + return setattr_wrapper; + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index c7b7e52f..b4e5a889 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, + public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, SaveOptions? options, bool save_traces = true) { if (!overwrite && File.Exists(filepath)) diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs new file mode 100644 index 00000000..e7cb5b3a --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs @@ -0,0 +1,96 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using ThirdParty.Tensorflow.Python.Keras.Protobuf; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class KerasLoadModelUtils + { + /// + /// Corresponding to keras/saving/save.py/load_model + /// + /// + /// + /// + /// + /// + public static Trackable load_model(string filepath, IDictionary? custom_objects = null, + bool compile = true, LoadOptions? options = null) + { + using (SharedObjectSavingScope.Enter()) + { + using (LoadContext.load_context(options)) + { + if (!File.Exists(filepath) && !Directory.Exists(filepath)) + { + throw new IOException($"No file or directory found at {filepath}."); + } + if (Directory.Exists(filepath)) + { + return load(filepath, compile, options); + } + else + { + throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); + } + } + } + } + + public static Trackable load(string path, bool compile = true, LoadOptions? options = null) + { + SavedMetadata metadata = new SavedMetadata(); + var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; + var object_graph_def = meta_graph_def.ObjectGraphDef; + string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); + if (File.Exists(path_to_metadata_pb)) + { + metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); + } + else + { + throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + } + + if (metadata.Nodes is null || metadata.Nodes.Count == 0) + { + return Loader.load(path, options: options) as Model; + } + + var keras_loader = new KerasObjectLoader(metadata, object_graph_def); + keras_loader.load_layers(compile: compile); + + Dictionary)> nodes_to_load = new(); + nodes_to_load["root"] = (null, null); + foreach(var item in keras_loader.LoadedNodes) + { + nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; + } + var loaded = Loader.load_partial(path, nodes_to_load, options); + + keras_loader.finalize_objects(); + // keras_loader.del_tracking(); + + var model = loaded["root"]; + + if(model is Model && compile) + { + // TODO: implement it. + } + + if (!tf.Context.executing_eagerly()) + { + // TODO: implement it. + } + + return model; + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs new file mode 100644 index 00000000..11b1201d --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs @@ -0,0 +1,69 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using Tensorflow.Training.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + // TODO: remove this class to common project. + public class ContextHandler: IDisposable + { + public Action DisposeCallBack { get; set; } + public void Dispose() + { + DisposeCallBack.Invoke(true); + } + } + public class LoadContext + { + private bool _entered_load_context; + private LoadOptions? _load_options; + private static ThreadLocal _load_context = new(); + private LoadContext() + { + _entered_load_context = false; + _load_options = null; + } + + public void set_load_options(LoadOptions load_options) + { + _load_options = load_options; + _entered_load_context = true; + } + + private void clear_load_options() + { + _load_options = null; + _entered_load_context = false; + } + + private LoadOptions? load_options() + { + return _load_options; + } + + public static ContextHandler load_context(LoadOptions? load_options) + { + if(_load_context.Value is null) + { + _load_context.Value = new LoadContext(); + } + _load_context.Value.set_load_options(load_options); + return new ContextHandler() + { + DisposeCallBack = _ => _load_context.Value.clear_load_options() + }; + } + + public static LoadOptions? get_load_option() + { + return _load_context.Value.load_options(); + } + + public static bool in_load_context() + { + return _load_context.Value._entered_load_context; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 730a33e3..fffa4b8a 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -22,12 +22,16 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; +using Tensorflow.Train; namespace Tensorflow.Keras.Utils { public class generic_utils { + private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; /// /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. /// @@ -51,6 +55,21 @@ namespace Tensorflow.Keras.Utils return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); } + public static Layer deserialize_keras_object(string class_name, JObject config) + { + return class_name switch + { + "Sequential" => new Sequential(config.ToObject()), + "InputLayer" => new InputLayer(config.ToObject()), + "Flatten" => new Flatten(config.ToObject()), + "ELU" => new ELU(config.ToObject()), + "Dense" => new Dense(config.ToObject()), + "Softmax" => new Softmax(config.ToObject()), + _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + + $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") + }; + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => @@ -60,5 +79,15 @@ namespace Tensorflow.Keras.Utils x.ToString(); })).ToLower(); } + + /// + /// Determines whether config appears to be a valid layer config. + /// + /// + /// + public static bool validate_config(JObject config) + { + return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs new file mode 100644 index 00000000..63a7fe9c --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -0,0 +1,45 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow; +using Tensorflow.Keras.Optimizers; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.Keras.UnitTest.SaveModel; + +[TestClass] +public class SequentialModelLoad +{ + [TestMethod] + public void SimpleModelFromSequential() + { + var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); + Debug.Assert(model is Model); + var m = model as Model; + + m.summary(); + + m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 50000, + }).Result; + + m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs similarity index 99% rename from test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs rename to test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs index 269b9c05..d24049fb 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs @@ -63,6 +63,8 @@ public class SequentialModelTest keras.layers.Softmax(1) }); + model.summary(); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader();