diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 9812d3c6..9793798d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -158,4 +158,13 @@ public static class CheckPointUtils { return objects_ids_and_slot_variables_and_paths(graph_view).Item1; } + + internal static IEnumerable _objects_with_attributes(IEnumerable full_list) + { + return full_list.TakeWhile(x => + { + var saveables = x.gather_saveables_for_checkpoint(); + return saveables is not null && saveables.Count > 0; + }); + } } diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs index 49976280..3f7b1836 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -1,12 +1,13 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; namespace Tensorflow.Checkpoint { - internal class CheckpointReader : IDisposable + public class CheckpointReader : IDisposable { private IntPtr _reader; public Dictionary VariableToDataTypeMap { get; set; } @@ -61,14 +62,14 @@ namespace Tensorflow.Checkpoint return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); } - public Tensor GetTensor(string name) + public unsafe Tensor GetTensor(string name) { Status status = new Status(); var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); status.Check(true); var shape = GetVariableShape(name); var dtype = GetVariableDataType(name); - return new Tensor(tensor, shape, dtype); + return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); } private void ReadAllShapeAndType() diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 3267ae12..72372e41 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -175,9 +175,9 @@ public static class SaveUtilV1 { var name = factory_data.name; var key = factory_data.checkpoint_key; - var maybe_saveable = factory_data.factory; + var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); - // TODO: oneflow python has a process with callable `saveable_factory`. + // TODO: tensorflow python has a process with callable `saveable_factory`. List saveables = new(); if (maybe_saveable.TryGet(out var s)) { @@ -217,7 +217,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - Maybe factory, + Func> factory, string name, string checkpoint_key ); diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs index 2132cd1d..8a6858f6 100644 --- a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs @@ -24,6 +24,6 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); [DllImport(TensorFlowLibName)] - internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); + internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index a10e8953..d5cf2ae4 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -10,6 +10,8 @@ using Tensorflow.Exceptions; using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; using static Tensorflow.Binding; using Tensorflow.Operations; +using Newtonsoft.Json; +using Tensorflow.Training; namespace Tensorflow.Checkpoint; @@ -259,11 +261,48 @@ public class TrackableSaver saveables_cache: null ); - throw new NotImplementedException(); + new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); + + if(_graph_view.AttachedDependencies is not null) + { + foreach(var refer in _graph_view.AttachedDependencies) + { + if(refer.Name == "root") + { + continue; + } + int? proto_id = null; + // Find proto ID of attached dependency (if it is in the proto). + foreach (var proto_refer in object_graph_proto.Nodes[0].Children) + { + if(proto_refer.LocalName == refer.Name) + { + proto_id = proto_refer.NodeId; + break; + } + } + + if (proto_id is null) + { + continue; + } + + // Object has already been restored. This can happen when there's an + // indirect connection from the attached object to the root. + if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) + { + continue; + } + + new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); + } + } + + return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); } } -internal class CheckpointRestoreCoordinator +public class CheckpointRestoreCoordinator { private CheckpointOptions _options; private TrackableObjectGraph _object_graph_proto; @@ -280,6 +319,9 @@ internal class CheckpointRestoreCoordinator private List _restore_ops; private List _all_trackables; private Dictionary _object_by_proto_id; + private Dictionary _restore_ops_by_name; + private Dictionary> _deferred_slot_restorations; + private Dictionary> _unused_attributes; public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) @@ -299,10 +341,12 @@ internal class CheckpointRestoreCoordinator _shape_map = _reader.VariableToShapeMap; _graph_view = graph_view; _restore_ops = new List(); + _restore_ops_by_name = new Dictionary(); _all_trackables = new List(); _matched_proto_ids = new HashSet(); _object_by_proto_id = new Dictionary(); _slot_restorations = new Dictionary>(); + _deferred_slot_restorations = new Dictionary>(); _expect_partial_attr = false; for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) @@ -330,10 +374,18 @@ internal class CheckpointRestoreCoordinator } } + /// + /// Corresponding to `all_python_objects` of tensorflow python + /// public List AllTrackables => _all_trackables; public HashSet MatchedProtoIds => _matched_proto_ids; public Dictionary ObjectByProtoId => _object_by_proto_id; public int RestoreUid => _restore_uid; + public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; + public Dictionary> SlotRestorations => _slot_restorations; + public Dictionary> DeferredSlotRestorations => _deferred_slot_restorations; + public Dictionary RestoreOpsByName => _restore_ops_by_name; + public Dictionary> UnusedAttributes => _unused_attributes; public void new_restore_ops(IEnumerable new_ops) { @@ -341,18 +393,52 @@ internal class CheckpointRestoreCoordinator // skip the callback. } - public List restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) + public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null) { - throw new NotImplementedException(); + List restore_ops = new(); + foreach(var position in positions) + { + var key = position.ObjectProto.Attributes[0].CheckpointKey; + throw new NotImplementedException(); + } + + Dictionary variable_dict = new(); + foreach(var item in tensor_saveables) + { + if(item.Value.TryGet(out var variable)) + { + variable_dict[item.Key] = variable; + } + else + { + throw new TypeError(); + } + } + + if (tensor_saveables is not null && tensor_saveables.Count > 0) + { + var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); + var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); + if (!tf.Context.executing_eagerly()) + { + foreach(var item in new_restore_ops) + { + restore_ops.Add(item.Value); + Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); + _restore_ops_by_name[item.Key] = item.Value; + } + } + } + return restore_ops; } } public abstract class LoadStatus { - public abstract void assert_consumed(); - public abstract void assert_existing_objects_matched(); - public abstract void assert_nontrivial_match(); - public abstract void run_restore_ops(Session? session = null); + public abstract LoadStatus assert_consumed(); + public abstract LoadStatus assert_existing_objects_matched(); + public abstract LoadStatus assert_nontrivial_match(); + public abstract LoadStatus run_restore_ops(Session? session = null); public abstract void initialize_or_restore(Session? session = null); public virtual LoadStatus expect_partial() { @@ -371,19 +457,19 @@ public class InitializationOnlyStatus: LoadStatus _object_graph_view = object_graph_view; _root = object_graph_view.Root; } - public override void assert_consumed() + public override LoadStatus assert_consumed() { throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); } - public override void assert_existing_objects_matched() + public override LoadStatus assert_existing_objects_matched() { throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); } - public override void assert_nontrivial_match() + public override LoadStatus assert_nontrivial_match() { throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); } - public override void run_restore_ops(Session? session = null) + public override LoadStatus run_restore_ops(Session? session = null) { throw new AssertionError("No checkpoint specified, so no restore ops are available " + "(save_path=None to Saver.restore)."); @@ -403,10 +489,78 @@ public class InitializationOnlyStatus: LoadStatus } } -public class CheckpointLoadStatus +internal class CheckpointLoadStatus: LoadStatus { - public CheckpointLoadStatus() + private CheckpointRestoreCoordinator _checkpoint; + private Dictionary _feed_dict; + private ObjectGraphView _object_graph_view; + private Trackable _root; + public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary feed_dict, ObjectGraphView graph_view):base() + { + _checkpoint = checkpoint; + _feed_dict = feed_dict; + _object_graph_view = graph_view; + _root = graph_view.Root; + } + + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + + public override LoadStatus assert_consumed() + { + throw new NotImplementedException(); + } + + public override LoadStatus assert_existing_objects_matched() + { + for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) + { + var node = _checkpoint.ObjectGraphProto.Nodes[i]; + if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && + trackable.UpdateUid < _checkpoint.RestoreUid) + { + throw new AssertionError($"Object {node} not assigned a value from checkpoint."); + } + } + foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) + { + if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) + { + continue; + } + _checkpoint.AllTrackables.Add(trackable_object); + } + var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) + .Except(_checkpoint.ObjectByProtoId.Values); + if (unused_trackables.Any()) + { + var num_unused_trackables = unused_trackables.Count(); + var num_variables_to_show = Math.Min(10, num_unused_trackables); + throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + + $"not bound to checkpointed values, likely due to changes in the " + + $"Python program. Showing {num_variables_to_show} of " + + $"{num_unused_trackables} unmatched objects: " + + $"{{list(unused_python_objects)[:num_variables_to_show]}}"); + } + return this; + } + + public override LoadStatus assert_nontrivial_match() + { + throw new NotImplementedException(); + } + + public override LoadStatus expect_partial() { + throw new NotImplementedException(); + } + public override void initialize_or_restore(Session? session = null) + { + throw new NotImplementedException(); + } + + public override LoadStatus run_restore_ops(Session? session = null) + { + throw new NotImplementedException(); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 09904d68..96e6c8dd 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint // tf python has code `with ops.device(restore_device):` here. tf.device(restore_device); // may be risky. - var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); Dictionary> restored_tensor_dict = new(); int idx = 0; diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs index 2d8bf096..b27396a7 100644 --- a/src/TensorFlowNET.Core/Checkpoint/restore.cs +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -1,11 +1,15 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Text; using Tensorflow.Train; +using Tensorflow.Training; +using static Tensorflow.Binding; namespace Tensorflow.Checkpoint; -internal class CheckpointPosition +public class CheckpointPosition { private CheckpointRestoreCoordinator _checkpoint; private int _proto_id; @@ -18,6 +22,8 @@ internal class CheckpointPosition } public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; + public CheckpointRestoreCoordinator Checkpoint => _checkpoint; + public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; public void restore(Trackable trackable) { @@ -25,7 +31,11 @@ internal class CheckpointPosition { if (bind_project(trackable)) { - + var restore_ops = _restore_descendants(); + if(restore_ops is not null && restore_ops.Count > 0) + { + _checkpoint.new_restore_ops(restore_ops); + } } } } @@ -51,30 +61,271 @@ internal class CheckpointPosition } } - public void gather_ops_or_named_saveables() + public (List, Dictionary>, List, object?) gather_ops_or_named_saveables() { // skip the registered_saver + if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) + { + return (new List(), new Dictionary>(), + new List(), null); + } + + var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); + List existing_restore_ops; + List positions = new(); + Dictionary> named_saveables; + if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) + { + (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); + } + else if(saveable_factories.Count > 0) + { + (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); + } + else + { + throw new NotImplementedException(); + } + return (existing_restore_ops, named_saveables, positions, null); + } + + public CheckpointPosition create_child_position(int node_id) + { + return new CheckpointPosition(_checkpoint, node_id); + } + + public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, + int slot_variable_id, string slot_name) + { + //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); + + // TODO(Rinne): implement it. + return (null, null); + } + + /// + /// Creates a saveable using the _serialize_to_tensor method. + /// + /// + private (List, Dictionary>) _create_serialize_to_tensor_saveable( + IDictionary>> saveable_factories) + { + string suffix = SaveableCompat.get_saveable_name(this.Trackable); + suffix = suffix ?? ""; + var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; + + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The restore under graph mode has not been implemented. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + + var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); + // skip the cache. + Dictionary> dict = new(); + dict[saveable_name] = saveable; + return (new List(), dict); + } + + private (List, Dictionary>) _create_saveables_by_attribute_name( + IDictionary>> saveable_factories) + { + // TODO(Rinne): implement it. + if(ObjectProto.Attributes is null) + { + return (new List(), new Dictionary>()); + } + + List existing_restore_ops = new(); + HashSet created_compat_names = new(); + Dictionary> named_saveables = new(); + foreach (var serialized_tensor in ObjectProto.Attributes) + { + Operation existing_op; + if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) + { + existing_op = null; + } + else + { + existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; + } + + if(existing_op is not null) + { + existing_restore_ops.Add(existing_op); + continue; + } + + if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) + { + continue; + } + + // TODO(Rinne): deal with cache. + + var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); + if(saveable is null) + { + _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List()).Add(serialized_tensor.Name); + continue; + } + named_saveables[serialized_tensor.CheckpointKey] = saveable; + } + return (existing_restore_ops, named_saveables); + } + + private Maybe _get_saveable_from_factory(IDictionary>> saveable_factories, + TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet created_compat_names) + { + var expected_factory_name = serialized_tensor.Name; + var factory_input_name = serialized_tensor.CheckpointKey; + + if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) + { + foreach(var item in saveable_factories) + { + var factory_name = item.Key; + var factory = item.Value; + if (expected_factory_name.StartsWith(factory_name)) + { + if(matched_factory is not null) + { + throw new ValueError($"Forward compatibility load error: Unable to load " + + "checkpoint saved in future version of TensorFlow. " + + "Please update your version of TensorFlow to the " + + "version in which the checkpoint was saved."); + } + } + matched_factory = factory; + factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; + created_compat_names.Add(factory_name); + } + } + return matched_factory(factory_input_name); + } + + private string _extract_saveable_name(string checkpoint_key) + { + var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; + return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); } /// /// Restore the bound Trackable and dependencies (may be deferred). /// - private void _restore_descendants() + private List _restore_descendants() { Queue<(CheckpointPosition, Trackable)> visit_queue = new(); visit_queue.Enqueue((this, this.Trackable)); + List restore_ops = new(); + Dictionary> tensor_saveables = new(); + List positions = new(); + + CheckpointPosition current_position = null; + while (visit_queue.Count > 0) + { + current_position = visit_queue.Dequeue().Item1; + var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); + restore_ops.AddRange(new_restore_ops); + foreach(var item in new_tensor_saveables) + { + tensor_saveables.Add(item.Key, item.Value); + } + positions.AddRange(new_positions); + _queue_children_for_restoration(current_position, visit_queue); + _queue_slot_variables(current_position, visit_queue); + } + restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); + return restore_ops; + } + + private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + foreach(var child in checkpoint_position.ObjectProto.Children) + { + var child_position = checkpoint_position.create_child_position(child.NodeId); + var local_object = trackable._lookup_dependency(child.LocalName); + var child_proto = child_position.ObjectProto; + if(local_object is null) + { + if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) + { + trackable.DeferredDependencies.SetDefault(child.LocalName, new List()).Add(child_position); + } + } + else + { + if (child_position.bind_project(local_object)) + { + visit_queue.Enqueue((child_position, local_object)); + } + } + } + } + private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) + { + var trackable = checkpoint_position.Trackable; + var checkpoint = checkpoint_position.Checkpoint; + if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) + { + checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var deferred_slot_restoration in positions) + { + var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( + trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, + deferred_slot_restoration.SlotName + ); + if(slot_variable_position is not null) + { + visit_queue.Enqueue((slot_variable_position, slot_variable)); + } + } + } + if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) + { + checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); + foreach (var slot_restoration in restorations) + { + if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) + { + throw new NotImplementedException(); + // TODO(Rinne); implement it. + } + else + { + Debug.Assert(trackable is BaseResourceVariable); + Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List()) + .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); + } + } + } } - private void _single_restore() + private (List, Dictionary>, List, object?) _single_restore() { var trackable = this.Trackable; trackable._maybe_initialize_trackable(); if(_checkpoint.RestoreUid > trackable.UpdateUid) { - + var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); + trackable.UpdateUid = _checkpoint.RestoreUid; + return (restore_ops, tensor_saveables, positions, registered_savers); + } + else + { + return (new List(), new Dictionary>(), + new List(), null); } } } + +public record class DeferredSlotVariableRestoration( + BaseResourceVariable OriginalVariable, + int SlotVariableId, + string SlotName +); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs index cb3ea4d3..2926f8e2 100644 --- a/src/TensorFlowNET.Core/Eager/execute.cs +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -10,7 +10,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Eager { - internal class execute + internal static class execute { public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) { @@ -27,5 +27,9 @@ namespace Tensorflow.Eager return tensors; } + public static bool must_record_gradient() + { + return false; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 956be96b..26a9b5be 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations /// /// Callers must ensure all the named tensors are indeed stored in the checkpoint. /// - public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") + public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + Dictionary attrs = new(); + attrs["dtypes"] = dtypes; + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + "RestoreV2", name, prefix, tensor_names, shape_and_slices + ) + { attrs = attrs }); + return result; + } + catch (Exception) + { + try + { + return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); + } + catch (Exception) + { + + } + } + } var dict = new Dictionary(); dict["prefix"] = prefix; dict["tensor_names"] = tensor_names; @@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations return (tensors); } + public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) + { + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + object[] attrs = new object[] { "dtypes", dtypes }; + Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; + var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); + + if (execute.must_record_gradient()) + { + // TODO(Rinne); record the gradient + } + return result; + } + /// /// Reverses specific dimensions of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 35c5877f..16e1bac4 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -62,6 +62,7 @@ namespace Tensorflow public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { + // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); return _op.outputs; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 1309a617..2fd0d1d8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -39,6 +39,24 @@ namespace Tensorflow _op = value; } } + public BaseResourceVariable variable + { + get + { + if (_op.TryGet(out var v)) + { + return v; + } + else + { + throw new TypeError("The _op is not a variable."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 1f8d1a01..9595ba11 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -63,7 +63,7 @@ namespace Tensorflow if (!save_options.experimental_skip_checkpoint) { - // TODO: implement it. + _restore_checkpoint(); } foreach(var node in _nodes) { @@ -398,13 +398,27 @@ namespace Tensorflow /// private void _restore_checkpoint() { - var variables_path = SavedModelUtils.get_variables_dir(_export_dir); + var variables_path = SavedModelUtils.get_variables_path(_export_dir); var saver = new TrackableSaver(new ObjectGraphView(get(0))); tf.device("CPU"); saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + LoadStatus load_status; if (_save_options.allow_partial_checkpoint) { + load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); + load_status.assert_nontrivial_match(); + } + else + { + load_status = saver.restore(variables_path, _checkpoint_options); + load_status.assert_existing_objects_matched(); + } + var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; + if (!tf.Context.executing_eagerly()) + { + throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + + "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); } } diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index cc9be7a2..20831122 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -68,6 +68,34 @@ namespace Tensorflow return saveables.ToArray(); } + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach (var (name, op) in enumerate(names_to_saveables)) + { + foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + return saveables.ToArray(); + } + + public static MySaveableObject[] validate_and_slice_inputs(Dictionary names_to_saveables) + { + var saveables = new List(); + var seen_ops = new List(); + + foreach(var item in names_to_saveables.OrderBy(x => x.Key)) + { + foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) + { + _add_saveable(saveables, seen_ops, converted_saveable_object); + } + } + return saveables.ToArray(); + } + private static void _add_saveable(List saveables, List seen_ops, T saveable) where T : MySaveableObject { if (seen_ops.Contains(saveable.op)) @@ -77,6 +105,15 @@ namespace Tensorflow seen_ops.Add(saveable.op); } + private static void _add_saveable(List saveables, List seen_ops, MySaveableObject saveable) + { + if (seen_ops.Contains(saveable.variable)) + throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); + + saveables.Add(saveable); + seen_ops.Add(saveable.variable); + } + /// /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// @@ -136,19 +173,20 @@ namespace Tensorflow { full_name = name + "_" + attr; } - if(factory.TryGet(out var variable)) + var op = factory(full_name); + if(op.TryGet(out var variable)) { - foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) { - yield return op; + yield return v; } } else { - var saveable = factory.GetValue(); - foreach (var op in saveable_objects_for_op(saveable, saveable.name)) + var saveable = op.GetValue(); + foreach (var v in saveable_objects_for_op(saveable, saveable.name)) { - yield return op; + yield return v; } } } @@ -214,20 +252,19 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary> saveable_objects_from_trackable(Trackable obj) + public static IDictionary>> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` - if (trackable_has_serialize_to_tensor(obj)) + Maybe create_saveable(string name = "") { - var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. var tensor_dict = obj.serialize_to_tensors(); List specs = new(); List local_names = new(); string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; - foreach(var pair in tensor_dict) + foreach (var pair in tensor_dict) { var tensor_name = pair.Key; var maybe_tensor = pair.Value; @@ -235,9 +272,9 @@ namespace Tensorflow string spec_name = name + TrackableUtils.escape_local_name(tensor_name); IDictionary internal_dict; - if(maybe_tensor.TryGet(out var tensor)) + if (maybe_tensor.TryGet(out var tensor)) { - internal_dict= new Dictionary(); + internal_dict = new Dictionary(); internal_dict[""] = tensor; } else @@ -245,13 +282,18 @@ namespace Tensorflow internal_dict = maybe_tensor.GetValue>(); } - foreach(var item in internal_dict) + foreach (var item in internal_dict) { specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); } } - Dictionary> res = new(); - res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return new TrackableSaveable(obj, specs, name, local_names, prefix); + } + + if (trackable_has_serialize_to_tensor(obj)) + { + Dictionary>> res = new(); + res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; return res; } else @@ -339,14 +381,21 @@ namespace Tensorflow /// /// /// - public static IDictionary> recreate_saveable_objects( + public static IDictionary>> recreate_saveable_objects( IDictionary saveable_fn_by_name, IEnumerable? temp_session) { if (saveable_fn_by_name.Count > 0) { throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); } - return new Dictionary>(); + var res = new Dictionary>>(); + return res; + } + + public static Maybe create_saveable_object(string name, string key, Func> factory, + bool call_with_mapped_captures = false) + { + return factory(key); } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index a1de569e..7c86a580 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -41,9 +41,10 @@ namespace Tensorflow.Train protected IDictionary _unconditional_dependency_names; protected IList _unconditional_checkpoint_dependencies; + protected Dictionary> _unconditional_deferred_dependencies; - protected IDictionary> _self_saveable_object_factories = - new Dictionary>(); + protected IDictionary>> _self_saveable_object_factories = + new Dictionary>>(); private bool _manual_tracking = true; private static Trackable _none = new AutoTrackable(); @@ -71,7 +72,8 @@ namespace Tensorflow.Train public IList UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } - public IDictionary> SelfSaveableObjectFactories + public Dictionary> DeferredDependencies => _unconditional_deferred_dependencies; + public IDictionary>> SelfSaveableObjectFactories { get { @@ -147,9 +149,11 @@ namespace Tensorflow.Train _self_update_uid = -1; _unconditional_checkpoint_dependencies = new List(); _unconditional_dependency_names = new Dictionary(); + _unconditional_deferred_dependencies = new Dictionary>(); } - public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) + public virtual IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, + IDictionary>? cache = null) { _maybe_initialize_trackable(); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); @@ -185,10 +189,19 @@ namespace Tensorflow.Train /// public virtual void _handle_deferred_dependencies(string name, Trackable trackable) { - //_maybe_initialize_trackable(); - //trackable._maybe_initialize_trackable(); - - // TODO: complete the implementation. + _maybe_initialize_trackable(); + trackable._maybe_initialize_trackable(); + + if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) + { + _unconditional_deferred_dependencies.Remove(name); + foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) + { + checkpoint_position.restore(trackable); + } + } + + // TODO(Rinne): deal with `_self_name_based_restores` } public virtual Trackable? _lookup_dependency(string name) @@ -236,12 +249,19 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary> gather_saveables_for_checkpoint() + public virtual IDictionary>> gather_saveables_for_checkpoint() { + Maybe create_saveable(string name = "") + { + throw new NotImplementedException(); + //return new TrackableSaveable(this, null, name, null, null); + } if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). - throw new NotImplementedException(); + Dictionary>> res = new(); + res[""] = create_saveable; + return res; } else { diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 8d513191..05c513a8 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -21,9 +21,9 @@ public static class TrackableUtils LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); } } - private static string _ESCAPE_CHAR = "."; - private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; - private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; + internal static string _ESCAPE_CHAR = "."; + internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; + internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; public static string object_path_to_string(IEnumerable node_path_arr) { diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 97203604..9b8cfcb5 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -293,10 +293,10 @@ namespace Tensorflow resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); } - public override IDictionary> gather_saveables_for_checkpoint() + public override IDictionary>> gather_saveables_for_checkpoint() { - var res = new Dictionary>(); - res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + var res = new Dictionary>>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; return res; } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 63a7fe9c..d43b1358 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -21,7 +21,7 @@ public class SequentialModelLoad [TestMethod] public void SimpleModelFromSequential() { - var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); + var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); Debug.Assert(model is Model); var m = model as Model;