diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
index 8ae2dae8..9812d3c6 100644
--- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
@@ -149,4 +149,13 @@ public static class CheckPointUtils
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
+
+ ///
+ /// Traverse the object graph and list all accessible objects.
+ ///
+ ///
+ public static IList list_objects(ObjectGraphView graph_view)
+ {
+ return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
+ }
}
diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
index 0c2862da..a10e8953 100644
--- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
@@ -6,8 +6,10 @@ using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
+using Tensorflow.Exceptions;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;
+using Tensorflow.Operations;
namespace Tensorflow.Checkpoint;
@@ -21,8 +23,20 @@ public class TrackableSaver
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
+ private Tensor? _file_prefix_placeholder = null;
private Dictionary? _object_map = null;
private object? _cache = null;
+ public Tensor? FilePrefixPlaceHolder
+ {
+ get
+ {
+ return _file_prefix_placeholder;
+ }
+ set
+ {
+ _file_prefix_placeholder = value;
+ }
+ }
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
@@ -192,4 +206,207 @@ public class TrackableSaver
return save_path;
}
}
+
+ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
+ {
+ if (options is null)
+ {
+ options = new CheckpointOptions();
+ }
+ if(save_path is null)
+ {
+ return new InitializationOnlyStatus(_graph_view, ops.uid());
+ }
+
+ CheckpointReader reader = new CheckpointReader(save_path);
+ bool graph_building = tf.Context.executing_eagerly();
+ Dictionary dtype_map = null;
+ if (!graph_building)
+ {
+ dtype_map = reader.VariableToDataTypeMap;
+ }
+ Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY);
+
+ Dictionary file_prefix_feed_dict;
+ Tensor file_prefix_tensor;
+ if (graph_building)
+ {
+ if(_file_prefix_placeholder is null)
+ {
+ tf.device("/cpu:0");
+ _file_prefix_placeholder = constant_op.constant("model");
+ }
+ file_prefix_tensor = _file_prefix_placeholder;
+ file_prefix_feed_dict = new();
+ file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
+ }
+ else
+ {
+ tf.device("/cpu:0");
+ file_prefix_tensor = constant_op.constant(save_path);
+ file_prefix_feed_dict = null;
+ }
+ TrackableObjectGraph object_graph_proto = new();
+ object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
+ CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
+ object_graph_proto: object_graph_proto,
+ save_path: save_path,
+ save_path_tensor: file_prefix_tensor,
+ reader: reader,
+ restore_op_cache: null,
+ graph_view: _graph_view,
+ options: options,
+ saveables_cache: null
+ );
+
+ throw new NotImplementedException();
+ }
+}
+
+internal class CheckpointRestoreCoordinator
+{
+ private CheckpointOptions _options;
+ private TrackableObjectGraph _object_graph_proto;
+ private int _restore_uid;
+ private HashSet _matched_proto_ids;
+ private Tensor _save_path_tensor;
+ private string _save_path_string;
+ private CheckpointReader _reader;
+ private Dictionary _dtype_map;
+ private Dictionary _shape_map;
+ private ObjectGraphView _graph_view;
+ private Dictionary> _slot_restorations;
+ private bool _expect_partial_attr;
+ private List _restore_ops;
+ private List _all_trackables;
+ private Dictionary _object_by_proto_id;
+
+ public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor,
+ CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache)
+ {
+ // TODO(Rinne): cache.
+ _options = options;
+ _object_graph_proto = object_graph_proto;
+ _restore_uid = ops.uid();
+ _save_path_tensor = save_path_tensor;
+ _save_path_string = save_path;
+ _reader = reader;
+ if(_reader is null)
+ {
+ _reader = new CheckpointReader(save_path);
+ }
+ _dtype_map = _reader.VariableToDataTypeMap;
+ _shape_map = _reader.VariableToShapeMap;
+ _graph_view = graph_view;
+ _restore_ops = new List();
+ _all_trackables = new List();
+ _matched_proto_ids = new HashSet();
+ _object_by_proto_id = new Dictionary();
+ _slot_restorations = new Dictionary>();
+
+ _expect_partial_attr = false;
+ for(int i = 0; i < _object_graph_proto.Nodes.Count; i++)
+ {
+ var node = _object_graph_proto.Nodes[i];
+ foreach(var slot_reference in node.SlotVariables)
+ {
+ _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List())
+ .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName));
+ }
+ }
+
+ // skip the deleter and cache.
+ }
+
+ public bool ExpectPartial
+ {
+ get
+ {
+ return _expect_partial_attr;
+ }
+ set
+ {
+ _expect_partial_attr = value;
+ }
+ }
+
+ public List AllTrackables => _all_trackables;
+ public HashSet MatchedProtoIds => _matched_proto_ids;
+ public Dictionary ObjectByProtoId => _object_by_proto_id;
+ public int RestoreUid => _restore_uid;
+
+ public void new_restore_ops(IEnumerable new_ops)
+ {
+ _restore_ops.AddRange(new_ops);
+ // skip the callback.
+ }
+
+ public List restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null)
+ {
+ throw new NotImplementedException();
+ }
+}
+
+public abstract class LoadStatus
+{
+ public abstract void assert_consumed();
+ public abstract void assert_existing_objects_matched();
+ public abstract void assert_nontrivial_match();
+ public abstract void run_restore_ops(Session? session = null);
+ public abstract void initialize_or_restore(Session? session = null);
+ public virtual LoadStatus expect_partial()
+ {
+ return this;
+ }
+}
+
+public class InitializationOnlyStatus: LoadStatus
+{
+ private int _restore_uid;
+ private ObjectGraphView _object_graph_view;
+ private Trackable _root;
+ public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid)
+ {
+ _restore_uid = restore_uid;
+ _object_graph_view = object_graph_view;
+ _root = object_graph_view.Root;
+ }
+ public override void assert_consumed()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override void assert_existing_objects_matched()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override void assert_nontrivial_match()
+ {
+ throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
+ }
+ public override void run_restore_ops(Session? session = null)
+ {
+ throw new AssertionError("No checkpoint specified, so no restore ops are available "
+ + "(save_path=None to Saver.restore).");
+ }
+ public override void initialize_or_restore(Session? session = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ return;
+ }
+ if(session is null)
+ {
+ session = new Session();
+ }
+ var trackable_objects = CheckPointUtils.list_objects(_object_graph_view);
+ throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
+ }
+}
+
+public class CheckpointLoadStatus
+{
+ public CheckpointLoadStatus()
+ {
+
+ }
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs
new file mode 100644
index 00000000..2d8bf096
--- /dev/null
+++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs
@@ -0,0 +1,80 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Train;
+
+namespace Tensorflow.Checkpoint;
+
+internal class CheckpointPosition
+{
+ private CheckpointRestoreCoordinator _checkpoint;
+ private int _proto_id;
+ private bool _skip_restore;
+ public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id)
+ {
+ _checkpoint = checkpoint;
+ _proto_id = proto_id;
+ _skip_restore = false;
+ }
+
+ public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id];
+
+ public void restore(Trackable trackable)
+ {
+ using (ops.init_scope())
+ {
+ if (bind_project(trackable))
+ {
+
+ }
+ }
+ }
+
+ ///
+ /// Set a checkpoint<->object correspondence.
+ ///
+ ///
+ ///
+ public bool bind_project(Trackable trackable)
+ {
+ _checkpoint.AllTrackables.Add(trackable);
+ _checkpoint.MatchedProtoIds.Add(_proto_id);
+ if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
+ {
+ // skip the `logging.warning`.
+ return false;
+ }
+ else
+ {
+ _checkpoint.ObjectByProtoId[_proto_id] = trackable;
+ return true;
+ }
+ }
+
+ public void gather_ops_or_named_saveables()
+ {
+ // skip the registered_saver
+
+
+ }
+
+ ///
+ /// Restore the bound Trackable and dependencies (may be deferred).
+ ///
+ private void _restore_descendants()
+ {
+ Queue<(CheckpointPosition, Trackable)> visit_queue = new();
+ visit_queue.Enqueue((this, this.Trackable));
+
+ }
+
+ private void _single_restore()
+ {
+ var trackable = this.Trackable;
+ trackable._maybe_initialize_trackable();
+ if(_checkpoint.RestoreUid > trackable.UpdateUid)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs
index 5f08702d..142b8b64 100644
--- a/src/TensorFlowNET.Core/IO/gfile.cs
+++ b/src/TensorFlowNET.Core/IO/gfile.cs
@@ -16,8 +16,10 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.IO;
using System.Linq;
+using static Tensorflow.Binding;
namespace Tensorflow.IO
{
@@ -63,5 +65,15 @@ namespace Tensorflow.IO
dirs.AddRange(Directory.GetFiles(dir));
return dirs.ToArray();
}
+
+ public string join(params string[] paths)
+ {
+ Debug.Assert(paths.Length >= 1);
+ if (paths[0].Substring(1).Contains("://"))
+ {
+ throw new NotImplementedException("The combination of urls has not been implemented.");
+ }
+ return Path.Combine(paths);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
index 4e190605..dfd8735b 100644
--- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Common
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
- var axis = serializer.Deserialize(reader, typeof(long[]));
+ var axis = serializer.Deserialize(reader, typeof(int[]));
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
diff --git a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
index 4437ba0a..9ff38129 100644
--- a/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
+++ b/src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
+using Tensorflow.Training.Saving.SavedModel;
namespace Tensorflow.ModelSaving
{
diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
index 1b1fa003..6ce7a0b0 100644
--- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
@@ -17,8 +17,8 @@
using System;
using System.Linq;
using Tensorflow.Framework;
-using Tensorflow.ModelSaving;
using Tensorflow.Train;
+using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;
diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs
new file mode 100644
index 00000000..df9bdc1b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public record class LoadOptions
+ {
+ public bool allow_partial_checkpoint;
+ public string experimental_io_device;
+ public bool experimental_skip_checkpoint;
+ public VariablePolicy experimental_variable_policy;
+
+ public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null,
+ bool experimental_skip_checkpoint = false, string experimental_variable_policy = null)
+ {
+ this.allow_partial_checkpoint = allow_partial_checkpoint;
+ this.experimental_io_device = experimental_io_device;
+ this.experimental_skip_checkpoint = experimental_skip_checkpoint;
+ this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
index fe0403c3..60188293 100644
--- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
@@ -1,4 +1,5 @@
-using Tensorflow.Train;
+using System;
+using Tensorflow.Train;
namespace Tensorflow;
@@ -14,4 +15,10 @@ public class RevivedTypes
// TODO: complete the implementation.
return null;
}
+
+ public static Tuple> deserialize(object proto)
+ {
+ // TODO: complete the implementation.
+ return null;
+ }
}
diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs
similarity index 83%
rename from src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
rename to src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs
index 45ebd884..d42f5253 100644
--- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs
@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Text;
-namespace Tensorflow.ModelSaving
+namespace Tensorflow
{
///
/// Options for saving to SavedModel.
@@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving
public bool save_variable_devices()
{
- return this != VariablePolicy.None;
+ return this != None;
}
///
@@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving
///
public static VariablePolicy from_obj(object obj)
{
- if (obj is null) return VariablePolicy.None;
+ if (obj is null) return None;
if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower();
return key switch
{
- null => VariablePolicy.None,
- "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
- "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
+ null => None,
+ "save_variable_devices" => SAVE_VARIABLE_DEVICES,
+ "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
};
}
diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
index 1be54287..5752d728 100644
--- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
@@ -5,7 +5,6 @@ using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions;
-using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
new file mode 100644
index 00000000..1f8d1a01
--- /dev/null
+++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
@@ -0,0 +1,601 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Net.Sockets;
+using System.Text;
+using Tensorflow.Checkpoint;
+using Tensorflow.Train;
+using Tensorflow.Training;
+using pbc = global::Google.Protobuf.Collections;
+using static Tensorflow.Binding;
+using System.Runtime.CompilerServices;
+using Tensorflow.Variables;
+
+namespace Tensorflow
+{
+ ///
+ /// Helper class to load an object-based SavedModel.
+ ///
+ public partial class Loader
+ {
+ private pbc::RepeatedField _asset_file_def;
+ private Dictionary> _operation_attributes;
+ private SavedObjectGraph _proto;
+ private string _export_dir;
+ private CheckpointOptions _checkpoint_options;
+ private LoadOptions _save_options;
+ private IDictionary)> _node_filters;
+ private Dictionary? _node_path_to_id;
+ private List? _filtered_nodes;
+ private List _ordered_node_ids;
+ private Dictionary)> _loaded_nodes;
+ private List _nodes;
+ private Dictionary> _node_setters;
+ public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir,
+ CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary)> filters)
+ {
+ var meta_graph = saved_model_proto.MetaGraphs[0];
+ _asset_file_def = meta_graph.AssetFileDef;
+ _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
+ _proto = object_graph_proto;
+ _export_dir = export_dir;
+ // TODO: `this._concrete_functions` and `this._restored_concrete_functions`
+ _checkpoint_options = ckpt_options;
+ _save_options = save_options;
+
+ // TODO: `this._pretty_printer`
+
+ _node_filters = filters;
+ _node_path_to_id = _convert_node_paths_to_ints();
+ _loaded_nodes = new Dictionary)>();
+ foreach(var filter in filters)
+ {
+ _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
+ }
+
+ _filtered_nodes = _retrieve_all_filtered_nodes();
+
+ _ordered_node_ids = _generate_ordered_node_ids();
+
+ _load_all();
+
+
+ if (!save_options.experimental_skip_checkpoint)
+ {
+ // TODO: implement it.
+ }
+ foreach(var node in _nodes)
+ {
+ // skip the process of `CapturableResource`.
+ }
+ }
+
+ ///
+ /// Maps all string node paths in node_filters to the int node ids.
+ ///
+ ///
+ private Dictionary? _convert_node_paths_to_ints()
+ {
+ if( _node_filters is null)
+ {
+ return null;
+ }
+ Dictionary path_to_int = new();
+ foreach(var node_id in _node_filters.Keys)
+ {
+ int int_node_id;
+ var node_path = node_id.Split('.');
+ if (node_path[0] != "root")
+ {
+ throw new ValueError($"When passing string identifiers to node_filters, the first name" +
+ $" must be root. Received {node_path[0]}.");
+ }
+ int_node_id = 0;
+ for(int i = 0; i < node_path.Length - 1; i++)
+ {
+ var name = node_path[i + 1];
+ int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1)));
+ }
+ path_to_int[node_id] = int_node_id;
+ }
+ return path_to_int;
+ }
+
+ private int _find_node_child(int node_id, string child_name, string path)
+ {
+ foreach(var refer in _proto.Nodes[node_id].Children)
+ {
+ if(refer.LocalName == child_name)
+ {
+ return refer.NodeId;
+ }
+ }
+ throw new ValueError($"Unable to find node {path}.");
+ }
+
+ private List? _retrieve_all_filtered_nodes()
+ {
+ if(_node_filters is null)
+ {
+ return null;
+ }
+
+ HashSet all_filtered_nodes = new();
+ Queue nodes_to_visit = new Queue(_node_filters.Keys);
+
+ while(nodes_to_visit.Count > 0)
+ {
+ var node_path = nodes_to_visit.Dequeue();
+ var node_id = _node_path_to_id[node_path];
+ if (all_filtered_nodes.Contains(node_id))
+ {
+ continue;
+ }
+ all_filtered_nodes.Add(node_id);
+ Trackable node = null;
+ Action