| @@ -149,4 +149,13 @@ public static class CheckPointUtils | |||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||
| // } | |||
| } | |||
| /// <summary> | |||
| /// Traverse the object graph and list all accessible objects. | |||
| /// </summary> | |||
| /// <param name="object_graph_view"></param> | |||
| public static IList<Trackable> list_objects(ObjectGraphView graph_view) | |||
| { | |||
| return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||
| } | |||
| } | |||
| @@ -6,8 +6,10 @@ using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -21,8 +23,20 @@ public class TrackableSaver | |||
| private TrackableObjectGraph _last_save_object_graph; | |||
| private Tensor? _object_graph_feed_tensor = null; | |||
| private Tensor? _file_prefix_feed_tensor = null; | |||
| private Tensor? _file_prefix_placeholder = null; | |||
| private Dictionary<Trackable, Trackable>? _object_map = null; | |||
| private object? _cache = null; | |||
| public Tensor? FilePrefixPlaceHolder | |||
| { | |||
| get | |||
| { | |||
| return _file_prefix_placeholder; | |||
| } | |||
| set | |||
| { | |||
| _file_prefix_placeholder = value; | |||
| } | |||
| } | |||
| public TrackableSaver(ObjectGraphView graph_view) | |||
| { | |||
| _graph_view = graph_view; | |||
| @@ -192,4 +206,207 @@ public class TrackableSaver | |||
| return save_path; | |||
| } | |||
| } | |||
| public LoadStatus restore(string? save_path, CheckpointOptions? options = null) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new CheckpointOptions(); | |||
| } | |||
| if(save_path is null) | |||
| { | |||
| return new InitializationOnlyStatus(_graph_view, ops.uid()); | |||
| } | |||
| CheckpointReader reader = new CheckpointReader(save_path); | |||
| bool graph_building = tf.Context.executing_eagerly(); | |||
| Dictionary<string, TF_DataType> dtype_map = null; | |||
| if (!graph_building) | |||
| { | |||
| dtype_map = reader.VariableToDataTypeMap; | |||
| } | |||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); | |||
| Dictionary<Tensor, string> file_prefix_feed_dict; | |||
| Tensor file_prefix_tensor; | |||
| if (graph_building) | |||
| { | |||
| if(_file_prefix_placeholder is null) | |||
| { | |||
| tf.device("/cpu:0"); | |||
| _file_prefix_placeholder = constant_op.constant("model"); | |||
| } | |||
| file_prefix_tensor = _file_prefix_placeholder; | |||
| file_prefix_feed_dict = new(); | |||
| file_prefix_feed_dict[_file_prefix_placeholder] = save_path; | |||
| } | |||
| else | |||
| { | |||
| tf.device("/cpu:0"); | |||
| file_prefix_tensor = constant_op.constant(save_path); | |||
| file_prefix_feed_dict = null; | |||
| } | |||
| TrackableObjectGraph object_graph_proto = new(); | |||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||
| CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||
| object_graph_proto: object_graph_proto, | |||
| save_path: save_path, | |||
| save_path_tensor: file_prefix_tensor, | |||
| reader: reader, | |||
| restore_op_cache: null, | |||
| graph_view: _graph_view, | |||
| options: options, | |||
| saveables_cache: null | |||
| ); | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| internal class CheckpointRestoreCoordinator | |||
| { | |||
| private CheckpointOptions _options; | |||
| private TrackableObjectGraph _object_graph_proto; | |||
| private int _restore_uid; | |||
| private HashSet<int> _matched_proto_ids; | |||
| private Tensor _save_path_tensor; | |||
| private string _save_path_string; | |||
| private CheckpointReader _reader; | |||
| private Dictionary<string, TF_DataType> _dtype_map; | |||
| private Dictionary<string, Shape> _shape_map; | |||
| private ObjectGraphView _graph_view; | |||
| private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations; | |||
| private bool _expect_partial_attr; | |||
| private List<Operation> _restore_ops; | |||
| private List<Trackable> _all_trackables; | |||
| private Dictionary<int, Trackable> _object_by_proto_id; | |||
| public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||
| CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||
| { | |||
| // TODO(Rinne): cache. | |||
| _options = options; | |||
| _object_graph_proto = object_graph_proto; | |||
| _restore_uid = ops.uid(); | |||
| _save_path_tensor = save_path_tensor; | |||
| _save_path_string = save_path; | |||
| _reader = reader; | |||
| if(_reader is null) | |||
| { | |||
| _reader = new CheckpointReader(save_path); | |||
| } | |||
| _dtype_map = _reader.VariableToDataTypeMap; | |||
| _shape_map = _reader.VariableToShapeMap; | |||
| _graph_view = graph_view; | |||
| _restore_ops = new List<Operation>(); | |||
| _all_trackables = new List<Trackable>(); | |||
| _matched_proto_ids = new HashSet<int>(); | |||
| _object_by_proto_id = new Dictionary<int, Trackable>(); | |||
| _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||
| _expect_partial_attr = false; | |||
| for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||
| { | |||
| var node = _object_graph_proto.Nodes[i]; | |||
| foreach(var slot_reference in node.SlotVariables) | |||
| { | |||
| _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>()) | |||
| .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); | |||
| } | |||
| } | |||
| // skip the deleter and cache. | |||
| } | |||
| public bool ExpectPartial | |||
| { | |||
| get | |||
| { | |||
| return _expect_partial_attr; | |||
| } | |||
| set | |||
| { | |||
| _expect_partial_attr = value; | |||
| } | |||
| } | |||
| public List<Trackable> AllTrackables => _all_trackables; | |||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
| public int RestoreUid => _restore_uid; | |||
| public void new_restore_ops(IEnumerable<Operation> new_ops) | |||
| { | |||
| _restore_ops.AddRange(new_ops); | |||
| // skip the callback. | |||
| } | |||
| public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| public abstract class LoadStatus | |||
| { | |||
| public abstract void assert_consumed(); | |||
| public abstract void assert_existing_objects_matched(); | |||
| public abstract void assert_nontrivial_match(); | |||
| public abstract void run_restore_ops(Session? session = null); | |||
| public abstract void initialize_or_restore(Session? session = null); | |||
| public virtual LoadStatus expect_partial() | |||
| { | |||
| return this; | |||
| } | |||
| } | |||
| public class InitializationOnlyStatus: LoadStatus | |||
| { | |||
| private int _restore_uid; | |||
| private ObjectGraphView _object_graph_view; | |||
| private Trackable _root; | |||
| public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) | |||
| { | |||
| _restore_uid = restore_uid; | |||
| _object_graph_view = object_graph_view; | |||
| _root = object_graph_view.Root; | |||
| } | |||
| public override void assert_consumed() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void assert_existing_objects_matched() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void assert_nontrivial_match() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override void run_restore_ops(Session? session = null) | |||
| { | |||
| throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||
| + "(save_path=None to Saver.restore)."); | |||
| } | |||
| public override void initialize_or_restore(Session? session = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return; | |||
| } | |||
| if(session is null) | |||
| { | |||
| session = new Session(); | |||
| } | |||
| var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| public class CheckpointLoadStatus | |||
| { | |||
| public CheckpointLoadStatus() | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Checkpoint; | |||
| internal class CheckpointPosition | |||
| { | |||
| private CheckpointRestoreCoordinator _checkpoint; | |||
| private int _proto_id; | |||
| private bool _skip_restore; | |||
| public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) | |||
| { | |||
| _checkpoint = checkpoint; | |||
| _proto_id = proto_id; | |||
| _skip_restore = false; | |||
| } | |||
| public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||
| public void restore(Trackable trackable) | |||
| { | |||
| using (ops.init_scope()) | |||
| { | |||
| if (bind_project(trackable)) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Set a checkpoint<->object correspondence. | |||
| /// </summary> | |||
| /// <param name="trackable"></param> | |||
| /// <returns></returns> | |||
| public bool bind_project(Trackable trackable) | |||
| { | |||
| _checkpoint.AllTrackables.Add(trackable); | |||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
| { | |||
| // skip the `logging.warning`. | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| _checkpoint.ObjectByProtoId[_proto_id] = trackable; | |||
| return true; | |||
| } | |||
| } | |||
| public void gather_ops_or_named_saveables() | |||
| { | |||
| // skip the registered_saver | |||
| } | |||
| /// <summary> | |||
| /// Restore the bound Trackable and dependencies (may be deferred). | |||
| /// </summary> | |||
| private void _restore_descendants() | |||
| { | |||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
| visit_queue.Enqueue((this, this.Trackable)); | |||
| } | |||
| private void _single_restore() | |||
| { | |||
| var trackable = this.Trackable; | |||
| trackable._maybe_initialize_trackable(); | |||
| if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -16,8 +16,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.IO | |||
| { | |||
| @@ -63,5 +65,15 @@ namespace Tensorflow.IO | |||
| dirs.AddRange(Directory.GetFiles(dir)); | |||
| return dirs.ToArray(); | |||
| } | |||
| public string join(params string[] paths) | |||
| { | |||
| Debug.Assert(paths.Length >= 1); | |||
| if (paths[0].Substring(1).Contains("://")) | |||
| { | |||
| throw new NotImplementedException("The combination of urls has not been implemented."); | |||
| } | |||
| return Path.Combine(paths); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Common | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| var axis = serializer.Deserialize(reader, typeof(long[])); | |||
| var axis = serializer.Deserialize(reader, typeof(int[])); | |||
| if (axis is null) | |||
| { | |||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow.ModelSaving | |||
| { | |||
| @@ -17,8 +17,8 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| @@ -0,0 +1,23 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public record class LoadOptions | |||
| { | |||
| public bool allow_partial_checkpoint; | |||
| public string experimental_io_device; | |||
| public bool experimental_skip_checkpoint; | |||
| public VariablePolicy experimental_variable_policy; | |||
| public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, | |||
| bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) | |||
| { | |||
| this.allow_partial_checkpoint = allow_partial_checkpoint; | |||
| this.experimental_io_device = experimental_io_device; | |||
| this.experimental_skip_checkpoint = experimental_skip_checkpoint; | |||
| this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Train; | |||
| using System; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow; | |||
| @@ -14,4 +15,10 @@ public class RevivedTypes | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||
| { | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.ModelSaving | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Options for saving to SavedModel. | |||
| @@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving | |||
| public bool save_variable_devices() | |||
| { | |||
| return this != VariablePolicy.None; | |||
| return this != None; | |||
| } | |||
| /// <summary> | |||
| @@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving | |||
| /// <returns></returns> | |||
| public static VariablePolicy from_obj(object obj) | |||
| { | |||
| if (obj is null) return VariablePolicy.None; | |||
| if (obj is null) return None; | |||
| if (obj is VariablePolicy) return (VariablePolicy)obj; | |||
| var key = obj.ToString().ToLower(); | |||
| return key switch | |||
| { | |||
| null => VariablePolicy.None, | |||
| "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||
| "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||
| null => None, | |||
| "save_variable_devices" => SAVE_VARIABLE_DEVICES, | |||
| "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, | |||
| _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||
| }; | |||
| } | |||
| @@ -5,7 +5,6 @@ using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| @@ -0,0 +1,601 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Net.Sockets; | |||
| using System.Text; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Variables; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Helper class to load an object-based SavedModel. | |||
| /// </summary> | |||
| public partial class Loader | |||
| { | |||
| private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def; | |||
| private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes; | |||
| private SavedObjectGraph _proto; | |||
| private string _export_dir; | |||
| private CheckpointOptions _checkpoint_options; | |||
| private LoadOptions _save_options; | |||
| private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters; | |||
| private Dictionary<string, int>? _node_path_to_id; | |||
| private List<int>? _filtered_nodes; | |||
| private List<int> _ordered_node_ids; | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
| private List<Trackable> _nodes; | |||
| private Dictionary<int, Action<object, object, object>> _node_setters; | |||
| public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||
| CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||
| { | |||
| var meta_graph = saved_model_proto.MetaGraphs[0]; | |||
| _asset_file_def = meta_graph.AssetFileDef; | |||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||
| _proto = object_graph_proto; | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| _checkpoint_options = ckpt_options; | |||
| _save_options = save_options; | |||
| // TODO: `this._pretty_printer` | |||
| _node_filters = filters; | |||
| _node_path_to_id = _convert_node_paths_to_ints(); | |||
| _loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
| foreach(var filter in filters) | |||
| { | |||
| _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; | |||
| } | |||
| _filtered_nodes = _retrieve_all_filtered_nodes(); | |||
| _ordered_node_ids = _generate_ordered_node_ids(); | |||
| _load_all(); | |||
| if (!save_options.experimental_skip_checkpoint) | |||
| { | |||
| // TODO: implement it. | |||
| } | |||
| foreach(var node in _nodes) | |||
| { | |||
| // skip the process of `CapturableResource`. | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Maps all string node paths in node_filters to the int node ids. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private Dictionary<string, int>? _convert_node_paths_to_ints() | |||
| { | |||
| if( _node_filters is null) | |||
| { | |||
| return null; | |||
| } | |||
| Dictionary<string, int> path_to_int = new(); | |||
| foreach(var node_id in _node_filters.Keys) | |||
| { | |||
| int int_node_id; | |||
| var node_path = node_id.Split('.'); | |||
| if (node_path[0] != "root") | |||
| { | |||
| throw new ValueError($"When passing string identifiers to node_filters, the first name" + | |||
| $" must be root. Received {node_path[0]}."); | |||
| } | |||
| int_node_id = 0; | |||
| for(int i = 0; i < node_path.Length - 1; i++) | |||
| { | |||
| var name = node_path[i + 1]; | |||
| int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); | |||
| } | |||
| path_to_int[node_id] = int_node_id; | |||
| } | |||
| return path_to_int; | |||
| } | |||
| private int _find_node_child(int node_id, string child_name, string path) | |||
| { | |||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||
| { | |||
| if(refer.LocalName == child_name) | |||
| { | |||
| return refer.NodeId; | |||
| } | |||
| } | |||
| throw new ValueError($"Unable to find node {path}."); | |||
| } | |||
| private List<int>? _retrieve_all_filtered_nodes() | |||
| { | |||
| if(_node_filters is null) | |||
| { | |||
| return null; | |||
| } | |||
| HashSet<int> all_filtered_nodes = new(); | |||
| Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys); | |||
| while(nodes_to_visit.Count > 0) | |||
| { | |||
| var node_path = nodes_to_visit.Dequeue(); | |||
| var node_id = _node_path_to_id[node_path]; | |||
| if (all_filtered_nodes.Contains(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| all_filtered_nodes.Add(node_id); | |||
| Trackable node = null; | |||
| Action<object, object, object> setter = null; | |||
| if(_loaded_nodes.TryGetValue(node_id, out var res)) | |||
| { | |||
| (node, setter) = res; | |||
| } | |||
| if(node is not null) | |||
| { | |||
| node._maybe_initialize_trackable(); | |||
| } | |||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||
| { | |||
| Trackable children_object = null; | |||
| if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) | |||
| { | |||
| children_object = result.Item1; | |||
| } | |||
| // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. | |||
| if(children_object is null && node is not null) | |||
| { | |||
| children_object = node._lookup_dependency(refer.LocalName); | |||
| if(children_object is TrackableDataStructure) | |||
| { | |||
| // TODO: set setter as lambda. | |||
| _loaded_nodes[refer.NodeId] = (children_object, setter); | |||
| } | |||
| } | |||
| string child_path = $"{node_path}.{refer.LocalName}"; | |||
| _node_path_to_id[child_path] = refer.NodeId; | |||
| nodes_to_visit.Enqueue(child_path); | |||
| } | |||
| } | |||
| if (all_filtered_nodes.Contains(0)) | |||
| { | |||
| return null; | |||
| } | |||
| return all_filtered_nodes.ToList(); | |||
| } | |||
| /// <summary> | |||
| /// Orders the node ids so that dependencies appear first. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private List<int> _generate_ordered_node_ids() | |||
| { | |||
| List<int> unordered_ids; | |||
| if(_filtered_nodes is null) | |||
| { | |||
| unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); | |||
| } | |||
| else | |||
| { | |||
| unordered_ids = new List<int>(_filtered_nodes); | |||
| } | |||
| Dictionary<int, List<int>> dependency_map = new(); | |||
| foreach(var node_id in unordered_ids) | |||
| { | |||
| var deps = dependency_map.SetDefault(node_id, new List<int>()); | |||
| if (_loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| var proto = _proto.Nodes[node_id]; | |||
| foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
| { | |||
| deps.Add(dep); | |||
| if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||
| { | |||
| // TODO: add info with `_pretty_printer`. | |||
| throw new ValueError($"Unable to partially load SavedModel since the specified filter " + | |||
| $"does not include all required objects for loading (e.g. " + | |||
| $"variables used in functions or deserialization dependencies). " + | |||
| $"Please include this path in the filter: {dep}"); | |||
| } | |||
| } | |||
| int? prev_slot = null; | |||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||
| { | |||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
| // The optimizer and original variable must be created before the slot | |||
| // variable, since the slot variable is generated using the Optimizer's | |||
| // add_slot API. | |||
| var slot_deps = dependency_map[slot_variable_node_id]; | |||
| slot_deps.Add(node_id); | |||
| slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||
| if(prev_slot is not null) | |||
| { | |||
| slot_deps.Add(prev_slot.Value); | |||
| } | |||
| prev_slot = slot_variable_node_id; | |||
| } | |||
| } | |||
| try | |||
| { | |||
| return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||
| } | |||
| catch (TrackableUtils.CyclicDependencyError ex) | |||
| { | |||
| throw new ValueError("Encountered a cycle in the deserialization dependencies" + | |||
| "in the SavedModel. This is extremely unexpected, please" + | |||
| "file a bug and make sure you are not manually modifying the SavedModel."); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Returns a dictionary of all dependencies of an object. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <returns></returns> | |||
| private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||
| { | |||
| Dictionary<Maybe<string, int>, int> dependencies = new(); | |||
| foreach(var refer in proto.Dependencies) | |||
| { | |||
| dependencies[refer.LocalName] = refer.NodeId; | |||
| } | |||
| if(proto.KindCase == SavedObject.KindOneofCase.Function) | |||
| { | |||
| var concreete_functions = proto.Function.ConcreteFunctions; | |||
| foreach(var fn_name in concreete_functions) | |||
| { | |||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
| { | |||
| dependencies[bound_input] = bound_input; | |||
| } | |||
| } | |||
| } | |||
| else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) | |||
| { | |||
| var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; | |||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
| { | |||
| dependencies[bound_input] = bound_input; | |||
| } | |||
| } | |||
| else if(proto.KindCase == SavedObject.KindOneofCase.Resource) | |||
| { | |||
| foreach(var child in proto.Children) | |||
| { | |||
| if(child.LocalName == "_create_resource") | |||
| { | |||
| dependencies["_create_resource"] = child.NodeId; | |||
| } | |||
| } | |||
| } | |||
| return dependencies; | |||
| } | |||
| /// <summary> | |||
| /// Loads all nodes and functions from the SavedModel and their edges. | |||
| /// </summary> | |||
| private void _load_all() | |||
| { | |||
| _load_nodes(); | |||
| _load_edges(); | |||
| _setup_remaining_functions(); | |||
| _load_checkpoint_save_and_restore_functions(); | |||
| } | |||
| /// <summary> | |||
| /// Restores the checkpoint-related save/restore functions to all nodes. | |||
| /// </summary> | |||
| private void _load_checkpoint_save_and_restore_functions() | |||
| { | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| var node = get(node_id); | |||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| // Restore Trackable serialize- and restore-from-tensor functions. | |||
| Debug.Assert(proto.SaveableObjects.Count == 1); | |||
| var saveable_object_proto = proto.SaveableObjects.Values.First(); | |||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| else | |||
| { | |||
| // Restore legacy SaveableObject functions. | |||
| Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new(); | |||
| foreach(var item in proto.SaveableObjects) | |||
| { | |||
| var name = item.Key; | |||
| var saveable_object_proto = item.Value; | |||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
| saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||
| } | |||
| node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Load all saved objects. | |||
| /// </summary> | |||
| private void _load_nodes() | |||
| { | |||
| // `nodes` maps from node ids to recreated objects | |||
| // `node_setters` maps from node ids to setter functions | |||
| // (same signature as setattr) for setting children. | |||
| var (nodes, node_setters) = _initialize_loaded_nodes(); | |||
| Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)> | |||
| slot_variable_node_ids = new(); | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||
| { | |||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
| slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); | |||
| } | |||
| } | |||
| // Re-create everything. | |||
| foreach (var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| if (nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| else if (slot_variable_node_ids.ContainsKey(node_id)) | |||
| { | |||
| // Use the public Optimizer interface when creating slot variables. | |||
| var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||
| var optimizer_object = nodes[optimizer_node_id]; | |||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
| // TODO: implement it. | |||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| else | |||
| { | |||
| var (node, setter) = _recreate(proto, node_id, nodes); | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| } | |||
| } | |||
| if (!nodes.ContainsKey(0)) | |||
| { | |||
| nodes[0] = _recreate_base_user_object().Item1; | |||
| } | |||
| _nodes = new List<Trackable>(); | |||
| for(int i = 0; i < _proto.Nodes.Count; i++) | |||
| { | |||
| _nodes.Add(nodes[i]); | |||
| } | |||
| _node_setters = node_setters; | |||
| } | |||
| /// <summary> | |||
| /// Load state from checkpoint into the deserialized objects. | |||
| /// </summary> | |||
| private void _restore_checkpoint() | |||
| { | |||
| var variables_path = SavedModelUtils.get_variables_dir(_export_dir); | |||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
| tf.device("CPU"); | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| if (_save_options.allow_partial_checkpoint) | |||
| { | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Adds edges from objects to other objects and functions. | |||
| /// </summary> | |||
| private void _load_edges() | |||
| { | |||
| foreach(var (node_id, object_proto) in _iter_all_nodes()) | |||
| { | |||
| _add_object_graph_edges(object_proto, node_id); | |||
| } | |||
| if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) | |||
| { | |||
| var root = get(0); | |||
| foreach(var node_path in _node_filters.Keys) | |||
| { | |||
| var loaded_node = _nodes[_node_path_to_id[node_path]]; | |||
| var path = node_path.Split('.'); | |||
| var current_node = root; | |||
| foreach(var name in path.Skip(1).Take(path.Length - 2)) | |||
| { | |||
| // `hasattr` and `setattr` is used here | |||
| throw new NotImplementedException(); | |||
| } | |||
| // `hasattr` and `setattr` is used here | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| private void _setup_remaining_functions() | |||
| { | |||
| // TODO: implement it with concrete functions. | |||
| } | |||
| public Trackable get(int node_id) | |||
| { | |||
| return _nodes[node_id]; | |||
| } | |||
| public Trackable get(string node_id) | |||
| { | |||
| return get(_node_path_to_id[node_id]); | |||
| } | |||
| /// <summary> | |||
| /// Adds edges from an object to its children. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | |||
| { | |||
| var obj = _nodes[node_id]; | |||
| var setter = _node_setters[node_id]; | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
| // skip the process of "__call__" | |||
| } | |||
| } | |||
| private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
| { | |||
| Dictionary<int, Trackable> nodes = new(); | |||
| Dictionary<int, Action<object, object, object>> node_setters = new(); | |||
| foreach(var item in _loaded_nodes) | |||
| { | |||
| var node_id = item.Key; | |||
| var (node, setter) = item.Value; | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| } | |||
| return (nodes, node_setters); | |||
| } | |||
| private IEnumerable<(int, SavedObject)> _iter_all_nodes() | |||
| { | |||
| foreach(var node_id in _ordered_node_ids) | |||
| { | |||
| yield return (node_id, _proto.Nodes[node_id]); | |||
| } | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
| { | |||
| // skip the registered classes. | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||
| foreach(var item in _get_node_dependencies(proto)) | |||
| { | |||
| dependencies[item.Key] = nodes[item.Value]; | |||
| } | |||
| return _recreate_default(proto, node_id, dependencies); | |||
| } | |||
| /// <summary> | |||
| /// Creates a Python object from a SavedObject protocol buffer. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| /// <param name="dependencies"></param> | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| return proto.KindCase switch | |||
| { | |||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||
| SavedObject.KindOneofCase.Function => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||
| }; | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||
| { | |||
| // skip the check of proto identifier because of lack of property. | |||
| var looked_up = RevivedTypes.deserialize(proto); | |||
| if(looked_up is null) | |||
| { | |||
| return _recreate_base_user_object(proto, node_id); | |||
| } | |||
| return (looked_up.Item1, looked_up.Item2); | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||
| { | |||
| return (new _UserObject(), setattr); | |||
| } | |||
| private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto) | |||
| { | |||
| string name = proto.Name; | |||
| string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>"; | |||
| // TODO(Rinne): `validate_synchronization_aggregation_trainable` | |||
| var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( | |||
| proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); | |||
| var saved_device = proto.Device; | |||
| var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); | |||
| if (load_with_device) | |||
| { | |||
| tf.device(saved_device); | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| } | |||
| else | |||
| { | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| } | |||
| } | |||
| // TODO: remove this to a common class. | |||
| public static Action<object, object, object> setattr = (x, y, z) => | |||
| { | |||
| Debug.Assert(y is string); | |||
| var properties = x.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| { | |||
| if((string)y == p.Name) | |||
| { | |||
| p.SetValue(x, z); | |||
| return; | |||
| } | |||
| } | |||
| // TODO(Rinne): check if the property has been set successfully. | |||
| //throw new ValueError($"Cannot find the property {y} of {x}."); | |||
| }; | |||
| public class _UserObject: AutoTrackable | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,122 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Loader | |||
| { | |||
| public static SavedModel parse_saved_model(string export_dir) | |||
| { | |||
| var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||
| var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); | |||
| SavedModel saved_model = new SavedModel(); | |||
| if (File.Exists(path_to_pb)) | |||
| { | |||
| byte[] file_content; | |||
| using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| file_content = new byte[f.Length]; | |||
| Debug.Assert(f.Length <= int.MaxValue); | |||
| f.Read(file_content, 0, (int)f.Length); | |||
| } | |||
| // TODO: change to stream mode. | |||
| saved_model.MergeFrom(file_content); | |||
| return saved_model; | |||
| } | |||
| else if (File.Exists(path_to_pbtxt)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + | |||
| $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); | |||
| } | |||
| } | |||
| // TODO: revise the type of `tags` | |||
| public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) | |||
| { | |||
| return load_partial(export_dir, null, tags, options)["root"]; | |||
| } | |||
| public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new LoadOptions(); | |||
| } | |||
| if (tags is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); | |||
| Trackable root = null; | |||
| Loader loader = null; | |||
| if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) | |||
| { | |||
| // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` | |||
| var meta_graph_def = saved_model_proto.MetaGraphs[0]; | |||
| if (!BitConverter.IsLittleEndian) | |||
| { | |||
| SavedModelUtils.swap_function_tensor_content(meta_graph_def); | |||
| } | |||
| var object_graph_proto = meta_graph_def.ObjectGraphDef; | |||
| var ckpt_options = new CheckpointOptions(options.experimental_io_device); | |||
| tf_with(ops.init_scope(), x => | |||
| { | |||
| loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||
| root = loader.get(0); | |||
| // skip the assignment of `graph_debug_info`. | |||
| }); | |||
| // skip the assignment of `tensorflow_version` | |||
| // skip the assignment of `tensorflow_git_version` | |||
| // skip the process of `metrics`. | |||
| } | |||
| else | |||
| { | |||
| if(filters is not null && filters.Count > 0) | |||
| { | |||
| throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" | |||
| + " version) cannot be loaded with node filters."); | |||
| } | |||
| tf_with(ops.init_scope(), x => | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| }); | |||
| } | |||
| if(filters != null && filters.Count > 0) | |||
| { | |||
| return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||
| } | |||
| else | |||
| { | |||
| var res = new Dictionary<string, Trackable>(); | |||
| res["root"] = root; | |||
| return res; | |||
| } | |||
| } | |||
| public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) | |||
| { | |||
| var saved_model = parse_saved_model(export_dir); | |||
| // TODO: implement debug info. | |||
| return (saved_model, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,6 @@ using System.Text; | |||
| using Google.Protobuf; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.Binding; | |||
| @@ -1,7 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.ModelSaving; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| @@ -333,6 +333,21 @@ namespace Tensorflow | |||
| return restored_ops; | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// Returns a dict of SaveableObject factories generated from loaded fns. | |||
| /// </summary> | |||
| /// <param name="saveable_fn_by_name"></param> | |||
| /// <param name="temp_session"></param> | |||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects( | |||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||
| { | |||
| if (saveable_fn_by_name.Count > 0) | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| } | |||
| } | |||
| public class SaveableCompatibilityConverter: Trackable | |||
| @@ -20,8 +20,8 @@ using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| @@ -71,6 +71,17 @@ namespace Tensorflow.Train | |||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||
| public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories | |||
| { | |||
| get | |||
| { | |||
| return _self_saveable_object_factories; | |||
| } | |||
| set | |||
| { | |||
| _self_saveable_object_factories = value; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| @@ -259,4 +270,6 @@ namespace Tensorflow.Train | |||
| } | |||
| public record class TrackableReference(string Name, Trackable Refer); | |||
| public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Train; | |||
| @@ -5,9 +5,9 @@ using Tensorflow.Variables; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.ModelSaving; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -19,7 +19,11 @@ namespace Tensorflow | |||
| protected TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| protected string _handle_name; | |||
| protected string handle_name => _handle_name; | |||
| public string handle_name | |||
| { | |||
| get { return _handle_name; } | |||
| set { _handle_name = value; } | |||
| } | |||
| protected string _unique_id; | |||
| public string UniqueId => _unique_id; | |||
| @@ -238,5 +238,23 @@ namespace Tensorflow | |||
| { | |||
| return _graph_element.eval(session); | |||
| } | |||
| public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( | |||
| VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) | |||
| { | |||
| if(aggregation is null) | |||
| { | |||
| aggregation = VariableAggregation.None; | |||
| } | |||
| if(synchronization is null) | |||
| { | |||
| synchronization = VariableSynchronization.Auto; | |||
| } | |||
| if (trainable is null) | |||
| { | |||
| trainable = synchronization != VariableSynchronization.OnRead; | |||
| } | |||
| return (synchronization.Value, aggregation.Value, trainable.Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="config"></param> | |||
| /// <returns></returns> | |||
| static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
| { | |||
| // Layer instances created during the graph reconstruction process. | |||
| var created_layers = new Dictionary<string, ILayer>(); | |||
| @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine | |||
| Inputs = inputs, | |||
| Outputs = outputs | |||
| }) | |||
| { | |||
| Initialize(inputs, outputs, name); | |||
| } | |||
| internal void Initialize(Tensors inputs, Tensors outputs, string name = null) | |||
| { | |||
| _input_layers = new List<ILayer>(); | |||
| _output_layers = new List<ILayer>(); | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| @@ -75,7 +76,17 @@ namespace Tensorflow.Keras.Engine | |||
| public int Id => id; | |||
| protected string name; | |||
| protected string base_name; | |||
| public string Name => name; | |||
| public string Name | |||
| { | |||
| get | |||
| { | |||
| return name; | |||
| } | |||
| set | |||
| { | |||
| name = value; | |||
| } | |||
| } | |||
| protected bool computePreviousMask; | |||
| protected List<Operation> updates; | |||
| @@ -89,6 +100,8 @@ namespace Tensorflow.Keras.Engine | |||
| List<INode> outboundNodes; | |||
| public List<INode> OutboundNodes => outboundNodes; | |||
| public JObject SerializedAttributes { get; set; } | |||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
| public CallContext CallContext => callContext.Value; | |||
| public Tensor[] input | |||
| @@ -117,6 +130,11 @@ namespace Tensorflow.Keras.Engine | |||
| protected List<ILayer> _self_tracked_trackables; | |||
| public Layer(LayerArgs args) | |||
| { | |||
| Initialize(args); | |||
| } | |||
| internal virtual void Initialize(LayerArgs args) | |||
| { | |||
| this.args = args; | |||
| // A stateful layer is a layer whose updates are run during inference too, | |||
| @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| using (SharedObjectSavingScope.Enter()) | |||
| { | |||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
| KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine | |||
| IVariableV1 _predict_counter; | |||
| bool _base_model_initialized; | |||
| bool stop_training; | |||
| public bool IsGraphNetwork => _is_graph_network; | |||
| public OptimizerV2 Optimizer | |||
| { | |||
| @@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine | |||
| _init_batch_counters(); | |||
| } | |||
| internal override void Initialize(LayerArgs args) | |||
| { | |||
| _init_batch_counters(); | |||
| base.Initialize(args); | |||
| } | |||
| void _configure_steps_per_execution(int steps_per_execution) | |||
| { | |||
| _steps_per_execution = tf.Variable(steps_per_execution, | |||
| @@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine | |||
| : base(args.Inputs, args.Outputs, name: args.Name) | |||
| { | |||
| this.args = args; | |||
| if (args.Layers == null) | |||
| args.Layers = new List<ILayer>(); | |||
| // SupportsMasking = true; | |||
| _compute_output_and_mask_jointly = true; | |||
| _auto_track_sub_layers = false; | |||
| @@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine | |||
| _created_nodes = new List<INode>(); | |||
| // Add to the model any layers passed to the constructor. | |||
| if (args.Layers != null) | |||
| if (args.Layers is not null) | |||
| { | |||
| foreach (var layer in args.Layers) | |||
| add(layer); | |||
| InitLayers(args.Layers); | |||
| } | |||
| } | |||
| public void InitLayers(IEnumerable<ILayer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| add(layer); | |||
| } | |||
| } | |||
| @@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||
| { | |||
| throw new ValueError("Alpha must be a number greater than 0."); | |||
| } | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| @@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||
| } | |||
| public override void build(Shape input_shape) | |||
| { | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| @@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { | |||
| if ( alpha < 0f ) { | |||
| throw new ValueError("Alpha must be a number greater than 0."); | |||
| } | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| Tensor output = inputs; | |||
| @@ -1,12 +1,24 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.ApiDef.Types; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| @@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| public class KerasObjectLoader | |||
| { | |||
| SavedMetadata _metadata; | |||
| SavedObjectGraph _proto; | |||
| Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||
| List<int> _traversed_nodes_from_config = new List<int>(); | |||
| private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
| private SavedMetadata _metadata; | |||
| private SavedObjectGraph _proto; | |||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>(); | |||
| private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>(); | |||
| private List<int> _traversed_nodes_from_config = new List<int>(); | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes; | |||
| private List<int> _models_to_reconstruct; | |||
| public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes; | |||
| static KerasObjectLoader() | |||
| { | |||
| PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||
| } | |||
| public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | |||
| { | |||
| _metadata = metadata; | |||
| _proto = object_graph_def; | |||
| _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | |||
| _models_to_reconstruct = new List<int>(); | |||
| loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
| } | |||
| /// <summary> | |||
| @@ -42,15 +66,251 @@ namespace Tensorflow.Keras.Saving | |||
| continue; | |||
| } | |||
| _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
| } | |||
| foreach(var node_metadata in metric_list) | |||
| { | |||
| try | |||
| { | |||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | |||
| node_metadata.Metadata); | |||
| } | |||
| catch(ValueError e) | |||
| { | |||
| if (compile) | |||
| { | |||
| throw e; | |||
| } | |||
| // TODO: add logging.warning. | |||
| } | |||
| } | |||
| } | |||
| public string get_path(int node_id) | |||
| { | |||
| return _node_paths[node_id]; | |||
| } | |||
| /// <summary> | |||
| /// Finish setting up Keras objects. | |||
| /// | |||
| /// This function is executed after all objects and functions have been created. | |||
| /// Call functions and losses are attached to each layer, and once all layers | |||
| /// have been fully set up, graph networks are initialized. | |||
| /// | |||
| /// Subclassed models that are revived from the SavedModel are treated like | |||
| /// layers, and have their call/loss functions attached here. | |||
| /// </summary> | |||
| public void finalize_objects() | |||
| { | |||
| List<Layer> layers_revived_from_config = new(); | |||
| List<Layer> layers_revived_from_saved_model = new(); | |||
| foreach(var item in loaded_nodes) | |||
| { | |||
| var node_id = item.Key; | |||
| var node = item.Value.Item1; | |||
| if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| _unblock_model_reconstruction(node_id, node as Layer); | |||
| if(node is InputLayer or Metric) | |||
| { | |||
| continue; | |||
| } | |||
| // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||
| layers_revived_from_config.Add(node as Layer); | |||
| } | |||
| _finalize_saved_model_layers(layers_revived_from_saved_model); | |||
| _finalize_config_layers(layers_revived_from_config); | |||
| _reconstruct_all_models(); | |||
| } | |||
| private void _reconstruct_all_models() | |||
| { | |||
| HashSet<int> all_initialized_models = new(); | |||
| for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) | |||
| { | |||
| int model_id = _models_to_reconstruct[i]; | |||
| all_initialized_models.Add(model_id); | |||
| var (model, layers) = model_layer_dependencies[model_id]; | |||
| _reconstruct_model(model_id, model, layers.ToList()); | |||
| _finalize_config_layers(new List<Layer>() { model }); | |||
| } | |||
| Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); | |||
| } | |||
| private void _reconstruct_model(int model_id, Model model, List<Layer> layers) | |||
| { | |||
| var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"]; | |||
| if(model.input is not null && model.input.Length > 0) | |||
| { | |||
| } | |||
| else if(model is Sequential s) | |||
| { | |||
| if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) | |||
| { | |||
| if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||
| { | |||
| layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||
| } | |||
| else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||
| { | |||
| // TODO: implement it | |||
| } | |||
| } | |||
| // `model.__init__(layers, config["name"])` | |||
| s.InitLayers(layers); | |||
| s.Name = config["name"].ToObject<string>(); | |||
| if(s.input is null || s.input.Length == 0) | |||
| { | |||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
| var input_specs = _infer_inputs(first_layer); | |||
| var input_shapes = _infer_inputs(first_layer, true); | |||
| // `model._set_inputs(input_specs)` | |||
| // skip the check of input_specs is Dictionary | |||
| if (!s.Built) | |||
| { | |||
| s.build(input_shapes); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // skip the parameter `created_layers`. | |||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>()); | |||
| // skip the `model.__init__` | |||
| (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||
| (model as Functional).connect_ancillary_layers(created_layers); | |||
| } | |||
| _set_network_attributes_from_metadata(model); | |||
| _unblock_model_reconstruction(model_id, model); | |||
| } | |||
| void _load_layer(int node_id, string identifier, string metadata_json) | |||
| private void _set_network_attributes_from_metadata(Model revived_object) | |||
| { | |||
| // TODO: implement it. | |||
| } | |||
| /// <summary> | |||
| /// Runs the final steps of loading Keras Layers from config. | |||
| /// </summary> | |||
| /// <param name="layers"></param> | |||
| private void _finalize_config_layers(List<Layer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| if (_is_graph_network(layer)) | |||
| { | |||
| _restore_layer_unconditional_losses(layer); | |||
| } | |||
| _restore_layer_activation_loss(layer); | |||
| _restore_layer_metrics(layer); | |||
| // TODO(Rinne): deal with RNN. | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Runs the final steps of loading Keras Layers from SavedModel. | |||
| /// </summary> | |||
| /// <param name="layers"></param> | |||
| private void _finalize_saved_model_layers(List<Layer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| // TODO(Rinne): deal with `RevivedNetwork`. | |||
| _restore_layer_unconditional_losses(layer); | |||
| _restore_layer_activation_loss(layer); | |||
| _restore_layer_metrics(layer); | |||
| } | |||
| } | |||
| private void _restore_layer_unconditional_losses(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| private void _restore_layer_activation_loss(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| private void _restore_layer_metrics(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| /// <summary> | |||
| /// Removes layer from blocking model reconstruction. | |||
| /// </summary> | |||
| /// <param name="layer_id"></param> | |||
| /// <param name="layer"></param> | |||
| private void _unblock_model_reconstruction(int layer_id, Layer layer) | |||
| { | |||
| foreach(var depencency in model_layer_ids_dependencies) | |||
| { | |||
| var layer_ids = depencency.Value.Item2; | |||
| var layers = model_layer_dependencies.SetDefault(depencency.Key, | |||
| (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; | |||
| if (!layer_ids.Contains(layer_id)) | |||
| { | |||
| continue; | |||
| } | |||
| layers[Array.IndexOf(layer_ids, layer_id)] = layer; | |||
| if (layers.All(x => x is not null)) | |||
| { | |||
| _models_to_reconstruct.Add(depencency.Key); | |||
| } | |||
| } | |||
| } | |||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
| { | |||
| metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| _revive_from_config(identifier, metadata, node_id); | |||
| if (loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| var (node, setter) = loaded_nodes[node_id]; | |||
| _maybe_add_serialized_attributes(node as Layer, metadata); | |||
| var config = metadata.Config; | |||
| if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) | |||
| { | |||
| Debug.Assert(node is Model); | |||
| var child_nodes = _get_child_layer_node_ids(node_id); | |||
| model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); | |||
| if(child_nodes is null || child_nodes.Length == 0) | |||
| { | |||
| _models_to_reconstruct.Add(node_id); | |||
| } | |||
| } | |||
| return (node, setter); | |||
| } | |||
| else | |||
| { | |||
| var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||
| if (obj is null) | |||
| { | |||
| (obj, setter) = _revive_custom_object(identifier, metadata); | |||
| } | |||
| Debug.Assert(obj is Layer); | |||
| _maybe_add_serialized_attributes(obj as Layer, metadata); | |||
| return (obj, setter); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -59,11 +319,32 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="identifier"></param> | |||
| /// <param name="metadata"></param> | |||
| /// <param name="node_id"></param> | |||
| void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
| private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
| { | |||
| var obj = _revive_graph_network(identifier, metadata, node_id); | |||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
| Trackable obj; | |||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| else | |||
| { | |||
| obj = _revive_graph_network(identifier, metadata, node_id); | |||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
| } | |||
| if(obj is null) | |||
| { | |||
| return (null, null); | |||
| } | |||
| var setter = _config_node_setter(_revive_setter); | |||
| _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | |||
| return (obj, setter); | |||
| } | |||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
| { | |||
| // TODO: implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||
| @@ -71,6 +352,12 @@ namespace Tensorflow.Keras.Saving | |||
| var config = metadata.Config; | |||
| var class_name = metadata.ClassName; | |||
| Model model = null; | |||
| if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") | |||
| { | |||
| return null; | |||
| } | |||
| if (class_name == "Sequential") | |||
| { | |||
| model = new Sequential(new SequentialArgs | |||
| @@ -78,34 +365,83 @@ namespace Tensorflow.Keras.Saving | |||
| Name = config.GetValue("name").ToString() | |||
| }); | |||
| } | |||
| else if (class_name == "Functional") | |||
| else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| model = model = new Sequential(new SequentialArgs | |||
| { | |||
| Name = class_name | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| // TODO: implement it. | |||
| throw new NotImplementedException("Not implemented"); | |||
| } | |||
| if (!metadata.IsGraphNetwork) | |||
| return null; | |||
| // Record this model and its layers. This will later be used to reconstruct | |||
| // the model. | |||
| var layers = _get_child_layer_node_ids(node_id); | |||
| model_layer_dependencies[node_id] = (model, layers); | |||
| model_layer_ids_dependencies[node_id] = (model, layers); | |||
| if(layers is null || layers.Length == 0) | |||
| { | |||
| _models_to_reconstruct.Add(node_id); | |||
| } | |||
| return model; | |||
| } | |||
| Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
| Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
| { | |||
| var config = metadata.Config; | |||
| var class_name = metadata.ClassName; | |||
| var shared_object_id = metadata.SharedObjectId; | |||
| var must_restore_from_config = metadata.MustRestoreFromConfig; | |||
| var obj = class_name switch | |||
| { | |||
| "Resizing" => Resizing.from_config(config), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| var obj = generic_utils.deserialize_keras_object(class_name, config); | |||
| obj.Name = metadata.Name; | |||
| // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
| return null; | |||
| if (!built) | |||
| { | |||
| return null; | |||
| } | |||
| return obj; | |||
| } | |||
| private void _revive_setter(object layer, object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| Debug.Assert(layer is Layer); | |||
| if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
| { | |||
| if(value is Trackable) | |||
| { | |||
| (layer as Layer)._track_trackable(value as Trackable, name as string); | |||
| } | |||
| if((layer as Layer).SerializedAttributes is null) | |||
| { | |||
| (layer as Layer).SerializedAttributes = new JObject(); | |||
| } | |||
| (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||
| } | |||
| else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| { | |||
| (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||
| } | |||
| else | |||
| { | |||
| var properties = layer.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| { | |||
| if(p.Name == name as string && p.GetValue(layer) is not null) | |||
| { | |||
| return; | |||
| } | |||
| } | |||
| Loader.setattr(layer, name, value); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -143,34 +479,186 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="obj"></param> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||
| void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) | |||
| { | |||
| if (_traversed_nodes_from_config.Contains(node_id)) | |||
| return; | |||
| var parent_path = _node_paths[node_id]; | |||
| _traversed_nodes_from_config.Add(node_id); | |||
| if (!obj.Built) | |||
| obj._maybe_initialize_trackable(); | |||
| if(obj is Layer layer && !layer.Built) | |||
| { | |||
| var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata); | |||
| _try_build_layer(layer, node_id, metadata.BuildInputShape); | |||
| } | |||
| List<(Trackable, int, string)> children = new(); | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| var obj_child = obj._lookup_dependency(refer.LocalName); | |||
| children.Add((obj_child, refer.NodeId, refer.LocalName)); | |||
| } | |||
| var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||
| Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||
| }); | |||
| if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||
| { | |||
| var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); | |||
| foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) | |||
| { | |||
| if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) | |||
| { | |||
| var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; | |||
| children.Add((metric, refer.NodeId, metric_path)); | |||
| } | |||
| } | |||
| } | |||
| foreach(var (obj_child, child_id, child_name) in children) | |||
| { | |||
| if(obj_child is null) | |||
| { | |||
| continue; | |||
| } | |||
| var child_proto = _proto.Nodes[child_id]; | |||
| // skip the check for registered identifier | |||
| Action<object, object, object> setter; | |||
| if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||
| { | |||
| setter = _revive_setter; | |||
| } | |||
| else | |||
| { | |||
| setter = Loader.setattr; | |||
| } | |||
| if (loaded_nodes.ContainsKey(child_id)) | |||
| { | |||
| // skip the logging.warning | |||
| continue; | |||
| } | |||
| if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) | |||
| { | |||
| (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; | |||
| } | |||
| if(obj_child is TrackableDataStructure) | |||
| { | |||
| setter = (x, y, z) => { }; | |||
| } | |||
| var child_path = $"{parent_path}.{child_name}"; | |||
| _node_paths[child_id] = child_path; | |||
| _add_children_recreated_from_config(obj_child, child_proto, child_id); | |||
| loaded_nodes[child_id] = (obj_child, setter); | |||
| } | |||
| } | |||
| bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) | |||
| private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
| { | |||
| if (obj.Built) | |||
| return true; | |||
| if(build_input_shape is null) | |||
| { | |||
| build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||
| } | |||
| if(build_input_shape is not null) | |||
| { | |||
| obj.build(build_input_shape); | |||
| // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. | |||
| // On the one hand, C# does not support call a method from specified parent class. | |||
| // On the other hand, currently All class derived from Layer call `Layer.Build` or | |||
| // move the implementation of `Layer.build` to its own `build` method. | |||
| // Therefore we do not call it here. | |||
| // However, it's still quite risky once in the future a certain class derived from | |||
| // `Layer` does not call `Layer.build`. | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
| /// <summary> | |||
| /// Infers input shape of layer from SavedModel functions. | |||
| /// </summary> | |||
| /// <param name="layer_node_id"></param> | |||
| /// <param name="convert_to_shapes"></param> | |||
| /// <returns></returns> | |||
| private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||
| { | |||
| if (obj.Built) | |||
| return true; | |||
| var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||
| if(call_fn_id is null) | |||
| { | |||
| return null; | |||
| } | |||
| var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; | |||
| if(concrete_functions is null) | |||
| { | |||
| return null; | |||
| } | |||
| var call_fn_name = concrete_functions[0]; | |||
| var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
| { | |||
| if(path_to_child is null || path_to_child.Count() == 0) | |||
| { | |||
| return parent_id; | |||
| } | |||
| foreach(var child in _proto.Nodes[parent_id].Children) | |||
| { | |||
| if(child.LocalName == path_to_child.First()) | |||
| { | |||
| return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); | |||
| } | |||
| } | |||
| return null; | |||
| } | |||
| private bool _is_graph_network(Layer layer) | |||
| { | |||
| // TODO: deal with `RevivedLayer` | |||
| if(layer is Functional) | |||
| { | |||
| return (layer as Functional).IsGraphNetwork || layer is Sequential; | |||
| } | |||
| return false; | |||
| } | |||
| private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||
| { | |||
| // TODO: deal with `RevivedLayer` | |||
| } | |||
| /// <summary> | |||
| /// Creates edges for nodes that are recreated from config. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private Action<object, object, object> _config_node_setter(Action<object, object, object> setter) | |||
| { | |||
| void setattr_wrapper(object obj, object name, object value) | |||
| { | |||
| Debug.Assert(obj is Trackable); | |||
| Debug.Assert(name is string); | |||
| if((obj as Trackable)._lookup_dependency(name as string) is null) | |||
| { | |||
| setter(obj, name, value); | |||
| } | |||
| } | |||
| return setattr_wrapper; | |||
| } | |||
| } | |||
| } | |||
| @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
| public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
| SaveOptions? options, bool save_traces = true) | |||
| { | |||
| if (!overwrite && File.Exists(filepath)) | |||
| @@ -0,0 +1,96 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class KerasLoadModelUtils | |||
| { | |||
| /// <summary> | |||
| /// Corresponding to keras/saving/save.py/load_model | |||
| /// </summary> | |||
| /// <param name="filepath"></param> | |||
| /// <param name="custom_objects"></param> | |||
| /// <param name="compile"></param> | |||
| /// <param name="options"></param> | |||
| /// <returns></returns> | |||
| public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||
| bool compile = true, LoadOptions? options = null) | |||
| { | |||
| using (SharedObjectSavingScope.Enter()) | |||
| { | |||
| using (LoadContext.load_context(options)) | |||
| { | |||
| if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||
| { | |||
| throw new IOException($"No file or directory found at {filepath}."); | |||
| } | |||
| if (Directory.Exists(filepath)) | |||
| { | |||
| return load(filepath, compile, options); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| public static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
| { | |||
| SavedMetadata metadata = new SavedMetadata(); | |||
| var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
| string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||
| if (File.Exists(path_to_metadata_pb)) | |||
| { | |||
| metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||
| { | |||
| return Loader.load(path, options: options) as Model; | |||
| } | |||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
| keras_loader.load_layers(compile: compile); | |||
| Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||
| nodes_to_load["root"] = (null, null); | |||
| foreach(var item in keras_loader.LoadedNodes) | |||
| { | |||
| nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||
| } | |||
| var loaded = Loader.load_partial(path, nodes_to_load, options); | |||
| keras_loader.finalize_objects(); | |||
| // keras_loader.del_tracking(); | |||
| var model = loaded["root"]; | |||
| if(model is Model && compile) | |||
| { | |||
| // TODO: implement it. | |||
| } | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| // TODO: implement it. | |||
| } | |||
| return model; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| // TODO: remove this class to common project. | |||
| public class ContextHandler: IDisposable | |||
| { | |||
| public Action<bool> DisposeCallBack { get; set; } | |||
| public void Dispose() | |||
| { | |||
| DisposeCallBack.Invoke(true); | |||
| } | |||
| } | |||
| public class LoadContext | |||
| { | |||
| private bool _entered_load_context; | |||
| private LoadOptions? _load_options; | |||
| private static ThreadLocal<LoadContext> _load_context = new(); | |||
| private LoadContext() | |||
| { | |||
| _entered_load_context = false; | |||
| _load_options = null; | |||
| } | |||
| public void set_load_options(LoadOptions load_options) | |||
| { | |||
| _load_options = load_options; | |||
| _entered_load_context = true; | |||
| } | |||
| private void clear_load_options() | |||
| { | |||
| _load_options = null; | |||
| _entered_load_context = false; | |||
| } | |||
| private LoadOptions? load_options() | |||
| { | |||
| return _load_options; | |||
| } | |||
| public static ContextHandler load_context(LoadOptions? load_options) | |||
| { | |||
| if(_load_context.Value is null) | |||
| { | |||
| _load_context.Value = new LoadContext(); | |||
| } | |||
| _load_context.Value.set_load_options(load_options); | |||
| return new ContextHandler() | |||
| { | |||
| DisposeCallBack = _ => _load_context.Value.clear_load_options() | |||
| }; | |||
| } | |||
| public static LoadOptions? get_load_option() | |||
| { | |||
| return _load_context.Value.load_options(); | |||
| } | |||
| public static bool in_load_context() | |||
| { | |||
| return _load_context.Value._entered_load_context; | |||
| } | |||
| } | |||
| } | |||
| @@ -22,12 +22,16 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class generic_utils | |||
| { | |||
| private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; | |||
| /// <summary> | |||
| /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | |||
| /// </summary> | |||
| @@ -51,6 +55,21 @@ namespace Tensorflow.Keras.Utils | |||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | |||
| } | |||
| public static Layer deserialize_keras_object(string class_name, JObject config) | |||
| { | |||
| return class_name switch | |||
| { | |||
| "Sequential" => new Sequential(config.ToObject<SequentialArgs>()), | |||
| "InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()), | |||
| "Flatten" => new Flatten(config.ToObject<FlattenArgs>()), | |||
| "ELU" => new ELU(config.ToObject<ELUArgs>()), | |||
| "Dense" => new Dense(config.ToObject<DenseArgs>()), | |||
| "Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()), | |||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||
| }; | |||
| } | |||
| public static string to_snake_case(string name) | |||
| { | |||
| return string.Concat(name.Select((x, i) => | |||
| @@ -60,5 +79,15 @@ namespace Tensorflow.Keras.Utils | |||
| x.ToString(); | |||
| })).ToLower(); | |||
| } | |||
| /// <summary> | |||
| /// Determines whether config appears to be a valid layer config. | |||
| /// </summary> | |||
| /// <param name="config"></param> | |||
| /// <returns></returns> | |||
| public static bool validate_config(JObject config) | |||
| { | |||
| return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using static Tensorflow.KerasApi; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| [TestClass] | |||
| public class SequentialModelLoad | |||
| { | |||
| [TestMethod] | |||
| public void SimpleModelFromSequential() | |||
| { | |||
| var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||
| Debug.Assert(model is Model); | |||
| var m = model as Model; | |||
| m.summary(); | |||
| m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| var data_loader = new MnistModelLoader(); | |||
| var num_epochs = 1; | |||
| var batch_size = 50; | |||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = false, | |||
| ValidationSize = 50000, | |||
| }).Result; | |||
| m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| } | |||
| } | |||
| @@ -63,6 +63,8 @@ public class SequentialModelTest | |||
| keras.layers.Softmax(1) | |||
| }); | |||
| model.summary(); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| var data_loader = new MnistModelLoader(); | |||