| @@ -158,4 +158,13 @@ public static class CheckPointUtils | |||
| { | |||
| return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||
| } | |||
| internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
| { | |||
| return full_list.TakeWhile(x => | |||
| { | |||
| var saveables = x.gather_saveables_for_checkpoint(); | |||
| return saveables is not null && saveables.Count > 0; | |||
| }); | |||
| } | |||
| } | |||
| @@ -1,12 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| internal class CheckpointReader : IDisposable | |||
| public class CheckpointReader : IDisposable | |||
| { | |||
| private IntPtr _reader; | |||
| public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | |||
| @@ -61,14 +62,14 @@ namespace Tensorflow.Checkpoint | |||
| return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | |||
| } | |||
| public Tensor GetTensor(string name) | |||
| public unsafe Tensor GetTensor(string name) | |||
| { | |||
| Status status = new Status(); | |||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | |||
| status.Check(true); | |||
| var shape = GetVariableShape(name); | |||
| var dtype = GetVariableDataType(name); | |||
| return new Tensor(tensor, shape, dtype); | |||
| return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); | |||
| } | |||
| private void ReadAllShapeAndType() | |||
| @@ -175,9 +175,9 @@ public static class SaveUtilV1 | |||
| { | |||
| var name = factory_data.name; | |||
| var key = factory_data.checkpoint_key; | |||
| var maybe_saveable = factory_data.factory; | |||
| var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); | |||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | |||
| List<MySaveableObject> saveables = new(); | |||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
| { | |||
| @@ -217,7 +217,7 @@ public static class SaveUtilV1 | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -24,6 +24,6 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); | |||
| internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -10,6 +10,8 @@ using Tensorflow.Exceptions; | |||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -259,11 +261,48 @@ public class TrackableSaver | |||
| saveables_cache: null | |||
| ); | |||
| throw new NotImplementedException(); | |||
| new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); | |||
| if(_graph_view.AttachedDependencies is not null) | |||
| { | |||
| foreach(var refer in _graph_view.AttachedDependencies) | |||
| { | |||
| if(refer.Name == "root") | |||
| { | |||
| continue; | |||
| } | |||
| int? proto_id = null; | |||
| // Find proto ID of attached dependency (if it is in the proto). | |||
| foreach (var proto_refer in object_graph_proto.Nodes[0].Children) | |||
| { | |||
| if(proto_refer.LocalName == refer.Name) | |||
| { | |||
| proto_id = proto_refer.NodeId; | |||
| break; | |||
| } | |||
| } | |||
| if (proto_id is null) | |||
| { | |||
| continue; | |||
| } | |||
| // Object has already been restored. This can happen when there's an | |||
| // indirect connection from the attached object to the root. | |||
| if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) | |||
| { | |||
| continue; | |||
| } | |||
| new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); | |||
| } | |||
| } | |||
| return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); | |||
| } | |||
| } | |||
| internal class CheckpointRestoreCoordinator | |||
| public class CheckpointRestoreCoordinator | |||
| { | |||
| private CheckpointOptions _options; | |||
| private TrackableObjectGraph _object_graph_proto; | |||
| @@ -280,6 +319,9 @@ internal class CheckpointRestoreCoordinator | |||
| private List<Operation> _restore_ops; | |||
| private List<Trackable> _all_trackables; | |||
| private Dictionary<int, Trackable> _object_by_proto_id; | |||
| private Dictionary<string, Operation> _restore_ops_by_name; | |||
| private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations; | |||
| private Dictionary<int, IList<string>> _unused_attributes; | |||
| public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||
| CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||
| @@ -299,10 +341,12 @@ internal class CheckpointRestoreCoordinator | |||
| _shape_map = _reader.VariableToShapeMap; | |||
| _graph_view = graph_view; | |||
| _restore_ops = new List<Operation>(); | |||
| _restore_ops_by_name = new Dictionary<string, Operation>(); | |||
| _all_trackables = new List<Trackable>(); | |||
| _matched_proto_ids = new HashSet<int>(); | |||
| _object_by_proto_id = new Dictionary<int, Trackable>(); | |||
| _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||
| _deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>(); | |||
| _expect_partial_attr = false; | |||
| for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||
| @@ -330,10 +374,18 @@ internal class CheckpointRestoreCoordinator | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Corresponding to `all_python_objects` of tensorflow python | |||
| /// </summary> | |||
| public List<Trackable> AllTrackables => _all_trackables; | |||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
| public int RestoreUid => _restore_uid; | |||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
| public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations; | |||
| public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations; | |||
| public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name; | |||
| public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes; | |||
| public void new_restore_ops(IEnumerable<Operation> new_ops) | |||
| { | |||
| @@ -341,18 +393,52 @@ internal class CheckpointRestoreCoordinator | |||
| // skip the callback. | |||
| } | |||
| public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) | |||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| List<Operation> restore_ops = new(); | |||
| foreach(var position in positions) | |||
| { | |||
| var key = position.ObjectProto.Attributes[0].CheckpointKey; | |||
| throw new NotImplementedException(); | |||
| } | |||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
| foreach(var item in tensor_saveables) | |||
| { | |||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
| { | |||
| variable_dict[item.Key] = variable; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError(); | |||
| } | |||
| } | |||
| if (tensor_saveables is not null && tensor_saveables.Count > 0) | |||
| { | |||
| var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); | |||
| var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| foreach(var item in new_restore_ops) | |||
| { | |||
| restore_ops.Add(item.Value); | |||
| Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); | |||
| _restore_ops_by_name[item.Key] = item.Value; | |||
| } | |||
| } | |||
| } | |||
| return restore_ops; | |||
| } | |||
| } | |||
| public abstract class LoadStatus | |||
| { | |||
| public abstract void assert_consumed(); | |||
| public abstract void assert_existing_objects_matched(); | |||
| public abstract void assert_nontrivial_match(); | |||
| public abstract void run_restore_ops(Session? session = null); | |||
| public abstract LoadStatus assert_consumed(); | |||
| public abstract LoadStatus assert_existing_objects_matched(); | |||
| public abstract LoadStatus assert_nontrivial_match(); | |||
| public abstract LoadStatus run_restore_ops(Session? session = null); | |||
| public abstract void initialize_or_restore(Session? session = null); | |||
| public virtual LoadStatus expect_partial() | |||
| { | |||
| @@ -371,19 +457,19 @@ public class InitializationOnlyStatus: LoadStatus | |||
| _object_graph_view = object_graph_view; | |||
| _root = object_graph_view.Root; | |||
| } | |||
| public override void assert_consumed() | |||
| public override LoadStatus assert_consumed() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void assert_existing_objects_matched() | |||
| public override LoadStatus assert_existing_objects_matched() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void assert_nontrivial_match() | |||
| public override LoadStatus assert_nontrivial_match() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void run_restore_ops(Session? session = null) | |||
| public override LoadStatus run_restore_ops(Session? session = null) | |||
| { | |||
| throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||
| + "(save_path=None to Saver.restore)."); | |||
| @@ -403,10 +489,78 @@ public class InitializationOnlyStatus: LoadStatus | |||
| } | |||
| } | |||
| public class CheckpointLoadStatus | |||
| internal class CheckpointLoadStatus: LoadStatus | |||
| { | |||
| public CheckpointLoadStatus() | |||
| private CheckpointRestoreCoordinator _checkpoint; | |||
| private Dictionary<Tensor, string> _feed_dict; | |||
| private ObjectGraphView _object_graph_view; | |||
| private Trackable _root; | |||
| public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base() | |||
| { | |||
| _checkpoint = checkpoint; | |||
| _feed_dict = feed_dict; | |||
| _object_graph_view = graph_view; | |||
| _root = graph_view.Root; | |||
| } | |||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
| public override LoadStatus assert_consumed() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus assert_existing_objects_matched() | |||
| { | |||
| for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) | |||
| { | |||
| var node = _checkpoint.ObjectGraphProto.Nodes[i]; | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && | |||
| trackable.UpdateUid < _checkpoint.RestoreUid) | |||
| { | |||
| throw new AssertionError($"Object {node} not assigned a value from checkpoint."); | |||
| } | |||
| } | |||
| foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) | |||
| { | |||
| if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) | |||
| { | |||
| continue; | |||
| } | |||
| _checkpoint.AllTrackables.Add(trackable_object); | |||
| } | |||
| var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) | |||
| .Except(_checkpoint.ObjectByProtoId.Values); | |||
| if (unused_trackables.Any()) | |||
| { | |||
| var num_unused_trackables = unused_trackables.Count(); | |||
| var num_variables_to_show = Math.Min(10, num_unused_trackables); | |||
| throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + | |||
| $"not bound to checkpointed values, likely due to changes in the " + | |||
| $"Python program. Showing {num_variables_to_show} of " + | |||
| $"{num_unused_trackables} unmatched objects: " + | |||
| $"{{list(unused_python_objects)[:num_variables_to_show]}}"); | |||
| } | |||
| return this; | |||
| } | |||
| public override LoadStatus assert_nontrivial_match() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus expect_partial() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override void initialize_or_restore(Session? session = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus run_restore_ops(Session? session = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint | |||
| // tf python has code `with ops.device(restore_device):` here. | |||
| tf.device(restore_device); // may be risky. | |||
| var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
| int idx = 0; | |||
| @@ -1,11 +1,15 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Checkpoint; | |||
| internal class CheckpointPosition | |||
| public class CheckpointPosition | |||
| { | |||
| private CheckpointRestoreCoordinator _checkpoint; | |||
| private int _proto_id; | |||
| @@ -18,6 +22,8 @@ internal class CheckpointPosition | |||
| } | |||
| public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
| public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; | |||
| public void restore(Trackable trackable) | |||
| { | |||
| @@ -25,7 +31,11 @@ internal class CheckpointPosition | |||
| { | |||
| if (bind_project(trackable)) | |||
| { | |||
| var restore_ops = _restore_descendants(); | |||
| if(restore_ops is not null && restore_ops.Count > 0) | |||
| { | |||
| _checkpoint.new_restore_ops(restore_ops); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -51,30 +61,271 @@ internal class CheckpointPosition | |||
| } | |||
| } | |||
| public void gather_ops_or_named_saveables() | |||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| { | |||
| // skip the registered_saver | |||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); | |||
| List<Operation> existing_restore_ops; | |||
| List<CheckpointPosition> positions = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
| } | |||
| else if(saveable_factories.Count > 0) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return (existing_restore_ops, named_saveables, positions, null); | |||
| } | |||
| public CheckpointPosition create_child_position(int node_id) | |||
| { | |||
| return new CheckpointPosition(_checkpoint, node_id); | |||
| } | |||
| public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, | |||
| int slot_variable_id, string slot_name) | |||
| { | |||
| //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); | |||
| // TODO(Rinne): implement it. | |||
| return (null, null); | |||
| } | |||
| /// <summary> | |||
| /// Creates a saveable using the _serialize_to_tensor method. | |||
| /// </summary> | |||
| /// <param name="saveable_factories"></param> | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
| suffix = suffix ?? ""; | |||
| var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("The restore under graph mode has not been implemented. " + | |||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
| // skip the cache. | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| dict[saveable_name] = saveable; | |||
| return (new List<Operation>(), dict); | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| if(ObjectProto.Attributes is null) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
| } | |||
| List<Operation> existing_restore_ops = new(); | |||
| HashSet<string> created_compat_names = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| foreach (var serialized_tensor in ObjectProto.Attributes) | |||
| { | |||
| Operation existing_op; | |||
| if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) | |||
| { | |||
| existing_op = null; | |||
| } | |||
| else | |||
| { | |||
| existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; | |||
| } | |||
| if(existing_op is not null) | |||
| { | |||
| existing_restore_ops.Add(existing_op); | |||
| continue; | |||
| } | |||
| if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) | |||
| { | |||
| continue; | |||
| } | |||
| // TODO(Rinne): deal with cache. | |||
| var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); | |||
| if(saveable is null) | |||
| { | |||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
| continue; | |||
| } | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
| } | |||
| return (existing_restore_ops, named_saveables); | |||
| } | |||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
| { | |||
| var expected_factory_name = serialized_tensor.Name; | |||
| var factory_input_name = serialized_tensor.CheckpointKey; | |||
| if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) | |||
| { | |||
| foreach(var item in saveable_factories) | |||
| { | |||
| var factory_name = item.Key; | |||
| var factory = item.Value; | |||
| if (expected_factory_name.StartsWith(factory_name)) | |||
| { | |||
| if(matched_factory is not null) | |||
| { | |||
| throw new ValueError($"Forward compatibility load error: Unable to load " + | |||
| "checkpoint saved in future version of TensorFlow. " + | |||
| "Please update your version of TensorFlow to the " + | |||
| "version in which the checkpoint was saved."); | |||
| } | |||
| } | |||
| matched_factory = factory; | |||
| factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; | |||
| created_compat_names.Add(factory_name); | |||
| } | |||
| } | |||
| return matched_factory(factory_input_name); | |||
| } | |||
| private string _extract_saveable_name(string checkpoint_key) | |||
| { | |||
| var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; | |||
| return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); | |||
| } | |||
| /// <summary> | |||
| /// Restore the bound Trackable and dependencies (may be deferred). | |||
| /// </summary> | |||
| private void _restore_descendants() | |||
| private List<Operation> _restore_descendants() | |||
| { | |||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
| visit_queue.Enqueue((this, this.Trackable)); | |||
| List<Operation> restore_ops = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| List<CheckpointPosition> positions = new(); | |||
| CheckpointPosition current_position = null; | |||
| while (visit_queue.Count > 0) | |||
| { | |||
| current_position = visit_queue.Dequeue().Item1; | |||
| var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); | |||
| restore_ops.AddRange(new_restore_ops); | |||
| foreach(var item in new_tensor_saveables) | |||
| { | |||
| tensor_saveables.Add(item.Key, item.Value); | |||
| } | |||
| positions.AddRange(new_positions); | |||
| _queue_children_for_restoration(current_position, visit_queue); | |||
| _queue_slot_variables(current_position, visit_queue); | |||
| } | |||
| restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); | |||
| return restore_ops; | |||
| } | |||
| private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
| { | |||
| var trackable = checkpoint_position.Trackable; | |||
| foreach(var child in checkpoint_position.ObjectProto.Children) | |||
| { | |||
| var child_position = checkpoint_position.create_child_position(child.NodeId); | |||
| var local_object = trackable._lookup_dependency(child.LocalName); | |||
| var child_proto = child_position.ObjectProto; | |||
| if(local_object is null) | |||
| { | |||
| if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) | |||
| { | |||
| trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (child_position.bind_project(local_object)) | |||
| { | |||
| visit_queue.Enqueue((child_position, local_object)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
| { | |||
| var trackable = checkpoint_position.Trackable; | |||
| var checkpoint = checkpoint_position.Checkpoint; | |||
| if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) | |||
| { | |||
| checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); | |||
| foreach (var deferred_slot_restoration in positions) | |||
| { | |||
| var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( | |||
| trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, | |||
| deferred_slot_restoration.SlotName | |||
| ); | |||
| if(slot_variable_position is not null) | |||
| { | |||
| visit_queue.Enqueue((slot_variable_position, slot_variable)); | |||
| } | |||
| } | |||
| } | |||
| if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) | |||
| { | |||
| checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); | |||
| foreach (var slot_restoration in restorations) | |||
| { | |||
| if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // TODO(Rinne); implement it. | |||
| } | |||
| else | |||
| { | |||
| Debug.Assert(trackable is BaseResourceVariable); | |||
| Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>()) | |||
| .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private void _single_restore() | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| { | |||
| var trackable = this.Trackable; | |||
| trackable._maybe_initialize_trackable(); | |||
| if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||
| { | |||
| var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); | |||
| trackable.UpdateUid = _checkpoint.RestoreUid; | |||
| return (restore_ops, tensor_saveables, positions, registered_savers); | |||
| } | |||
| else | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| } | |||
| } | |||
| public record class DeferredSlotVariableRestoration( | |||
| BaseResourceVariable OriginalVariable, | |||
| int SlotVariableId, | |||
| string SlotName | |||
| ); | |||
| @@ -10,7 +10,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| internal class execute | |||
| internal static class execute | |||
| { | |||
| public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||
| { | |||
| @@ -27,5 +27,9 @@ namespace Tensorflow.Eager | |||
| return tensors; | |||
| } | |||
| public static bool must_record_gradient() | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations | |||
| /// | |||
| /// Callers must ensure all the named tensors are indeed stored in the checkpoint. | |||
| /// </remarks> | |||
| public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
| public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| Dictionary<string, object> attrs = new(); | |||
| attrs["dtypes"] = dtypes; | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
| "RestoreV2", name, prefix, tensor_names, shape_and_slices | |||
| ) | |||
| { attrs = attrs }); | |||
| return result; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| try | |||
| { | |||
| return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["prefix"] = prefix; | |||
| dict["tensor_names"] = tensor_names; | |||
| @@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations | |||
| return (tensors); | |||
| } | |||
| public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) | |||
| { | |||
| prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||
| var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||
| var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
| object[] attrs = new object[] { "dtypes", dtypes }; | |||
| Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||
| var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| // TODO(Rinne); record the gradient | |||
| } | |||
| return result; | |||
| } | |||
| /// <summary> | |||
| /// Reverses specific dimensions of a tensor. | |||
| /// </summary> | |||
| @@ -62,6 +62,7 @@ namespace Tensorflow | |||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
| { | |||
| // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. | |||
| var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||
| return _op.outputs; | |||
| @@ -39,6 +39,24 @@ namespace Tensorflow | |||
| _op = value; | |||
| } | |||
| } | |||
| public BaseResourceVariable variable | |||
| { | |||
| get | |||
| { | |||
| if (_op.TryGet<BaseResourceVariable>(out var v)) | |||
| { | |||
| return v; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("The _op is not a variable."); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| _op = value; | |||
| } | |||
| } | |||
| public SaveSpec[] specs; | |||
| public string name; | |||
| public string device; | |||
| @@ -63,7 +63,7 @@ namespace Tensorflow | |||
| if (!save_options.experimental_skip_checkpoint) | |||
| { | |||
| // TODO: implement it. | |||
| _restore_checkpoint(); | |||
| } | |||
| foreach(var node in _nodes) | |||
| { | |||
| @@ -398,13 +398,27 @@ namespace Tensorflow | |||
| /// </summary> | |||
| private void _restore_checkpoint() | |||
| { | |||
| var variables_path = SavedModelUtils.get_variables_dir(_export_dir); | |||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
| tf.device("CPU"); | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| LoadStatus load_status; | |||
| if (_save_options.allow_partial_checkpoint) | |||
| { | |||
| load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); | |||
| load_status.assert_nontrivial_match(); | |||
| } | |||
| else | |||
| { | |||
| load_status = saver.restore(variables_path, _checkpoint_options); | |||
| load_status.assert_existing_objects_matched(); | |||
| } | |||
| var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + | |||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| @@ -68,6 +68,34 @@ namespace Tensorflow | |||
| return saveables.ToArray(); | |||
| } | |||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables) | |||
| { | |||
| var saveables = new List<MySaveableObject>(); | |||
| var seen_ops = new List<Tensor>(); | |||
| foreach (var (name, op) in enumerate(names_to_saveables)) | |||
| { | |||
| foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) | |||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||
| } | |||
| return saveables.ToArray(); | |||
| } | |||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables) | |||
| { | |||
| var saveables = new List<MySaveableObject>(); | |||
| var seen_ops = new List<BaseResourceVariable>(); | |||
| foreach(var item in names_to_saveables.OrderBy(x => x.Key)) | |||
| { | |||
| foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||
| { | |||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||
| } | |||
| } | |||
| return saveables.ToArray(); | |||
| } | |||
| private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | |||
| { | |||
| if (seen_ops.Contains(saveable.op)) | |||
| @@ -77,6 +105,15 @@ namespace Tensorflow | |||
| seen_ops.Add(saveable.op); | |||
| } | |||
| private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable) | |||
| { | |||
| if (seen_ops.Contains(saveable.variable)) | |||
| throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); | |||
| saveables.Add(saveable); | |||
| seen_ops.Add(saveable.variable); | |||
| } | |||
| /// <summary> | |||
| /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||
| /// </summary> | |||
| @@ -136,19 +173,20 @@ namespace Tensorflow | |||
| { | |||
| full_name = name + "_" + attr; | |||
| } | |||
| if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||
| var op = factory(full_name); | |||
| if(op.TryGet<BaseResourceVariable>(out var variable)) | |||
| { | |||
| foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| { | |||
| yield return op; | |||
| yield return v; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| var saveable = factory.GetValue<MySaveableObject>(); | |||
| foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||
| var saveable = op.GetValue<MySaveableObject>(); | |||
| foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||
| { | |||
| yield return op; | |||
| yield return v; | |||
| } | |||
| } | |||
| } | |||
| @@ -214,20 +252,19 @@ namespace Tensorflow | |||
| return names_to_saveables; | |||
| } | |||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||
| { | |||
| // skip the process of type `PythonState` | |||
| if (trackable_has_serialize_to_tensor(obj)) | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||
| // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||
| var tensor_dict = obj.serialize_to_tensors(); | |||
| List<SaveSpec> specs = new(); | |||
| List<string> local_names = new(); | |||
| string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||
| foreach(var pair in tensor_dict) | |||
| foreach (var pair in tensor_dict) | |||
| { | |||
| var tensor_name = pair.Key; | |||
| var maybe_tensor = pair.Value; | |||
| @@ -235,9 +272,9 @@ namespace Tensorflow | |||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
| IDictionary<string, Tensor> internal_dict; | |||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| { | |||
| internal_dict= new Dictionary<string, Tensor>(); | |||
| internal_dict = new Dictionary<string, Tensor>(); | |||
| internal_dict[""] = tensor; | |||
| } | |||
| else | |||
| @@ -245,13 +282,18 @@ namespace Tensorflow | |||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| foreach(var item in internal_dict) | |||
| foreach (var item in internal_dict) | |||
| { | |||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
| } | |||
| } | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||
| res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| } | |||
| if (trackable_has_serialize_to_tensor(obj)) | |||
| { | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||
| return res; | |||
| } | |||
| else | |||
| @@ -339,14 +381,21 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="saveable_fn_by_name"></param> | |||
| /// <param name="temp_session"></param> | |||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects( | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||
| { | |||
| if (saveable_fn_by_name.Count > 0) | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| return res; | |||
| } | |||
| public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| bool call_with_mapped_captures = false) | |||
| { | |||
| return factory(key); | |||
| } | |||
| } | |||
| @@ -41,9 +41,10 @@ namespace Tensorflow.Train | |||
| protected IDictionary<string, Trackable> _unconditional_dependency_names; | |||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
| protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||
| protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
| new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||
| new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| private bool _manual_tracking = true; | |||
| private static Trackable _none = new AutoTrackable(); | |||
| @@ -71,7 +72,8 @@ namespace Tensorflow.Train | |||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||
| public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories | |||
| public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||
| public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||
| { | |||
| get | |||
| { | |||
| @@ -147,9 +149,11 @@ namespace Tensorflow.Train | |||
| _self_update_uid = -1; | |||
| _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||
| _unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||
| _unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>(); | |||
| } | |||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, | |||
| IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||
| @@ -185,10 +189,19 @@ namespace Tensorflow.Train | |||
| /// <param name="trackable"></param> | |||
| public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | |||
| { | |||
| //_maybe_initialize_trackable(); | |||
| //trackable._maybe_initialize_trackable(); | |||
| // TODO: complete the implementation. | |||
| _maybe_initialize_trackable(); | |||
| trackable._maybe_initialize_trackable(); | |||
| if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) | |||
| { | |||
| _unconditional_deferred_dependencies.Remove(name); | |||
| foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) | |||
| { | |||
| checkpoint_position.restore(trackable); | |||
| } | |||
| } | |||
| // TODO(Rinne): deal with `_self_name_based_restores` | |||
| } | |||
| public virtual Trackable? _lookup_dependency(string name) | |||
| @@ -236,12 +249,19 @@ namespace Tensorflow.Train | |||
| return self_tensor_map.Keys.ToList(); | |||
| } | |||
| public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| throw new NotImplementedException(); | |||
| //return new TrackableSaveable(this, null, name, null, null); | |||
| } | |||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
| { | |||
| // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||
| throw new NotImplementedException(); | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[""] = create_saveable; | |||
| return res; | |||
| } | |||
| else | |||
| { | |||
| @@ -21,9 +21,9 @@ public static class TrackableUtils | |||
| LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||
| } | |||
| } | |||
| private static string _ESCAPE_CHAR = "."; | |||
| private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
| private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
| internal static string _ESCAPE_CHAR = "."; | |||
| internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
| internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
| internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
| public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||
| { | |||
| @@ -293,10 +293,10 @@ namespace Tensorflow | |||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||
| } | |||
| public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||
| return res; | |||
| } | |||
| @@ -21,7 +21,7 @@ public class SequentialModelLoad | |||
| [TestMethod] | |||
| public void SimpleModelFromSequential() | |||
| { | |||
| var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); | |||
| Debug.Assert(model is Model); | |||
| var m = model as Model; | |||