| @@ -158,4 +158,13 @@ public static class CheckPointUtils | |||||
| { | { | ||||
| return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | ||||
| } | } | ||||
| internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||||
| { | |||||
| return full_list.TakeWhile(x => | |||||
| { | |||||
| var saveables = x.gather_saveables_for_checkpoint(); | |||||
| return saveables is not null && saveables.Count > 0; | |||||
| }); | |||||
| } | |||||
| } | } | ||||
| @@ -1,12 +1,13 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
| { | { | ||||
| internal class CheckpointReader : IDisposable | |||||
| public class CheckpointReader : IDisposable | |||||
| { | { | ||||
| private IntPtr _reader; | private IntPtr _reader; | ||||
| public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | ||||
| @@ -61,14 +62,14 @@ namespace Tensorflow.Checkpoint | |||||
| return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | ||||
| } | } | ||||
| public Tensor GetTensor(string name) | |||||
| public unsafe Tensor GetTensor(string name) | |||||
| { | { | ||||
| Status status = new Status(); | Status status = new Status(); | ||||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | ||||
| status.Check(true); | status.Check(true); | ||||
| var shape = GetVariableShape(name); | var shape = GetVariableShape(name); | ||||
| var dtype = GetVariableDataType(name); | var dtype = GetVariableDataType(name); | ||||
| return new Tensor(tensor, shape, dtype); | |||||
| return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); | |||||
| } | } | ||||
| private void ReadAllShapeAndType() | private void ReadAllShapeAndType() | ||||
| @@ -175,9 +175,9 @@ public static class SaveUtilV1 | |||||
| { | { | ||||
| var name = factory_data.name; | var name = factory_data.name; | ||||
| var key = factory_data.checkpoint_key; | var key = factory_data.checkpoint_key; | ||||
| var maybe_saveable = factory_data.factory; | |||||
| var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); | |||||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | |||||
| List<MySaveableObject> saveables = new(); | List<MySaveableObject> saveables = new(); | ||||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | ||||
| { | { | ||||
| @@ -217,7 +217,7 @@ public static class SaveUtilV1 | |||||
| public record class CheckpointFactoryData | public record class CheckpointFactoryData | ||||
| ( | ( | ||||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
| string name, | string name, | ||||
| string checkpoint_key | string checkpoint_key | ||||
| ); | ); | ||||
| @@ -24,6 +24,6 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); | internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); | |||||
| internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,8 @@ using Tensorflow.Exceptions; | |||||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Newtonsoft.Json; | |||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| @@ -259,11 +261,48 @@ public class TrackableSaver | |||||
| saveables_cache: null | saveables_cache: null | ||||
| ); | ); | ||||
| throw new NotImplementedException(); | |||||
| new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); | |||||
| if(_graph_view.AttachedDependencies is not null) | |||||
| { | |||||
| foreach(var refer in _graph_view.AttachedDependencies) | |||||
| { | |||||
| if(refer.Name == "root") | |||||
| { | |||||
| continue; | |||||
| } | |||||
| int? proto_id = null; | |||||
| // Find proto ID of attached dependency (if it is in the proto). | |||||
| foreach (var proto_refer in object_graph_proto.Nodes[0].Children) | |||||
| { | |||||
| if(proto_refer.LocalName == refer.Name) | |||||
| { | |||||
| proto_id = proto_refer.NodeId; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (proto_id is null) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| // Object has already been restored. This can happen when there's an | |||||
| // indirect connection from the attached object to the root. | |||||
| if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); | |||||
| } | |||||
| } | |||||
| return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); | |||||
| } | } | ||||
| } | } | ||||
| internal class CheckpointRestoreCoordinator | |||||
| public class CheckpointRestoreCoordinator | |||||
| { | { | ||||
| private CheckpointOptions _options; | private CheckpointOptions _options; | ||||
| private TrackableObjectGraph _object_graph_proto; | private TrackableObjectGraph _object_graph_proto; | ||||
| @@ -280,6 +319,9 @@ internal class CheckpointRestoreCoordinator | |||||
| private List<Operation> _restore_ops; | private List<Operation> _restore_ops; | ||||
| private List<Trackable> _all_trackables; | private List<Trackable> _all_trackables; | ||||
| private Dictionary<int, Trackable> _object_by_proto_id; | private Dictionary<int, Trackable> _object_by_proto_id; | ||||
| private Dictionary<string, Operation> _restore_ops_by_name; | |||||
| private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations; | |||||
| private Dictionary<int, IList<string>> _unused_attributes; | |||||
| public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | ||||
| CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | ||||
| @@ -299,10 +341,12 @@ internal class CheckpointRestoreCoordinator | |||||
| _shape_map = _reader.VariableToShapeMap; | _shape_map = _reader.VariableToShapeMap; | ||||
| _graph_view = graph_view; | _graph_view = graph_view; | ||||
| _restore_ops = new List<Operation>(); | _restore_ops = new List<Operation>(); | ||||
| _restore_ops_by_name = new Dictionary<string, Operation>(); | |||||
| _all_trackables = new List<Trackable>(); | _all_trackables = new List<Trackable>(); | ||||
| _matched_proto_ids = new HashSet<int>(); | _matched_proto_ids = new HashSet<int>(); | ||||
| _object_by_proto_id = new Dictionary<int, Trackable>(); | _object_by_proto_id = new Dictionary<int, Trackable>(); | ||||
| _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | ||||
| _deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>(); | |||||
| _expect_partial_attr = false; | _expect_partial_attr = false; | ||||
| for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | ||||
| @@ -330,10 +374,18 @@ internal class CheckpointRestoreCoordinator | |||||
| } | } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Corresponding to `all_python_objects` of tensorflow python | |||||
| /// </summary> | |||||
| public List<Trackable> AllTrackables => _all_trackables; | public List<Trackable> AllTrackables => _all_trackables; | ||||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | public HashSet<int> MatchedProtoIds => _matched_proto_ids; | ||||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | ||||
| public int RestoreUid => _restore_uid; | public int RestoreUid => _restore_uid; | ||||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||||
| public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations; | |||||
| public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations; | |||||
| public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name; | |||||
| public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes; | |||||
| public void new_restore_ops(IEnumerable<Operation> new_ops) | public void new_restore_ops(IEnumerable<Operation> new_ops) | ||||
| { | { | ||||
| @@ -341,18 +393,52 @@ internal class CheckpointRestoreCoordinator | |||||
| // skip the callback. | // skip the callback. | ||||
| } | } | ||||
| public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) | |||||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| List<Operation> restore_ops = new(); | |||||
| foreach(var position in positions) | |||||
| { | |||||
| var key = position.ObjectProto.Attributes[0].CheckpointKey; | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||||
| foreach(var item in tensor_saveables) | |||||
| { | |||||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||||
| { | |||||
| variable_dict[item.Key] = variable; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError(); | |||||
| } | |||||
| } | |||||
| if (tensor_saveables is not null && tensor_saveables.Count > 0) | |||||
| { | |||||
| var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); | |||||
| var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); | |||||
| if (!tf.Context.executing_eagerly()) | |||||
| { | |||||
| foreach(var item in new_restore_ops) | |||||
| { | |||||
| restore_ops.Add(item.Value); | |||||
| Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); | |||||
| _restore_ops_by_name[item.Key] = item.Value; | |||||
| } | |||||
| } | |||||
| } | |||||
| return restore_ops; | |||||
| } | } | ||||
| } | } | ||||
| public abstract class LoadStatus | public abstract class LoadStatus | ||||
| { | { | ||||
| public abstract void assert_consumed(); | |||||
| public abstract void assert_existing_objects_matched(); | |||||
| public abstract void assert_nontrivial_match(); | |||||
| public abstract void run_restore_ops(Session? session = null); | |||||
| public abstract LoadStatus assert_consumed(); | |||||
| public abstract LoadStatus assert_existing_objects_matched(); | |||||
| public abstract LoadStatus assert_nontrivial_match(); | |||||
| public abstract LoadStatus run_restore_ops(Session? session = null); | |||||
| public abstract void initialize_or_restore(Session? session = null); | public abstract void initialize_or_restore(Session? session = null); | ||||
| public virtual LoadStatus expect_partial() | public virtual LoadStatus expect_partial() | ||||
| { | { | ||||
| @@ -371,19 +457,19 @@ public class InitializationOnlyStatus: LoadStatus | |||||
| _object_graph_view = object_graph_view; | _object_graph_view = object_graph_view; | ||||
| _root = object_graph_view.Root; | _root = object_graph_view.Root; | ||||
| } | } | ||||
| public override void assert_consumed() | |||||
| public override LoadStatus assert_consumed() | |||||
| { | { | ||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | ||||
| } | } | ||||
| public override void assert_existing_objects_matched() | |||||
| public override LoadStatus assert_existing_objects_matched() | |||||
| { | { | ||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | ||||
| } | } | ||||
| public override void assert_nontrivial_match() | |||||
| public override LoadStatus assert_nontrivial_match() | |||||
| { | { | ||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | ||||
| } | } | ||||
| public override void run_restore_ops(Session? session = null) | |||||
| public override LoadStatus run_restore_ops(Session? session = null) | |||||
| { | { | ||||
| throw new AssertionError("No checkpoint specified, so no restore ops are available " | throw new AssertionError("No checkpoint specified, so no restore ops are available " | ||||
| + "(save_path=None to Saver.restore)."); | + "(save_path=None to Saver.restore)."); | ||||
| @@ -403,10 +489,78 @@ public class InitializationOnlyStatus: LoadStatus | |||||
| } | } | ||||
| } | } | ||||
| public class CheckpointLoadStatus | |||||
| internal class CheckpointLoadStatus: LoadStatus | |||||
| { | { | ||||
| public CheckpointLoadStatus() | |||||
| private CheckpointRestoreCoordinator _checkpoint; | |||||
| private Dictionary<Tensor, string> _feed_dict; | |||||
| private ObjectGraphView _object_graph_view; | |||||
| private Trackable _root; | |||||
| public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base() | |||||
| { | |||||
| _checkpoint = checkpoint; | |||||
| _feed_dict = feed_dict; | |||||
| _object_graph_view = graph_view; | |||||
| _root = graph_view.Root; | |||||
| } | |||||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||||
| public override LoadStatus assert_consumed() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override LoadStatus assert_existing_objects_matched() | |||||
| { | |||||
| for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) | |||||
| { | |||||
| var node = _checkpoint.ObjectGraphProto.Nodes[i]; | |||||
| if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && | |||||
| trackable.UpdateUid < _checkpoint.RestoreUid) | |||||
| { | |||||
| throw new AssertionError($"Object {node} not assigned a value from checkpoint."); | |||||
| } | |||||
| } | |||||
| foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) | |||||
| { | |||||
| if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _checkpoint.AllTrackables.Add(trackable_object); | |||||
| } | |||||
| var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) | |||||
| .Except(_checkpoint.ObjectByProtoId.Values); | |||||
| if (unused_trackables.Any()) | |||||
| { | |||||
| var num_unused_trackables = unused_trackables.Count(); | |||||
| var num_variables_to_show = Math.Min(10, num_unused_trackables); | |||||
| throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + | |||||
| $"not bound to checkpointed values, likely due to changes in the " + | |||||
| $"Python program. Showing {num_variables_to_show} of " + | |||||
| $"{num_unused_trackables} unmatched objects: " + | |||||
| $"{{list(unused_python_objects)[:num_variables_to_show]}}"); | |||||
| } | |||||
| return this; | |||||
| } | |||||
| public override LoadStatus assert_nontrivial_match() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override LoadStatus expect_partial() | |||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void initialize_or_restore(Session? session = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override LoadStatus run_restore_ops(Session? session = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint | |||||
| // tf python has code `with ops.device(restore_device):` here. | // tf python has code `with ops.device(restore_device):` here. | ||||
| tf.device(restore_device); // may be risky. | tf.device(restore_device); // may be risky. | ||||
| var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
| var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | ||||
| int idx = 0; | int idx = 0; | ||||
| @@ -1,11 +1,15 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| internal class CheckpointPosition | |||||
| public class CheckpointPosition | |||||
| { | { | ||||
| private CheckpointRestoreCoordinator _checkpoint; | private CheckpointRestoreCoordinator _checkpoint; | ||||
| private int _proto_id; | private int _proto_id; | ||||
| @@ -18,6 +22,8 @@ internal class CheckpointPosition | |||||
| } | } | ||||
| public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | ||||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||||
| public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; | |||||
| public void restore(Trackable trackable) | public void restore(Trackable trackable) | ||||
| { | { | ||||
| @@ -25,7 +31,11 @@ internal class CheckpointPosition | |||||
| { | { | ||||
| if (bind_project(trackable)) | if (bind_project(trackable)) | ||||
| { | { | ||||
| var restore_ops = _restore_descendants(); | |||||
| if(restore_ops is not null && restore_ops.Count > 0) | |||||
| { | |||||
| _checkpoint.new_restore_ops(restore_ops); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -51,30 +61,271 @@ internal class CheckpointPosition | |||||
| } | } | ||||
| } | } | ||||
| public void gather_ops_or_named_saveables() | |||||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||||
| { | { | ||||
| // skip the registered_saver | // skip the registered_saver | ||||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||||
| { | |||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
| new List<CheckpointPosition>(), null); | |||||
| } | |||||
| var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); | |||||
| List<Operation> existing_restore_ops; | |||||
| List<CheckpointPosition> positions = new(); | |||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||||
| { | |||||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||||
| } | |||||
| else if(saveable_factories.Count > 0) | |||||
| { | |||||
| (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| return (existing_restore_ops, named_saveables, positions, null); | |||||
| } | |||||
| public CheckpointPosition create_child_position(int node_id) | |||||
| { | |||||
| return new CheckpointPosition(_checkpoint, node_id); | |||||
| } | |||||
| public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, | |||||
| int slot_variable_id, string slot_name) | |||||
| { | |||||
| //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); | |||||
| // TODO(Rinne): implement it. | |||||
| return (null, null); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a saveable using the _serialize_to_tensor method. | |||||
| /// </summary> | |||||
| /// <param name="saveable_factories"></param> | |||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| { | |||||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||||
| suffix = suffix ?? ""; | |||||
| var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; | |||||
| if (!tf.Context.executing_eagerly()) | |||||
| { | |||||
| throw new NotImplementedException("The restore under graph mode has not been implemented. " + | |||||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||||
| // skip the cache. | |||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||||
| dict[saveable_name] = saveable; | |||||
| return (new List<Operation>(), dict); | |||||
| } | |||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| if(ObjectProto.Attributes is null) | |||||
| { | |||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||||
| } | |||||
| List<Operation> existing_restore_ops = new(); | |||||
| HashSet<string> created_compat_names = new(); | |||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||||
| foreach (var serialized_tensor in ObjectProto.Attributes) | |||||
| { | |||||
| Operation existing_op; | |||||
| if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) | |||||
| { | |||||
| existing_op = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; | |||||
| } | |||||
| if(existing_op is not null) | |||||
| { | |||||
| existing_restore_ops.Add(existing_op); | |||||
| continue; | |||||
| } | |||||
| if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| // TODO(Rinne): deal with cache. | |||||
| var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); | |||||
| if(saveable is null) | |||||
| { | |||||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||||
| continue; | |||||
| } | |||||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||||
| } | |||||
| return (existing_restore_ops, named_saveables); | |||||
| } | |||||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||||
| { | |||||
| var expected_factory_name = serialized_tensor.Name; | |||||
| var factory_input_name = serialized_tensor.CheckpointKey; | |||||
| if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) | |||||
| { | |||||
| foreach(var item in saveable_factories) | |||||
| { | |||||
| var factory_name = item.Key; | |||||
| var factory = item.Value; | |||||
| if (expected_factory_name.StartsWith(factory_name)) | |||||
| { | |||||
| if(matched_factory is not null) | |||||
| { | |||||
| throw new ValueError($"Forward compatibility load error: Unable to load " + | |||||
| "checkpoint saved in future version of TensorFlow. " + | |||||
| "Please update your version of TensorFlow to the " + | |||||
| "version in which the checkpoint was saved."); | |||||
| } | |||||
| } | |||||
| matched_factory = factory; | |||||
| factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; | |||||
| created_compat_names.Add(factory_name); | |||||
| } | |||||
| } | |||||
| return matched_factory(factory_input_name); | |||||
| } | |||||
| private string _extract_saveable_name(string checkpoint_key) | |||||
| { | |||||
| var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; | |||||
| return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Restore the bound Trackable and dependencies (may be deferred). | /// Restore the bound Trackable and dependencies (may be deferred). | ||||
| /// </summary> | /// </summary> | ||||
| private void _restore_descendants() | |||||
| private List<Operation> _restore_descendants() | |||||
| { | { | ||||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | ||||
| visit_queue.Enqueue((this, this.Trackable)); | visit_queue.Enqueue((this, this.Trackable)); | ||||
| List<Operation> restore_ops = new(); | |||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||||
| List<CheckpointPosition> positions = new(); | |||||
| CheckpointPosition current_position = null; | |||||
| while (visit_queue.Count > 0) | |||||
| { | |||||
| current_position = visit_queue.Dequeue().Item1; | |||||
| var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); | |||||
| restore_ops.AddRange(new_restore_ops); | |||||
| foreach(var item in new_tensor_saveables) | |||||
| { | |||||
| tensor_saveables.Add(item.Key, item.Value); | |||||
| } | |||||
| positions.AddRange(new_positions); | |||||
| _queue_children_for_restoration(current_position, visit_queue); | |||||
| _queue_slot_variables(current_position, visit_queue); | |||||
| } | |||||
| restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); | |||||
| return restore_ops; | |||||
| } | |||||
| private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||||
| { | |||||
| var trackable = checkpoint_position.Trackable; | |||||
| foreach(var child in checkpoint_position.ObjectProto.Children) | |||||
| { | |||||
| var child_position = checkpoint_position.create_child_position(child.NodeId); | |||||
| var local_object = trackable._lookup_dependency(child.LocalName); | |||||
| var child_proto = child_position.ObjectProto; | |||||
| if(local_object is null) | |||||
| { | |||||
| if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) | |||||
| { | |||||
| trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| if (child_position.bind_project(local_object)) | |||||
| { | |||||
| visit_queue.Enqueue((child_position, local_object)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||||
| { | |||||
| var trackable = checkpoint_position.Trackable; | |||||
| var checkpoint = checkpoint_position.Checkpoint; | |||||
| if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) | |||||
| { | |||||
| checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); | |||||
| foreach (var deferred_slot_restoration in positions) | |||||
| { | |||||
| var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( | |||||
| trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, | |||||
| deferred_slot_restoration.SlotName | |||||
| ); | |||||
| if(slot_variable_position is not null) | |||||
| { | |||||
| visit_queue.Enqueue((slot_variable_position, slot_variable)); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) | |||||
| { | |||||
| checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); | |||||
| foreach (var slot_restoration in restorations) | |||||
| { | |||||
| if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| // TODO(Rinne); implement it. | |||||
| } | |||||
| else | |||||
| { | |||||
| Debug.Assert(trackable is BaseResourceVariable); | |||||
| Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>()) | |||||
| .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| private void _single_restore() | |||||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||||
| { | { | ||||
| var trackable = this.Trackable; | var trackable = this.Trackable; | ||||
| trackable._maybe_initialize_trackable(); | trackable._maybe_initialize_trackable(); | ||||
| if(_checkpoint.RestoreUid > trackable.UpdateUid) | if(_checkpoint.RestoreUid > trackable.UpdateUid) | ||||
| { | { | ||||
| var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); | |||||
| trackable.UpdateUid = _checkpoint.RestoreUid; | |||||
| return (restore_ops, tensor_saveables, positions, registered_savers); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||||
| new List<CheckpointPosition>(), null); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| public record class DeferredSlotVariableRestoration( | |||||
| BaseResourceVariable OriginalVariable, | |||||
| int SlotVariableId, | |||||
| string SlotName | |||||
| ); | |||||
| @@ -10,7 +10,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| internal class execute | |||||
| internal static class execute | |||||
| { | { | ||||
| public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | ||||
| { | { | ||||
| @@ -27,5 +27,9 @@ namespace Tensorflow.Eager | |||||
| return tensors; | return tensors; | ||||
| } | } | ||||
| public static bool must_record_gradient() | |||||
| { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations | |||||
| /// | /// | ||||
| /// Callers must ensure all the named tensors are indeed stored in the checkpoint. | /// Callers must ensure all the named tensors are indeed stored in the checkpoint. | ||||
| /// </remarks> | /// </remarks> | ||||
| public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||||
| public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||||
| { | { | ||||
| var ctx = tf.Context; | |||||
| if (ctx.executing_eagerly()) | |||||
| { | |||||
| try | |||||
| { | |||||
| Dictionary<string, object> attrs = new(); | |||||
| attrs["dtypes"] = dtypes; | |||||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||||
| "RestoreV2", name, prefix, tensor_names, shape_and_slices | |||||
| ) | |||||
| { attrs = attrs }); | |||||
| return result; | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| try | |||||
| { | |||||
| return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); | |||||
| } | |||||
| catch (Exception) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| var dict = new Dictionary<string, object>(); | var dict = new Dictionary<string, object>(); | ||||
| dict["prefix"] = prefix; | dict["prefix"] = prefix; | ||||
| dict["tensor_names"] = tensor_names; | dict["tensor_names"] = tensor_names; | ||||
| @@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations | |||||
| return (tensors); | return (tensors); | ||||
| } | } | ||||
| public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) | |||||
| { | |||||
| prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||||
| var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||||
| var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||||
| object[] attrs = new object[] { "dtypes", dtypes }; | |||||
| Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||||
| var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||||
| if (execute.must_record_gradient()) | |||||
| { | |||||
| // TODO(Rinne); record the gradient | |||||
| } | |||||
| return result; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Reverses specific dimensions of a tensor. | /// Reverses specific dimensions of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -62,6 +62,7 @@ namespace Tensorflow | |||||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | ||||
| { | { | ||||
| // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. | |||||
| var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | ||||
| return _op.outputs; | return _op.outputs; | ||||
| @@ -39,6 +39,24 @@ namespace Tensorflow | |||||
| _op = value; | _op = value; | ||||
| } | } | ||||
| } | } | ||||
| public BaseResourceVariable variable | |||||
| { | |||||
| get | |||||
| { | |||||
| if (_op.TryGet<BaseResourceVariable>(out var v)) | |||||
| { | |||||
| return v; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError("The _op is not a variable."); | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| _op = value; | |||||
| } | |||||
| } | |||||
| public SaveSpec[] specs; | public SaveSpec[] specs; | ||||
| public string name; | public string name; | ||||
| public string device; | public string device; | ||||
| @@ -63,7 +63,7 @@ namespace Tensorflow | |||||
| if (!save_options.experimental_skip_checkpoint) | if (!save_options.experimental_skip_checkpoint) | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| _restore_checkpoint(); | |||||
| } | } | ||||
| foreach(var node in _nodes) | foreach(var node in _nodes) | ||||
| { | { | ||||
| @@ -398,13 +398,27 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| private void _restore_checkpoint() | private void _restore_checkpoint() | ||||
| { | { | ||||
| var variables_path = SavedModelUtils.get_variables_dir(_export_dir); | |||||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | var saver = new TrackableSaver(new ObjectGraphView(get(0))); | ||||
| tf.device("CPU"); | tf.device("CPU"); | ||||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | ||||
| LoadStatus load_status; | |||||
| if (_save_options.allow_partial_checkpoint) | if (_save_options.allow_partial_checkpoint) | ||||
| { | { | ||||
| load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); | |||||
| load_status.assert_nontrivial_match(); | |||||
| } | |||||
| else | |||||
| { | |||||
| load_status = saver.restore(variables_path, _checkpoint_options); | |||||
| load_status.assert_existing_objects_matched(); | |||||
| } | |||||
| var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; | |||||
| if (!tf.Context.executing_eagerly()) | |||||
| { | |||||
| throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + | |||||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -68,6 +68,34 @@ namespace Tensorflow | |||||
| return saveables.ToArray(); | return saveables.ToArray(); | ||||
| } | } | ||||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables) | |||||
| { | |||||
| var saveables = new List<MySaveableObject>(); | |||||
| var seen_ops = new List<Tensor>(); | |||||
| foreach (var (name, op) in enumerate(names_to_saveables)) | |||||
| { | |||||
| foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) | |||||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||||
| } | |||||
| return saveables.ToArray(); | |||||
| } | |||||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables) | |||||
| { | |||||
| var saveables = new List<MySaveableObject>(); | |||||
| var seen_ops = new List<BaseResourceVariable>(); | |||||
| foreach(var item in names_to_saveables.OrderBy(x => x.Key)) | |||||
| { | |||||
| foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||||
| { | |||||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||||
| } | |||||
| } | |||||
| return saveables.ToArray(); | |||||
| } | |||||
| private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | ||||
| { | { | ||||
| if (seen_ops.Contains(saveable.op)) | if (seen_ops.Contains(saveable.op)) | ||||
| @@ -77,6 +105,15 @@ namespace Tensorflow | |||||
| seen_ops.Add(saveable.op); | seen_ops.Add(saveable.op); | ||||
| } | } | ||||
| private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable) | |||||
| { | |||||
| if (seen_ops.Contains(saveable.variable)) | |||||
| throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); | |||||
| saveables.Add(saveable); | |||||
| seen_ops.Add(saveable.variable); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -136,19 +173,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| full_name = name + "_" + attr; | full_name = name + "_" + attr; | ||||
| } | } | ||||
| if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||||
| var op = factory(full_name); | |||||
| if(op.TryGet<BaseResourceVariable>(out var variable)) | |||||
| { | { | ||||
| foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
| foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||||
| { | { | ||||
| yield return op; | |||||
| yield return v; | |||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var saveable = factory.GetValue<MySaveableObject>(); | |||||
| foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||||
| var saveable = op.GetValue<MySaveableObject>(); | |||||
| foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||||
| { | { | ||||
| yield return op; | |||||
| yield return v; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -214,20 +252,19 @@ namespace Tensorflow | |||||
| return names_to_saveables; | return names_to_saveables; | ||||
| } | } | ||||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||||
| { | { | ||||
| // skip the process of type `PythonState` | // skip the process of type `PythonState` | ||||
| if (trackable_has_serialize_to_tensor(obj)) | |||||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| { | { | ||||
| var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||||
| // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | ||||
| var tensor_dict = obj.serialize_to_tensors(); | var tensor_dict = obj.serialize_to_tensors(); | ||||
| List<SaveSpec> specs = new(); | List<SaveSpec> specs = new(); | ||||
| List<string> local_names = new(); | List<string> local_names = new(); | ||||
| string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | ||||
| foreach(var pair in tensor_dict) | |||||
| foreach (var pair in tensor_dict) | |||||
| { | { | ||||
| var tensor_name = pair.Key; | var tensor_name = pair.Key; | ||||
| var maybe_tensor = pair.Value; | var maybe_tensor = pair.Value; | ||||
| @@ -235,9 +272,9 @@ namespace Tensorflow | |||||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | ||||
| IDictionary<string, Tensor> internal_dict; | IDictionary<string, Tensor> internal_dict; | ||||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| { | { | ||||
| internal_dict= new Dictionary<string, Tensor>(); | |||||
| internal_dict = new Dictionary<string, Tensor>(); | |||||
| internal_dict[""] = tensor; | internal_dict[""] = tensor; | ||||
| } | } | ||||
| else | else | ||||
| @@ -245,13 +282,18 @@ namespace Tensorflow | |||||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | ||||
| } | } | ||||
| foreach(var item in internal_dict) | |||||
| foreach (var item in internal_dict) | |||||
| { | { | ||||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | ||||
| } | } | ||||
| } | } | ||||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||||
| res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
| return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||||
| } | |||||
| if (trackable_has_serialize_to_tensor(obj)) | |||||
| { | |||||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||||
| return res; | return res; | ||||
| } | } | ||||
| else | else | ||||
| @@ -339,14 +381,21 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="saveable_fn_by_name"></param> | /// <param name="saveable_fn_by_name"></param> | ||||
| /// <param name="temp_session"></param> | /// <param name="temp_session"></param> | ||||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects( | |||||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | ||||
| { | { | ||||
| if (saveable_fn_by_name.Count > 0) | if (saveable_fn_by_name.Count > 0) | ||||
| { | { | ||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | ||||
| } | } | ||||
| return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| return res; | |||||
| } | |||||
| public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||||
| bool call_with_mapped_captures = false) | |||||
| { | |||||
| return factory(key); | |||||
| } | } | ||||
| } | } | ||||
| @@ -41,9 +41,10 @@ namespace Tensorflow.Train | |||||
| protected IDictionary<string, Trackable> _unconditional_dependency_names; | protected IDictionary<string, Trackable> _unconditional_dependency_names; | ||||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | ||||
| protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||||
| protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||||
| new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
| protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||||
| new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| private bool _manual_tracking = true; | private bool _manual_tracking = true; | ||||
| private static Trackable _none = new AutoTrackable(); | private static Trackable _none = new AutoTrackable(); | ||||
| @@ -71,7 +72,8 @@ namespace Tensorflow.Train | |||||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | ||||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | ||||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | ||||
| public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories | |||||
| public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||||
| public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -147,9 +149,11 @@ namespace Tensorflow.Train | |||||
| _self_update_uid = -1; | _self_update_uid = -1; | ||||
| _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | ||||
| _unconditional_dependency_names = new Dictionary<string, Trackable>(); | _unconditional_dependency_names = new Dictionary<string, Trackable>(); | ||||
| _unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>(); | |||||
| } | } | ||||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, | |||||
| IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||||
| { | { | ||||
| _maybe_initialize_trackable(); | _maybe_initialize_trackable(); | ||||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | ||||
| @@ -185,10 +189,19 @@ namespace Tensorflow.Train | |||||
| /// <param name="trackable"></param> | /// <param name="trackable"></param> | ||||
| public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | ||||
| { | { | ||||
| //_maybe_initialize_trackable(); | |||||
| //trackable._maybe_initialize_trackable(); | |||||
| // TODO: complete the implementation. | |||||
| _maybe_initialize_trackable(); | |||||
| trackable._maybe_initialize_trackable(); | |||||
| if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) | |||||
| { | |||||
| _unconditional_deferred_dependencies.Remove(name); | |||||
| foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) | |||||
| { | |||||
| checkpoint_position.restore(trackable); | |||||
| } | |||||
| } | |||||
| // TODO(Rinne): deal with `_self_name_based_restores` | |||||
| } | } | ||||
| public virtual Trackable? _lookup_dependency(string name) | public virtual Trackable? _lookup_dependency(string name) | ||||
| @@ -236,12 +249,19 @@ namespace Tensorflow.Train | |||||
| return self_tensor_map.Keys.ToList(); | return self_tensor_map.Keys.ToList(); | ||||
| } | } | ||||
| public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
| public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| { | { | ||||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| //return new TrackableSaveable(this, null, name, null, null); | |||||
| } | |||||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | ||||
| { | { | ||||
| // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | ||||
| throw new NotImplementedException(); | |||||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||||
| res[""] = create_saveable; | |||||
| return res; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -21,9 +21,9 @@ public static class TrackableUtils | |||||
| LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | ||||
| } | } | ||||
| } | } | ||||
| private static string _ESCAPE_CHAR = "."; | |||||
| private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
| private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
| internal static string _ESCAPE_CHAR = "."; | |||||
| internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||||
| internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||||
| internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | ||||
| public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | ||||
| { | { | ||||
| @@ -293,10 +293,10 @@ namespace Tensorflow | |||||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | ||||
| } | } | ||||
| public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||||
| public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||||
| { | { | ||||
| var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ public class SequentialModelLoad | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromSequential() | public void SimpleModelFromSequential() | ||||
| { | { | ||||
| var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); | |||||
| Debug.Assert(model is Model); | Debug.Assert(model is Model); | ||||
| var m = model as Model; | var m = model as Model; | ||||