| @@ -149,4 +149,13 @@ public static class CheckPointUtils | |||||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | ||||
| // } | // } | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Traverse the object graph and list all accessible objects. | |||||
| /// </summary> | |||||
| /// <param name="object_graph_view"></param> | |||||
| public static IList<Trackable> list_objects(ObjectGraphView graph_view) | |||||
| { | |||||
| return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||||
| } | |||||
| } | } | ||||
| @@ -6,8 +6,10 @@ using System.Linq; | |||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Exceptions; | |||||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Operations; | |||||
| namespace Tensorflow.Checkpoint; | namespace Tensorflow.Checkpoint; | ||||
| @@ -21,8 +23,20 @@ public class TrackableSaver | |||||
| private TrackableObjectGraph _last_save_object_graph; | private TrackableObjectGraph _last_save_object_graph; | ||||
| private Tensor? _object_graph_feed_tensor = null; | private Tensor? _object_graph_feed_tensor = null; | ||||
| private Tensor? _file_prefix_feed_tensor = null; | private Tensor? _file_prefix_feed_tensor = null; | ||||
| private Tensor? _file_prefix_placeholder = null; | |||||
| private Dictionary<Trackable, Trackable>? _object_map = null; | private Dictionary<Trackable, Trackable>? _object_map = null; | ||||
| private object? _cache = null; | private object? _cache = null; | ||||
| public Tensor? FilePrefixPlaceHolder | |||||
| { | |||||
| get | |||||
| { | |||||
| return _file_prefix_placeholder; | |||||
| } | |||||
| set | |||||
| { | |||||
| _file_prefix_placeholder = value; | |||||
| } | |||||
| } | |||||
| public TrackableSaver(ObjectGraphView graph_view) | public TrackableSaver(ObjectGraphView graph_view) | ||||
| { | { | ||||
| _graph_view = graph_view; | _graph_view = graph_view; | ||||
| @@ -192,4 +206,207 @@ public class TrackableSaver | |||||
| return save_path; | return save_path; | ||||
| } | } | ||||
| } | } | ||||
| public LoadStatus restore(string? save_path, CheckpointOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new CheckpointOptions(); | |||||
| } | |||||
| if(save_path is null) | |||||
| { | |||||
| return new InitializationOnlyStatus(_graph_view, ops.uid()); | |||||
| } | |||||
| CheckpointReader reader = new CheckpointReader(save_path); | |||||
| bool graph_building = tf.Context.executing_eagerly(); | |||||
| Dictionary<string, TF_DataType> dtype_map = null; | |||||
| if (!graph_building) | |||||
| { | |||||
| dtype_map = reader.VariableToDataTypeMap; | |||||
| } | |||||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); | |||||
| Dictionary<Tensor, string> file_prefix_feed_dict; | |||||
| Tensor file_prefix_tensor; | |||||
| if (graph_building) | |||||
| { | |||||
| if(_file_prefix_placeholder is null) | |||||
| { | |||||
| tf.device("/cpu:0"); | |||||
| _file_prefix_placeholder = constant_op.constant("model"); | |||||
| } | |||||
| file_prefix_tensor = _file_prefix_placeholder; | |||||
| file_prefix_feed_dict = new(); | |||||
| file_prefix_feed_dict[_file_prefix_placeholder] = save_path; | |||||
| } | |||||
| else | |||||
| { | |||||
| tf.device("/cpu:0"); | |||||
| file_prefix_tensor = constant_op.constant(save_path); | |||||
| file_prefix_feed_dict = null; | |||||
| } | |||||
| TrackableObjectGraph object_graph_proto = new(); | |||||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||||
| CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||||
| object_graph_proto: object_graph_proto, | |||||
| save_path: save_path, | |||||
| save_path_tensor: file_prefix_tensor, | |||||
| reader: reader, | |||||
| restore_op_cache: null, | |||||
| graph_view: _graph_view, | |||||
| options: options, | |||||
| saveables_cache: null | |||||
| ); | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| internal class CheckpointRestoreCoordinator | |||||
| { | |||||
| private CheckpointOptions _options; | |||||
| private TrackableObjectGraph _object_graph_proto; | |||||
| private int _restore_uid; | |||||
| private HashSet<int> _matched_proto_ids; | |||||
| private Tensor _save_path_tensor; | |||||
| private string _save_path_string; | |||||
| private CheckpointReader _reader; | |||||
| private Dictionary<string, TF_DataType> _dtype_map; | |||||
| private Dictionary<string, Shape> _shape_map; | |||||
| private ObjectGraphView _graph_view; | |||||
| private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations; | |||||
| private bool _expect_partial_attr; | |||||
| private List<Operation> _restore_ops; | |||||
| private List<Trackable> _all_trackables; | |||||
| private Dictionary<int, Trackable> _object_by_proto_id; | |||||
| public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||||
| CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||||
| { | |||||
| // TODO(Rinne): cache. | |||||
| _options = options; | |||||
| _object_graph_proto = object_graph_proto; | |||||
| _restore_uid = ops.uid(); | |||||
| _save_path_tensor = save_path_tensor; | |||||
| _save_path_string = save_path; | |||||
| _reader = reader; | |||||
| if(_reader is null) | |||||
| { | |||||
| _reader = new CheckpointReader(save_path); | |||||
| } | |||||
| _dtype_map = _reader.VariableToDataTypeMap; | |||||
| _shape_map = _reader.VariableToShapeMap; | |||||
| _graph_view = graph_view; | |||||
| _restore_ops = new List<Operation>(); | |||||
| _all_trackables = new List<Trackable>(); | |||||
| _matched_proto_ids = new HashSet<int>(); | |||||
| _object_by_proto_id = new Dictionary<int, Trackable>(); | |||||
| _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||||
| _expect_partial_attr = false; | |||||
| for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||||
| { | |||||
| var node = _object_graph_proto.Nodes[i]; | |||||
| foreach(var slot_reference in node.SlotVariables) | |||||
| { | |||||
| _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>()) | |||||
| .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); | |||||
| } | |||||
| } | |||||
| // skip the deleter and cache. | |||||
| } | |||||
| public bool ExpectPartial | |||||
| { | |||||
| get | |||||
| { | |||||
| return _expect_partial_attr; | |||||
| } | |||||
| set | |||||
| { | |||||
| _expect_partial_attr = value; | |||||
| } | |||||
| } | |||||
| public List<Trackable> AllTrackables => _all_trackables; | |||||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||||
| public int RestoreUid => _restore_uid; | |||||
| public void new_restore_ops(IEnumerable<Operation> new_ops) | |||||
| { | |||||
| _restore_ops.AddRange(new_ops); | |||||
| // skip the callback. | |||||
| } | |||||
| public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| public abstract class LoadStatus | |||||
| { | |||||
| public abstract void assert_consumed(); | |||||
| public abstract void assert_existing_objects_matched(); | |||||
| public abstract void assert_nontrivial_match(); | |||||
| public abstract void run_restore_ops(Session? session = null); | |||||
| public abstract void initialize_or_restore(Session? session = null); | |||||
| public virtual LoadStatus expect_partial() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| } | |||||
| public class InitializationOnlyStatus: LoadStatus | |||||
| { | |||||
| private int _restore_uid; | |||||
| private ObjectGraphView _object_graph_view; | |||||
| private Trackable _root; | |||||
| public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) | |||||
| { | |||||
| _restore_uid = restore_uid; | |||||
| _object_graph_view = object_graph_view; | |||||
| _root = object_graph_view.Root; | |||||
| } | |||||
| public override void assert_consumed() | |||||
| { | |||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
| } | |||||
| public override void assert_existing_objects_matched() | |||||
| { | |||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
| } | |||||
| public override void assert_nontrivial_match() | |||||
| { | |||||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||||
| } | |||||
| public override void run_restore_ops(Session? session = null) | |||||
| { | |||||
| throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||||
| + "(save_path=None to Saver.restore)."); | |||||
| } | |||||
| public override void initialize_or_restore(Session? session = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| return; | |||||
| } | |||||
| if(session is null) | |||||
| { | |||||
| session = new Session(); | |||||
| } | |||||
| var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| } | |||||
| public class CheckpointLoadStatus | |||||
| { | |||||
| public CheckpointLoadStatus() | |||||
| { | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,80 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Checkpoint; | |||||
| internal class CheckpointPosition | |||||
| { | |||||
| private CheckpointRestoreCoordinator _checkpoint; | |||||
| private int _proto_id; | |||||
| private bool _skip_restore; | |||||
| public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) | |||||
| { | |||||
| _checkpoint = checkpoint; | |||||
| _proto_id = proto_id; | |||||
| _skip_restore = false; | |||||
| } | |||||
| public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||||
| public void restore(Trackable trackable) | |||||
| { | |||||
| using (ops.init_scope()) | |||||
| { | |||||
| if (bind_project(trackable)) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Set a checkpoint<->object correspondence. | |||||
| /// </summary> | |||||
| /// <param name="trackable"></param> | |||||
| /// <returns></returns> | |||||
| public bool bind_project(Trackable trackable) | |||||
| { | |||||
| _checkpoint.AllTrackables.Add(trackable); | |||||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | |||||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||||
| { | |||||
| // skip the `logging.warning`. | |||||
| return false; | |||||
| } | |||||
| else | |||||
| { | |||||
| _checkpoint.ObjectByProtoId[_proto_id] = trackable; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| public void gather_ops_or_named_saveables() | |||||
| { | |||||
| // skip the registered_saver | |||||
| } | |||||
| /// <summary> | |||||
| /// Restore the bound Trackable and dependencies (may be deferred). | |||||
| /// </summary> | |||||
| private void _restore_descendants() | |||||
| { | |||||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||||
| visit_queue.Enqueue((this, this.Trackable)); | |||||
| } | |||||
| private void _single_restore() | |||||
| { | |||||
| var trackable = this.Trackable; | |||||
| trackable._maybe_initialize_trackable(); | |||||
| if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,8 +16,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.IO | namespace Tensorflow.IO | ||||
| { | { | ||||
| @@ -63,5 +65,15 @@ namespace Tensorflow.IO | |||||
| dirs.AddRange(Directory.GetFiles(dir)); | dirs.AddRange(Directory.GetFiles(dir)); | ||||
| return dirs.ToArray(); | return dirs.ToArray(); | ||||
| } | } | ||||
| public string join(params string[] paths) | |||||
| { | |||||
| Debug.Assert(paths.Length >= 1); | |||||
| if (paths[0].Substring(1).Contains("://")) | |||||
| { | |||||
| throw new NotImplementedException("The combination of urls has not been implemented."); | |||||
| } | |||||
| return Path.Combine(paths); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| var axis = serializer.Deserialize(reader, typeof(long[])); | |||||
| var axis = serializer.Deserialize(reader, typeof(int[])); | |||||
| if (axis is null) | if (axis is null) | ||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | throw new ValueError("Cannot deserialize 'null' to `Axis`."); | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| namespace Tensorflow.ModelSaving | namespace Tensorflow.ModelSaving | ||||
| { | { | ||||
| @@ -17,8 +17,8 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
| using static Tensorflow.CppShapeInferenceResult.Types; | using static Tensorflow.CppShapeInferenceResult.Types; | ||||
| @@ -0,0 +1,23 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public record class LoadOptions | |||||
| { | |||||
| public bool allow_partial_checkpoint; | |||||
| public string experimental_io_device; | |||||
| public bool experimental_skip_checkpoint; | |||||
| public VariablePolicy experimental_variable_policy; | |||||
| public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, | |||||
| bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) | |||||
| { | |||||
| this.allow_partial_checkpoint = allow_partial_checkpoint; | |||||
| this.experimental_io_device = experimental_io_device; | |||||
| this.experimental_skip_checkpoint = experimental_skip_checkpoint; | |||||
| this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.Train; | |||||
| using System; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow; | namespace Tensorflow; | ||||
| @@ -14,4 +15,10 @@ public class RevivedTypes | |||||
| // TODO: complete the implementation. | // TODO: complete the implementation. | ||||
| return null; | return null; | ||||
| } | } | ||||
| public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||||
| { | |||||
| // TODO: complete the implementation. | |||||
| return null; | |||||
| } | |||||
| } | } | ||||
| @@ -2,7 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.ModelSaving | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Options for saving to SavedModel. | /// Options for saving to SavedModel. | ||||
| @@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving | |||||
| public bool save_variable_devices() | public bool save_variable_devices() | ||||
| { | { | ||||
| return this != VariablePolicy.None; | |||||
| return this != None; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static VariablePolicy from_obj(object obj) | public static VariablePolicy from_obj(object obj) | ||||
| { | { | ||||
| if (obj is null) return VariablePolicy.None; | |||||
| if (obj is null) return None; | |||||
| if (obj is VariablePolicy) return (VariablePolicy)obj; | if (obj is VariablePolicy) return (VariablePolicy)obj; | ||||
| var key = obj.ToString().ToLower(); | var key = obj.ToString().ToLower(); | ||||
| return key switch | return key switch | ||||
| { | { | ||||
| null => VariablePolicy.None, | |||||
| "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||||
| "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||||
| null => None, | |||||
| "save_variable_devices" => SAVE_VARIABLE_DEVICES, | |||||
| "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, | |||||
| _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -5,7 +5,6 @@ using System.Linq; | |||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using pbc = global::Google.Protobuf.Collections; | using pbc = global::Google.Protobuf.Collections; | ||||
| @@ -0,0 +1,601 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using System.Net.Sockets; | |||||
| using System.Text; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using pbc = global::Google.Protobuf.Collections; | |||||
| using static Tensorflow.Binding; | |||||
| using System.Runtime.CompilerServices; | |||||
| using Tensorflow.Variables; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// Helper class to load an object-based SavedModel. | |||||
| /// </summary> | |||||
| public partial class Loader | |||||
| { | |||||
| private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def; | |||||
| private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes; | |||||
| private SavedObjectGraph _proto; | |||||
| private string _export_dir; | |||||
| private CheckpointOptions _checkpoint_options; | |||||
| private LoadOptions _save_options; | |||||
| private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters; | |||||
| private Dictionary<string, int>? _node_path_to_id; | |||||
| private List<int>? _filtered_nodes; | |||||
| private List<int> _ordered_node_ids; | |||||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||||
| private List<Trackable> _nodes; | |||||
| private Dictionary<int, Action<object, object, object>> _node_setters; | |||||
| public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||||
| CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||||
| { | |||||
| var meta_graph = saved_model_proto.MetaGraphs[0]; | |||||
| _asset_file_def = meta_graph.AssetFileDef; | |||||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||||
| _proto = object_graph_proto; | |||||
| _export_dir = export_dir; | |||||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||||
| _checkpoint_options = ckpt_options; | |||||
| _save_options = save_options; | |||||
| // TODO: `this._pretty_printer` | |||||
| _node_filters = filters; | |||||
| _node_path_to_id = _convert_node_paths_to_ints(); | |||||
| _loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||||
| foreach(var filter in filters) | |||||
| { | |||||
| _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; | |||||
| } | |||||
| _filtered_nodes = _retrieve_all_filtered_nodes(); | |||||
| _ordered_node_ids = _generate_ordered_node_ids(); | |||||
| _load_all(); | |||||
| if (!save_options.experimental_skip_checkpoint) | |||||
| { | |||||
| // TODO: implement it. | |||||
| } | |||||
| foreach(var node in _nodes) | |||||
| { | |||||
| // skip the process of `CapturableResource`. | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Maps all string node paths in node_filters to the int node ids. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| private Dictionary<string, int>? _convert_node_paths_to_ints() | |||||
| { | |||||
| if( _node_filters is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| Dictionary<string, int> path_to_int = new(); | |||||
| foreach(var node_id in _node_filters.Keys) | |||||
| { | |||||
| int int_node_id; | |||||
| var node_path = node_id.Split('.'); | |||||
| if (node_path[0] != "root") | |||||
| { | |||||
| throw new ValueError($"When passing string identifiers to node_filters, the first name" + | |||||
| $" must be root. Received {node_path[0]}."); | |||||
| } | |||||
| int_node_id = 0; | |||||
| for(int i = 0; i < node_path.Length - 1; i++) | |||||
| { | |||||
| var name = node_path[i + 1]; | |||||
| int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); | |||||
| } | |||||
| path_to_int[node_id] = int_node_id; | |||||
| } | |||||
| return path_to_int; | |||||
| } | |||||
| private int _find_node_child(int node_id, string child_name, string path) | |||||
| { | |||||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||||
| { | |||||
| if(refer.LocalName == child_name) | |||||
| { | |||||
| return refer.NodeId; | |||||
| } | |||||
| } | |||||
| throw new ValueError($"Unable to find node {path}."); | |||||
| } | |||||
| private List<int>? _retrieve_all_filtered_nodes() | |||||
| { | |||||
| if(_node_filters is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| HashSet<int> all_filtered_nodes = new(); | |||||
| Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys); | |||||
| while(nodes_to_visit.Count > 0) | |||||
| { | |||||
| var node_path = nodes_to_visit.Dequeue(); | |||||
| var node_id = _node_path_to_id[node_path]; | |||||
| if (all_filtered_nodes.Contains(node_id)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| all_filtered_nodes.Add(node_id); | |||||
| Trackable node = null; | |||||
| Action<object, object, object> setter = null; | |||||
| if(_loaded_nodes.TryGetValue(node_id, out var res)) | |||||
| { | |||||
| (node, setter) = res; | |||||
| } | |||||
| if(node is not null) | |||||
| { | |||||
| node._maybe_initialize_trackable(); | |||||
| } | |||||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||||
| { | |||||
| Trackable children_object = null; | |||||
| if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) | |||||
| { | |||||
| children_object = result.Item1; | |||||
| } | |||||
| // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. | |||||
| if(children_object is null && node is not null) | |||||
| { | |||||
| children_object = node._lookup_dependency(refer.LocalName); | |||||
| if(children_object is TrackableDataStructure) | |||||
| { | |||||
| // TODO: set setter as lambda. | |||||
| _loaded_nodes[refer.NodeId] = (children_object, setter); | |||||
| } | |||||
| } | |||||
| string child_path = $"{node_path}.{refer.LocalName}"; | |||||
| _node_path_to_id[child_path] = refer.NodeId; | |||||
| nodes_to_visit.Enqueue(child_path); | |||||
| } | |||||
| } | |||||
| if (all_filtered_nodes.Contains(0)) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| return all_filtered_nodes.ToList(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Orders the node ids so that dependencies appear first. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| private List<int> _generate_ordered_node_ids() | |||||
| { | |||||
| List<int> unordered_ids; | |||||
| if(_filtered_nodes is null) | |||||
| { | |||||
| unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); | |||||
| } | |||||
| else | |||||
| { | |||||
| unordered_ids = new List<int>(_filtered_nodes); | |||||
| } | |||||
| Dictionary<int, List<int>> dependency_map = new(); | |||||
| foreach(var node_id in unordered_ids) | |||||
| { | |||||
| var deps = dependency_map.SetDefault(node_id, new List<int>()); | |||||
| if (_loaded_nodes.ContainsKey(node_id)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| var proto = _proto.Nodes[node_id]; | |||||
| foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||||
| { | |||||
| deps.Add(dep); | |||||
| if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||||
| { | |||||
| // TODO: add info with `_pretty_printer`. | |||||
| throw new ValueError($"Unable to partially load SavedModel since the specified filter " + | |||||
| $"does not include all required objects for loading (e.g. " + | |||||
| $"variables used in functions or deserialization dependencies). " + | |||||
| $"Please include this path in the filter: {dep}"); | |||||
| } | |||||
| } | |||||
| int? prev_slot = null; | |||||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||||
| { | |||||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||||
| // The optimizer and original variable must be created before the slot | |||||
| // variable, since the slot variable is generated using the Optimizer's | |||||
| // add_slot API. | |||||
| var slot_deps = dependency_map[slot_variable_node_id]; | |||||
| slot_deps.Add(node_id); | |||||
| slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||||
| if(prev_slot is not null) | |||||
| { | |||||
| slot_deps.Add(prev_slot.Value); | |||||
| } | |||||
| prev_slot = slot_variable_node_id; | |||||
| } | |||||
| } | |||||
| try | |||||
| { | |||||
| return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||||
| } | |||||
| catch (TrackableUtils.CyclicDependencyError ex) | |||||
| { | |||||
| throw new ValueError("Encountered a cycle in the deserialization dependencies" + | |||||
| "in the SavedModel. This is extremely unexpected, please" + | |||||
| "file a bug and make sure you are not manually modifying the SavedModel."); | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns a dictionary of all dependencies of an object. | |||||
| /// </summary> | |||||
| /// <param name="proto"></param> | |||||
| /// <returns></returns> | |||||
| private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||||
| { | |||||
| Dictionary<Maybe<string, int>, int> dependencies = new(); | |||||
| foreach(var refer in proto.Dependencies) | |||||
| { | |||||
| dependencies[refer.LocalName] = refer.NodeId; | |||||
| } | |||||
| if(proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
| { | |||||
| var concreete_functions = proto.Function.ConcreteFunctions; | |||||
| foreach(var fn_name in concreete_functions) | |||||
| { | |||||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||||
| { | |||||
| dependencies[bound_input] = bound_input; | |||||
| } | |||||
| } | |||||
| } | |||||
| else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) | |||||
| { | |||||
| var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; | |||||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||||
| { | |||||
| dependencies[bound_input] = bound_input; | |||||
| } | |||||
| } | |||||
| else if(proto.KindCase == SavedObject.KindOneofCase.Resource) | |||||
| { | |||||
| foreach(var child in proto.Children) | |||||
| { | |||||
| if(child.LocalName == "_create_resource") | |||||
| { | |||||
| dependencies["_create_resource"] = child.NodeId; | |||||
| } | |||||
| } | |||||
| } | |||||
| return dependencies; | |||||
| } | |||||
| /// <summary> | |||||
| /// Loads all nodes and functions from the SavedModel and their edges. | |||||
| /// </summary> | |||||
| private void _load_all() | |||||
| { | |||||
| _load_nodes(); | |||||
| _load_edges(); | |||||
| _setup_remaining_functions(); | |||||
| _load_checkpoint_save_and_restore_functions(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Restores the checkpoint-related save/restore functions to all nodes. | |||||
| /// </summary> | |||||
| private void _load_checkpoint_save_and_restore_functions() | |||||
| { | |||||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||||
| { | |||||
| var node = get(node_id); | |||||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||||
| { | |||||
| // Restore Trackable serialize- and restore-from-tensor functions. | |||||
| Debug.Assert(proto.SaveableObjects.Count == 1); | |||||
| var saveable_object_proto = proto.SaveableObjects.Values.First(); | |||||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| else | |||||
| { | |||||
| // Restore legacy SaveableObject functions. | |||||
| Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new(); | |||||
| foreach(var item in proto.SaveableObjects) | |||||
| { | |||||
| var name = item.Key; | |||||
| var saveable_object_proto = item.Value; | |||||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||||
| saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||||
| } | |||||
| node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||||
| } | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Load all saved objects. | |||||
| /// </summary> | |||||
| private void _load_nodes() | |||||
| { | |||||
| // `nodes` maps from node ids to recreated objects | |||||
| // `node_setters` maps from node ids to setter functions | |||||
| // (same signature as setattr) for setting children. | |||||
| var (nodes, node_setters) = _initialize_loaded_nodes(); | |||||
| Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)> | |||||
| slot_variable_node_ids = new(); | |||||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||||
| { | |||||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||||
| { | |||||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||||
| slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); | |||||
| } | |||||
| } | |||||
| // Re-create everything. | |||||
| foreach (var (node_id, proto) in _iter_all_nodes()) | |||||
| { | |||||
| if (nodes.ContainsKey(node_id)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| else if (slot_variable_node_ids.ContainsKey(node_id)) | |||||
| { | |||||
| // Use the public Optimizer interface when creating slot variables. | |||||
| var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||||
| var optimizer_object = nodes[optimizer_node_id]; | |||||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||||
| // TODO: implement it. | |||||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var (node, setter) = _recreate(proto, node_id, nodes); | |||||
| nodes[node_id] = node; | |||||
| node_setters[node_id] = setter; | |||||
| } | |||||
| } | |||||
| if (!nodes.ContainsKey(0)) | |||||
| { | |||||
| nodes[0] = _recreate_base_user_object().Item1; | |||||
| } | |||||
| _nodes = new List<Trackable>(); | |||||
| for(int i = 0; i < _proto.Nodes.Count; i++) | |||||
| { | |||||
| _nodes.Add(nodes[i]); | |||||
| } | |||||
| _node_setters = node_setters; | |||||
| } | |||||
| /// <summary> | |||||
| /// Load state from checkpoint into the deserialized objects. | |||||
| /// </summary> | |||||
| private void _restore_checkpoint() | |||||
| { | |||||
| var variables_path = SavedModelUtils.get_variables_dir(_export_dir); | |||||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||||
| tf.device("CPU"); | |||||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
| if (_save_options.allow_partial_checkpoint) | |||||
| { | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Adds edges from objects to other objects and functions. | |||||
| /// </summary> | |||||
| private void _load_edges() | |||||
| { | |||||
| foreach(var (node_id, object_proto) in _iter_all_nodes()) | |||||
| { | |||||
| _add_object_graph_edges(object_proto, node_id); | |||||
| } | |||||
| if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) | |||||
| { | |||||
| var root = get(0); | |||||
| foreach(var node_path in _node_filters.Keys) | |||||
| { | |||||
| var loaded_node = _nodes[_node_path_to_id[node_path]]; | |||||
| var path = node_path.Split('.'); | |||||
| var current_node = root; | |||||
| foreach(var name in path.Skip(1).Take(path.Length - 2)) | |||||
| { | |||||
| // `hasattr` and `setattr` is used here | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| // `hasattr` and `setattr` is used here | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| private void _setup_remaining_functions() | |||||
| { | |||||
| // TODO: implement it with concrete functions. | |||||
| } | |||||
| public Trackable get(int node_id) | |||||
| { | |||||
| return _nodes[node_id]; | |||||
| } | |||||
| public Trackable get(string node_id) | |||||
| { | |||||
| return get(_node_path_to_id[node_id]); | |||||
| } | |||||
| /// <summary> | |||||
| /// Adds edges from an object to its children. | |||||
| /// </summary> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="node_id"></param> | |||||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | |||||
| { | |||||
| var obj = _nodes[node_id]; | |||||
| var setter = _node_setters[node_id]; | |||||
| foreach(var refer in proto.Children) | |||||
| { | |||||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||||
| // skip the process of "__call__" | |||||
| } | |||||
| } | |||||
| private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||||
| { | |||||
| Dictionary<int, Trackable> nodes = new(); | |||||
| Dictionary<int, Action<object, object, object>> node_setters = new(); | |||||
| foreach(var item in _loaded_nodes) | |||||
| { | |||||
| var node_id = item.Key; | |||||
| var (node, setter) = item.Value; | |||||
| nodes[node_id] = node; | |||||
| node_setters[node_id] = setter; | |||||
| } | |||||
| return (nodes, node_setters); | |||||
| } | |||||
| private IEnumerable<(int, SavedObject)> _iter_all_nodes() | |||||
| { | |||||
| foreach(var node_id in _ordered_node_ids) | |||||
| { | |||||
| yield return (node_id, _proto.Nodes[node_id]); | |||||
| } | |||||
| } | |||||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||||
| { | |||||
| // skip the registered classes. | |||||
| Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||||
| foreach(var item in _get_node_dependencies(proto)) | |||||
| { | |||||
| dependencies[item.Key] = nodes[item.Value]; | |||||
| } | |||||
| return _recreate_default(proto, node_id, dependencies); | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates a Python object from a SavedObject protocol buffer. | |||||
| /// </summary> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="node_id"></param> | |||||
| /// <param name="dependencies"></param> | |||||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| { | |||||
| return proto.KindCase switch | |||||
| { | |||||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||||
| SavedObject.KindOneofCase.Function => throw new NotImplementedException(), | |||||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||||
| }; | |||||
| } | |||||
| private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||||
| { | |||||
| // skip the check of proto identifier because of lack of property. | |||||
| var looked_up = RevivedTypes.deserialize(proto); | |||||
| if(looked_up is null) | |||||
| { | |||||
| return _recreate_base_user_object(proto, node_id); | |||||
| } | |||||
| return (looked_up.Item1, looked_up.Item2); | |||||
| } | |||||
| private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||||
| { | |||||
| return (new _UserObject(), setattr); | |||||
| } | |||||
| private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto) | |||||
| { | |||||
| string name = proto.Name; | |||||
| string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>"; | |||||
| // TODO(Rinne): `validate_synchronization_aggregation_trainable` | |||||
| var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( | |||||
| proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); | |||||
| var saved_device = proto.Device; | |||||
| var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); | |||||
| if (load_with_device) | |||||
| { | |||||
| tf.device(saved_device); | |||||
| return (new UninitializedVariable( | |||||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
| dtype: (TF_DataType)proto.Dtype, | |||||
| name: name, | |||||
| trainable: trainable, | |||||
| aggregation: aggregation | |||||
| ), setattr); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (new UninitializedVariable( | |||||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
| dtype: (TF_DataType)proto.Dtype, | |||||
| name: name, | |||||
| trainable: trainable, | |||||
| aggregation: aggregation | |||||
| ), setattr); | |||||
| } | |||||
| } | |||||
| // TODO: remove this to a common class. | |||||
| public static Action<object, object, object> setattr = (x, y, z) => | |||||
| { | |||||
| Debug.Assert(y is string); | |||||
| var properties = x.GetType().GetProperties(); | |||||
| foreach(var p in properties) | |||||
| { | |||||
| if((string)y == p.Name) | |||||
| { | |||||
| p.SetValue(x, z); | |||||
| return; | |||||
| } | |||||
| } | |||||
| // TODO(Rinne): check if the property has been set successfully. | |||||
| //throw new ValueError($"Cannot find the property {y} of {x}."); | |||||
| }; | |||||
| public class _UserObject: AutoTrackable | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,122 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Operations; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Loader | |||||
| { | |||||
| public static SavedModel parse_saved_model(string export_dir) | |||||
| { | |||||
| var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||||
| var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); | |||||
| SavedModel saved_model = new SavedModel(); | |||||
| if (File.Exists(path_to_pb)) | |||||
| { | |||||
| byte[] file_content; | |||||
| using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) | |||||
| { | |||||
| file_content = new byte[f.Length]; | |||||
| Debug.Assert(f.Length <= int.MaxValue); | |||||
| f.Read(file_content, 0, (int)f.Length); | |||||
| } | |||||
| // TODO: change to stream mode. | |||||
| saved_model.MergeFrom(file_content); | |||||
| return saved_model; | |||||
| } | |||||
| else if (File.Exists(path_to_pbtxt)) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + | |||||
| $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); | |||||
| } | |||||
| } | |||||
| // TODO: revise the type of `tags` | |||||
| public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) | |||||
| { | |||||
| return load_partial(export_dir, null, tags, options)["root"]; | |||||
| } | |||||
| public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null) | |||||
| { | |||||
| if (options is null) | |||||
| { | |||||
| options = new LoadOptions(); | |||||
| } | |||||
| if (tags is not null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); | |||||
| Trackable root = null; | |||||
| Loader loader = null; | |||||
| if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) | |||||
| { | |||||
| // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` | |||||
| var meta_graph_def = saved_model_proto.MetaGraphs[0]; | |||||
| if (!BitConverter.IsLittleEndian) | |||||
| { | |||||
| SavedModelUtils.swap_function_tensor_content(meta_graph_def); | |||||
| } | |||||
| var object_graph_proto = meta_graph_def.ObjectGraphDef; | |||||
| var ckpt_options = new CheckpointOptions(options.experimental_io_device); | |||||
| tf_with(ops.init_scope(), x => | |||||
| { | |||||
| loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||||
| root = loader.get(0); | |||||
| // skip the assignment of `graph_debug_info`. | |||||
| }); | |||||
| // skip the assignment of `tensorflow_version` | |||||
| // skip the assignment of `tensorflow_git_version` | |||||
| // skip the process of `metrics`. | |||||
| } | |||||
| else | |||||
| { | |||||
| if(filters is not null && filters.Count > 0) | |||||
| { | |||||
| throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" | |||||
| + " version) cannot be loaded with node filters."); | |||||
| } | |||||
| tf_with(ops.init_scope(), x => | |||||
| { | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| }); | |||||
| } | |||||
| if(filters != null && filters.Count > 0) | |||||
| { | |||||
| return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||||
| } | |||||
| else | |||||
| { | |||||
| var res = new Dictionary<string, Trackable>(); | |||||
| res["root"] = root; | |||||
| return res; | |||||
| } | |||||
| } | |||||
| public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) | |||||
| { | |||||
| var saved_model = parse_saved_model(export_dir); | |||||
| // TODO: implement debug info. | |||||
| return (saved_model, null); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -6,7 +6,6 @@ using System.Text; | |||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -1,7 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.ModelSaving; | |||||
| namespace Tensorflow.Training.Saving.SavedModel | namespace Tensorflow.Training.Saving.SavedModel | ||||
| { | { | ||||
| @@ -333,6 +333,21 @@ namespace Tensorflow | |||||
| return restored_ops; | return restored_ops; | ||||
| }; | }; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Returns a dict of SaveableObject factories generated from loaded fns. | |||||
| /// </summary> | |||||
| /// <param name="saveable_fn_by_name"></param> | |||||
| /// <param name="temp_session"></param> | |||||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects( | |||||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||||
| { | |||||
| if (saveable_fn_by_name.Count > 0) | |||||
| { | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||||
| } | |||||
| return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||||
| } | |||||
| } | } | ||||
| public class SaveableCompatibilityConverter: Trackable | public class SaveableCompatibilityConverter: Trackable | ||||
| @@ -20,8 +20,8 @@ using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
| using Tensorflow.ModelSaving; | |||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| @@ -71,6 +71,17 @@ namespace Tensorflow.Train | |||||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | ||||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | ||||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | ||||
| public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories | |||||
| { | |||||
| get | |||||
| { | |||||
| return _self_saveable_object_factories; | |||||
| } | |||||
| set | |||||
| { | |||||
| _self_saveable_object_factories = value; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | /// Restore-on-create for a variable be saved with this `Checkpointable`. | ||||
| @@ -259,4 +270,6 @@ namespace Tensorflow.Train | |||||
| } | } | ||||
| public record class TrackableReference(string Name, Trackable Refer); | public record class TrackableReference(string Name, Trackable Refer); | ||||
| public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); | |||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| @@ -5,9 +5,9 @@ using Tensorflow.Variables; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.ModelSaving; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -19,7 +19,11 @@ namespace Tensorflow | |||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| public TF_DataType dtype => _dtype; | public TF_DataType dtype => _dtype; | ||||
| protected string _handle_name; | protected string _handle_name; | ||||
| protected string handle_name => _handle_name; | |||||
| public string handle_name | |||||
| { | |||||
| get { return _handle_name; } | |||||
| set { _handle_name = value; } | |||||
| } | |||||
| protected string _unique_id; | protected string _unique_id; | ||||
| public string UniqueId => _unique_id; | public string UniqueId => _unique_id; | ||||
| @@ -238,5 +238,23 @@ namespace Tensorflow | |||||
| { | { | ||||
| return _graph_element.eval(session); | return _graph_element.eval(session); | ||||
| } | } | ||||
| public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( | |||||
| VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) | |||||
| { | |||||
| if(aggregation is null) | |||||
| { | |||||
| aggregation = VariableAggregation.None; | |||||
| } | |||||
| if(synchronization is null) | |||||
| { | |||||
| synchronization = VariableSynchronization.Auto; | |||||
| } | |||||
| if (trainable is null) | |||||
| { | |||||
| trainable = synchronization != VariableSynchronization.OnRead; | |||||
| } | |||||
| return (synchronization.Value, aggregation.Value, trainable.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="config"></param> | /// <param name="config"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||||
| { | { | ||||
| // Layer instances created during the graph reconstruction process. | // Layer instances created during the graph reconstruction process. | ||||
| var created_layers = new Dictionary<string, ILayer>(); | var created_layers = new Dictionary<string, ILayer>(); | ||||
| @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine | |||||
| Inputs = inputs, | Inputs = inputs, | ||||
| Outputs = outputs | Outputs = outputs | ||||
| }) | }) | ||||
| { | |||||
| Initialize(inputs, outputs, name); | |||||
| } | |||||
| internal void Initialize(Tensors inputs, Tensors outputs, string name = null) | |||||
| { | { | ||||
| _input_layers = new List<ILayer>(); | _input_layers = new List<ILayer>(); | ||||
| _output_layers = new List<ILayer>(); | _output_layers = new List<ILayer>(); | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -75,7 +76,17 @@ namespace Tensorflow.Keras.Engine | |||||
| public int Id => id; | public int Id => id; | ||||
| protected string name; | protected string name; | ||||
| protected string base_name; | protected string base_name; | ||||
| public string Name => name; | |||||
| public string Name | |||||
| { | |||||
| get | |||||
| { | |||||
| return name; | |||||
| } | |||||
| set | |||||
| { | |||||
| name = value; | |||||
| } | |||||
| } | |||||
| protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
| protected List<Operation> updates; | protected List<Operation> updates; | ||||
| @@ -89,6 +100,8 @@ namespace Tensorflow.Keras.Engine | |||||
| List<INode> outboundNodes; | List<INode> outboundNodes; | ||||
| public List<INode> OutboundNodes => outboundNodes; | public List<INode> OutboundNodes => outboundNodes; | ||||
| public JObject SerializedAttributes { get; set; } | |||||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | ||||
| public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
| public Tensor[] input | public Tensor[] input | ||||
| @@ -117,6 +130,11 @@ namespace Tensorflow.Keras.Engine | |||||
| protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
| public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
| { | |||||
| Initialize(args); | |||||
| } | |||||
| internal virtual void Initialize(LayerArgs args) | |||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| // A stateful layer is a layer whose updates are run during inference too, | // A stateful layer is a layer whose updates are run during inference too, | ||||
| @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| using (SharedObjectSavingScope.Enter()) | using (SharedObjectSavingScope.Enter()) | ||||
| { | { | ||||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
| KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine | |||||
| IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
| bool _base_model_initialized; | bool _base_model_initialized; | ||||
| bool stop_training; | bool stop_training; | ||||
| public bool IsGraphNetwork => _is_graph_network; | |||||
| public OptimizerV2 Optimizer | public OptimizerV2 Optimizer | ||||
| { | { | ||||
| @@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine | |||||
| _init_batch_counters(); | _init_batch_counters(); | ||||
| } | } | ||||
| internal override void Initialize(LayerArgs args) | |||||
| { | |||||
| _init_batch_counters(); | |||||
| base.Initialize(args); | |||||
| } | |||||
| void _configure_steps_per_execution(int steps_per_execution) | void _configure_steps_per_execution(int steps_per_execution) | ||||
| { | { | ||||
| _steps_per_execution = tf.Variable(steps_per_execution, | _steps_per_execution = tf.Variable(steps_per_execution, | ||||
| @@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine | |||||
| : base(args.Inputs, args.Outputs, name: args.Name) | : base(args.Inputs, args.Outputs, name: args.Name) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| if (args.Layers == null) | |||||
| args.Layers = new List<ILayer>(); | |||||
| // SupportsMasking = true; | // SupportsMasking = true; | ||||
| _compute_output_and_mask_jointly = true; | _compute_output_and_mask_jointly = true; | ||||
| _auto_track_sub_layers = false; | _auto_track_sub_layers = false; | ||||
| @@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine | |||||
| _created_nodes = new List<INode>(); | _created_nodes = new List<INode>(); | ||||
| // Add to the model any layers passed to the constructor. | // Add to the model any layers passed to the constructor. | ||||
| if (args.Layers != null) | |||||
| if (args.Layers is not null) | |||||
| { | { | ||||
| foreach (var layer in args.Layers) | |||||
| add(layer); | |||||
| InitLayers(args.Layers); | |||||
| } | |||||
| } | |||||
| public void InitLayers(IEnumerable<ILayer> layers) | |||||
| { | |||||
| foreach(var layer in layers) | |||||
| { | |||||
| add(layer); | |||||
| } | } | ||||
| } | } | ||||
| @@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| { | { | ||||
| throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
| } | } | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | |||||
| base.build(input_shape); | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| } | } | ||||
| public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
| { | { | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | |||||
| base.build(input_shape); | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| { | { | ||||
| @@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| if ( alpha < 0f ) { | if ( alpha < 0f ) { | ||||
| throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
| } | } | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | |||||
| base.build(input_shape); | |||||
| } | } | ||||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
| Tensor output = inputs; | Tensor output = inputs; | ||||
| @@ -1,12 +1,24 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.ComponentModel; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Reflection; | |||||
| using System.Text.RegularExpressions; | using System.Text.RegularExpressions; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Keras.Layers.Rnn; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| using static Tensorflow.ApiDef.Types; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| public class KerasObjectLoader | public class KerasObjectLoader | ||||
| { | { | ||||
| SavedMetadata _metadata; | |||||
| SavedObjectGraph _proto; | |||||
| Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||||
| Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||||
| List<int> _traversed_nodes_from_config = new List<int>(); | |||||
| private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||||
| private SavedMetadata _metadata; | |||||
| private SavedObjectGraph _proto; | |||||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||||
| private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>(); | |||||
| private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>(); | |||||
| private List<int> _traversed_nodes_from_config = new List<int>(); | |||||
| private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes; | |||||
| private List<int> _models_to_reconstruct; | |||||
| public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes; | |||||
| static KerasObjectLoader() | |||||
| { | |||||
| PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||||
| } | |||||
| public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | ||||
| { | { | ||||
| _metadata = metadata; | _metadata = metadata; | ||||
| _proto = object_graph_def; | _proto = object_graph_def; | ||||
| _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | ||||
| _models_to_reconstruct = new List<int>(); | |||||
| loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -42,15 +66,251 @@ namespace Tensorflow.Keras.Saving | |||||
| continue; | continue; | ||||
| } | } | ||||
| _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||||
| } | |||||
| foreach(var node_metadata in metric_list) | |||||
| { | |||||
| try | |||||
| { | |||||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | |||||
| node_metadata.Metadata); | |||||
| } | |||||
| catch(ValueError e) | |||||
| { | |||||
| if (compile) | |||||
| { | |||||
| throw e; | |||||
| } | |||||
| // TODO: add logging.warning. | |||||
| } | |||||
| } | |||||
| } | |||||
| public string get_path(int node_id) | |||||
| { | |||||
| return _node_paths[node_id]; | |||||
| } | |||||
| /// <summary> | |||||
| /// Finish setting up Keras objects. | |||||
| /// | |||||
| /// This function is executed after all objects and functions have been created. | |||||
| /// Call functions and losses are attached to each layer, and once all layers | |||||
| /// have been fully set up, graph networks are initialized. | |||||
| /// | |||||
| /// Subclassed models that are revived from the SavedModel are treated like | |||||
| /// layers, and have their call/loss functions attached here. | |||||
| /// </summary> | |||||
| public void finalize_objects() | |||||
| { | |||||
| List<Layer> layers_revived_from_config = new(); | |||||
| List<Layer> layers_revived_from_saved_model = new(); | |||||
| foreach(var item in loaded_nodes) | |||||
| { | |||||
| var node_id = item.Key; | |||||
| var node = item.Value.Item1; | |||||
| if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| _unblock_model_reconstruction(node_id, node as Layer); | |||||
| if(node is InputLayer or Metric) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||||
| layers_revived_from_config.Add(node as Layer); | |||||
| } | |||||
| _finalize_saved_model_layers(layers_revived_from_saved_model); | |||||
| _finalize_config_layers(layers_revived_from_config); | |||||
| _reconstruct_all_models(); | |||||
| } | |||||
| private void _reconstruct_all_models() | |||||
| { | |||||
| HashSet<int> all_initialized_models = new(); | |||||
| for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) | |||||
| { | |||||
| int model_id = _models_to_reconstruct[i]; | |||||
| all_initialized_models.Add(model_id); | |||||
| var (model, layers) = model_layer_dependencies[model_id]; | |||||
| _reconstruct_model(model_id, model, layers.ToList()); | |||||
| _finalize_config_layers(new List<Layer>() { model }); | |||||
| } | |||||
| Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); | |||||
| } | |||||
| private void _reconstruct_model(int model_id, Model model, List<Layer> layers) | |||||
| { | |||||
| var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"]; | |||||
| if(model.input is not null && model.input.Length > 0) | |||||
| { | |||||
| } | |||||
| else if(model is Sequential s) | |||||
| { | |||||
| if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) | |||||
| { | |||||
| if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||||
| { | |||||
| layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||||
| } | |||||
| else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||||
| { | |||||
| // TODO: implement it | |||||
| } | |||||
| } | |||||
| // `model.__init__(layers, config["name"])` | |||||
| s.InitLayers(layers); | |||||
| s.Name = config["name"].ToObject<string>(); | |||||
| if(s.input is null || s.input.Length == 0) | |||||
| { | |||||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||||
| var input_specs = _infer_inputs(first_layer); | |||||
| var input_shapes = _infer_inputs(first_layer, true); | |||||
| // `model._set_inputs(input_specs)` | |||||
| // skip the check of input_specs is Dictionary | |||||
| if (!s.Built) | |||||
| { | |||||
| s.build(input_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| // skip the parameter `created_layers`. | |||||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>()); | |||||
| // skip the `model.__init__` | |||||
| (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||||
| (model as Functional).connect_ancillary_layers(created_layers); | |||||
| } | } | ||||
| _set_network_attributes_from_metadata(model); | |||||
| _unblock_model_reconstruction(model_id, model); | |||||
| } | } | ||||
| void _load_layer(int node_id, string identifier, string metadata_json) | |||||
| private void _set_network_attributes_from_metadata(Model revived_object) | |||||
| { | |||||
| // TODO: implement it. | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs the final steps of loading Keras Layers from config. | |||||
| /// </summary> | |||||
| /// <param name="layers"></param> | |||||
| private void _finalize_config_layers(List<Layer> layers) | |||||
| { | |||||
| foreach(var layer in layers) | |||||
| { | |||||
| if (_is_graph_network(layer)) | |||||
| { | |||||
| _restore_layer_unconditional_losses(layer); | |||||
| } | |||||
| _restore_layer_activation_loss(layer); | |||||
| _restore_layer_metrics(layer); | |||||
| // TODO(Rinne): deal with RNN. | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Runs the final steps of loading Keras Layers from SavedModel. | |||||
| /// </summary> | |||||
| /// <param name="layers"></param> | |||||
| private void _finalize_saved_model_layers(List<Layer> layers) | |||||
| { | |||||
| foreach(var layer in layers) | |||||
| { | |||||
| // TODO(Rinne): deal with `RevivedNetwork`. | |||||
| _restore_layer_unconditional_losses(layer); | |||||
| _restore_layer_activation_loss(layer); | |||||
| _restore_layer_metrics(layer); | |||||
| } | |||||
| } | |||||
| private void _restore_layer_unconditional_losses(Layer layer) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| } | |||||
| private void _restore_layer_activation_loss(Layer layer) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| } | |||||
| private void _restore_layer_metrics(Layer layer) | |||||
| { | |||||
| // TODO(Rinne): implement it. | |||||
| } | |||||
| /// <summary> | |||||
| /// Removes layer from blocking model reconstruction. | |||||
| /// </summary> | |||||
| /// <param name="layer_id"></param> | |||||
| /// <param name="layer"></param> | |||||
| private void _unblock_model_reconstruction(int layer_id, Layer layer) | |||||
| { | |||||
| foreach(var depencency in model_layer_ids_dependencies) | |||||
| { | |||||
| var layer_ids = depencency.Value.Item2; | |||||
| var layers = model_layer_dependencies.SetDefault(depencency.Key, | |||||
| (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; | |||||
| if (!layer_ids.Contains(layer_id)) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| layers[Array.IndexOf(layer_ids, layer_id)] = layer; | |||||
| if (layers.All(x => x is not null)) | |||||
| { | |||||
| _models_to_reconstruct.Add(depencency.Key); | |||||
| } | |||||
| } | |||||
| } | |||||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||||
| { | { | ||||
| metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | ||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | ||||
| _revive_from_config(identifier, metadata, node_id); | |||||
| if (loaded_nodes.ContainsKey(node_id)) | |||||
| { | |||||
| var (node, setter) = loaded_nodes[node_id]; | |||||
| _maybe_add_serialized_attributes(node as Layer, metadata); | |||||
| var config = metadata.Config; | |||||
| if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) | |||||
| { | |||||
| Debug.Assert(node is Model); | |||||
| var child_nodes = _get_child_layer_node_ids(node_id); | |||||
| model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); | |||||
| if(child_nodes is null || child_nodes.Length == 0) | |||||
| { | |||||
| _models_to_reconstruct.Add(node_id); | |||||
| } | |||||
| } | |||||
| return (node, setter); | |||||
| } | |||||
| else | |||||
| { | |||||
| var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||||
| if (obj is null) | |||||
| { | |||||
| (obj, setter) = _revive_custom_object(identifier, metadata); | |||||
| } | |||||
| Debug.Assert(obj is Layer); | |||||
| _maybe_add_serialized_attributes(obj as Layer, metadata); | |||||
| return (obj, setter); | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -59,11 +319,32 @@ namespace Tensorflow.Keras.Saving | |||||
| /// <param name="identifier"></param> | /// <param name="identifier"></param> | ||||
| /// <param name="metadata"></param> | /// <param name="metadata"></param> | ||||
| /// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
| void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||||
| private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||||
| { | { | ||||
| var obj = _revive_graph_network(identifier, metadata, node_id); | |||||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||||
| Trackable obj; | |||||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||||
| { | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| } | |||||
| else | |||||
| { | |||||
| obj = _revive_graph_network(identifier, metadata, node_id); | |||||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||||
| } | |||||
| if(obj is null) | |||||
| { | |||||
| return (null, null); | |||||
| } | |||||
| var setter = _config_node_setter(_revive_setter); | |||||
| _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | ||||
| return (obj, setter); | |||||
| } | |||||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||||
| { | |||||
| // TODO: implement it. | |||||
| throw new NotImplementedException(); | |||||
| } | } | ||||
| Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | ||||
| @@ -71,6 +352,12 @@ namespace Tensorflow.Keras.Saving | |||||
| var config = metadata.Config; | var config = metadata.Config; | ||||
| var class_name = metadata.ClassName; | var class_name = metadata.ClassName; | ||||
| Model model = null; | Model model = null; | ||||
| if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") | |||||
| { | |||||
| return null; | |||||
| } | |||||
| if (class_name == "Sequential") | if (class_name == "Sequential") | ||||
| { | { | ||||
| model = new Sequential(new SequentialArgs | model = new Sequential(new SequentialArgs | ||||
| @@ -78,34 +365,83 @@ namespace Tensorflow.Keras.Saving | |||||
| Name = config.GetValue("name").ToString() | Name = config.GetValue("name").ToString() | ||||
| }); | }); | ||||
| } | } | ||||
| else if (class_name == "Functional") | |||||
| else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | |||||
| { | { | ||||
| throw new NotImplementedException(""); | |||||
| model = model = new Sequential(new SequentialArgs | |||||
| { | |||||
| Name = class_name | |||||
| }); | |||||
| } | |||||
| else | |||||
| { | |||||
| // TODO: implement it. | |||||
| throw new NotImplementedException("Not implemented"); | |||||
| } | } | ||||
| if (!metadata.IsGraphNetwork) | |||||
| return null; | |||||
| // Record this model and its layers. This will later be used to reconstruct | // Record this model and its layers. This will later be used to reconstruct | ||||
| // the model. | // the model. | ||||
| var layers = _get_child_layer_node_ids(node_id); | var layers = _get_child_layer_node_ids(node_id); | ||||
| model_layer_dependencies[node_id] = (model, layers); | |||||
| model_layer_ids_dependencies[node_id] = (model, layers); | |||||
| if(layers is null || layers.Length == 0) | |||||
| { | |||||
| _models_to_reconstruct.Add(node_id); | |||||
| } | |||||
| return model; | return model; | ||||
| } | } | ||||
| Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||||
| Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||||
| { | { | ||||
| var config = metadata.Config; | var config = metadata.Config; | ||||
| var class_name = metadata.ClassName; | var class_name = metadata.ClassName; | ||||
| var shared_object_id = metadata.SharedObjectId; | var shared_object_id = metadata.SharedObjectId; | ||||
| var must_restore_from_config = metadata.MustRestoreFromConfig; | var must_restore_from_config = metadata.MustRestoreFromConfig; | ||||
| var obj = class_name switch | |||||
| { | |||||
| "Resizing" => Resizing.from_config(config), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| var obj = generic_utils.deserialize_keras_object(class_name, config); | |||||
| obj.Name = metadata.Name; | |||||
| // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | ||||
| return null; | |||||
| if (!built) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| return obj; | |||||
| } | |||||
| private void _revive_setter(object layer, object name, object value) | |||||
| { | |||||
| Debug.Assert(name is string); | |||||
| Debug.Assert(layer is Layer); | |||||
| if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||||
| { | |||||
| if(value is Trackable) | |||||
| { | |||||
| (layer as Layer)._track_trackable(value as Trackable, name as string); | |||||
| } | |||||
| if((layer as Layer).SerializedAttributes is null) | |||||
| { | |||||
| (layer as Layer).SerializedAttributes = new JObject(); | |||||
| } | |||||
| (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||||
| } | |||||
| else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||||
| { | |||||
| (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||||
| } | |||||
| else | |||||
| { | |||||
| var properties = layer.GetType().GetProperties(); | |||||
| foreach(var p in properties) | |||||
| { | |||||
| if(p.Name == name as string && p.GetValue(layer) is not null) | |||||
| { | |||||
| return; | |||||
| } | |||||
| } | |||||
| Loader.setattr(layer, name, value); | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -143,34 +479,186 @@ namespace Tensorflow.Keras.Saving | |||||
| /// <param name="obj"></param> | /// <param name="obj"></param> | ||||
| /// <param name="proto"></param> | /// <param name="proto"></param> | ||||
| /// <param name="node_id"></param> | /// <param name="node_id"></param> | ||||
| void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||||
| void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) | |||||
| { | { | ||||
| if (_traversed_nodes_from_config.Contains(node_id)) | if (_traversed_nodes_from_config.Contains(node_id)) | ||||
| return; | return; | ||||
| var parent_path = _node_paths[node_id]; | var parent_path = _node_paths[node_id]; | ||||
| _traversed_nodes_from_config.Add(node_id); | _traversed_nodes_from_config.Add(node_id); | ||||
| if (!obj.Built) | |||||
| obj._maybe_initialize_trackable(); | |||||
| if(obj is Layer layer && !layer.Built) | |||||
| { | { | ||||
| var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||||
| _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata); | |||||
| _try_build_layer(layer, node_id, metadata.BuildInputShape); | |||||
| } | |||||
| List<(Trackable, int, string)> children = new(); | |||||
| foreach(var refer in proto.Children) | |||||
| { | |||||
| var obj_child = obj._lookup_dependency(refer.LocalName); | |||||
| children.Add((obj_child, refer.NodeId, refer.LocalName)); | |||||
| } | |||||
| var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||||
| Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||||
| }); | |||||
| if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||||
| { | |||||
| var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); | |||||
| foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) | |||||
| { | |||||
| if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) | |||||
| { | |||||
| var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; | |||||
| children.Add((metric, refer.NodeId, metric_path)); | |||||
| } | |||||
| } | |||||
| } | |||||
| foreach(var (obj_child, child_id, child_name) in children) | |||||
| { | |||||
| if(obj_child is null) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| var child_proto = _proto.Nodes[child_id]; | |||||
| // skip the check for registered identifier | |||||
| Action<object, object, object> setter; | |||||
| if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||||
| { | |||||
| setter = _revive_setter; | |||||
| } | |||||
| else | |||||
| { | |||||
| setter = Loader.setattr; | |||||
| } | |||||
| if (loaded_nodes.ContainsKey(child_id)) | |||||
| { | |||||
| // skip the logging.warning | |||||
| continue; | |||||
| } | |||||
| if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) | |||||
| { | |||||
| (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; | |||||
| } | |||||
| if(obj_child is TrackableDataStructure) | |||||
| { | |||||
| setter = (x, y, z) => { }; | |||||
| } | |||||
| var child_path = $"{parent_path}.{child_name}"; | |||||
| _node_paths[child_id] = child_path; | |||||
| _add_children_recreated_from_config(obj_child, child_proto, child_id); | |||||
| loaded_nodes[child_id] = (obj_child, setter); | |||||
| } | } | ||||
| } | } | ||||
| bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) | |||||
| private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||||
| { | { | ||||
| if (obj.Built) | if (obj.Built) | ||||
| return true; | return true; | ||||
| if(build_input_shape is null) | |||||
| { | |||||
| build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||||
| } | |||||
| if(build_input_shape is not null) | |||||
| { | |||||
| obj.build(build_input_shape); | |||||
| // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. | |||||
| // On the one hand, C# does not support call a method from specified parent class. | |||||
| // On the other hand, currently All class derived from Layer call `Layer.Build` or | |||||
| // move the implementation of `Layer.build` to its own `build` method. | |||||
| // Therefore we do not call it here. | |||||
| // However, it's still quite risky once in the future a certain class derived from | |||||
| // `Layer` does not call `Layer.build`. | |||||
| return true; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||||
| /// <summary> | |||||
| /// Infers input shape of layer from SavedModel functions. | |||||
| /// </summary> | |||||
| /// <param name="layer_node_id"></param> | |||||
| /// <param name="convert_to_shapes"></param> | |||||
| /// <returns></returns> | |||||
| private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||||
| { | { | ||||
| if (obj.Built) | |||||
| return true; | |||||
| var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||||
| if(call_fn_id is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; | |||||
| if(concrete_functions is null) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var call_fn_name = concrete_functions[0]; | |||||
| var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| } | |||||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||||
| { | |||||
| if(path_to_child is null || path_to_child.Count() == 0) | |||||
| { | |||||
| return parent_id; | |||||
| } | |||||
| foreach(var child in _proto.Nodes[parent_id].Children) | |||||
| { | |||||
| if(child.LocalName == path_to_child.First()) | |||||
| { | |||||
| return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); | |||||
| } | |||||
| } | |||||
| return null; | |||||
| } | |||||
| private bool _is_graph_network(Layer layer) | |||||
| { | |||||
| // TODO: deal with `RevivedLayer` | |||||
| if(layer is Functional) | |||||
| { | |||||
| return (layer as Functional).IsGraphNetwork || layer is Sequential; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||||
| { | |||||
| // TODO: deal with `RevivedLayer` | |||||
| } | |||||
| /// <summary> | |||||
| /// Creates edges for nodes that are recreated from config. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| private Action<object, object, object> _config_node_setter(Action<object, object, object> setter) | |||||
| { | |||||
| void setattr_wrapper(object obj, object name, object value) | |||||
| { | |||||
| Debug.Assert(obj is Trackable); | |||||
| Debug.Assert(name is string); | |||||
| if((obj as Trackable)._lookup_dependency(name as string) is null) | |||||
| { | |||||
| setter(obj, name, value); | |||||
| } | |||||
| } | |||||
| return setattr_wrapper; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | public partial class KerasSavedModelUtils | ||||
| { | { | ||||
| public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||||
| public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||||
| SaveOptions? options, bool save_traces = true) | SaveOptions? options, bool save_traces = true) | ||||
| { | { | ||||
| if (!overwrite && File.Exists(filepath)) | if (!overwrite && File.Exists(filepath)) | ||||
| @@ -0,0 +1,96 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.IO; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Train; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| public class KerasLoadModelUtils | |||||
| { | |||||
| /// <summary> | |||||
| /// Corresponding to keras/saving/save.py/load_model | |||||
| /// </summary> | |||||
| /// <param name="filepath"></param> | |||||
| /// <param name="custom_objects"></param> | |||||
| /// <param name="compile"></param> | |||||
| /// <param name="options"></param> | |||||
| /// <returns></returns> | |||||
| public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||||
| bool compile = true, LoadOptions? options = null) | |||||
| { | |||||
| using (SharedObjectSavingScope.Enter()) | |||||
| { | |||||
| using (LoadContext.load_context(options)) | |||||
| { | |||||
| if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||||
| { | |||||
| throw new IOException($"No file or directory found at {filepath}."); | |||||
| } | |||||
| if (Directory.Exists(filepath)) | |||||
| { | |||||
| return load(filepath, compile, options); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| public static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
| { | |||||
| SavedMetadata metadata = new SavedMetadata(); | |||||
| var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
| string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||||
| if (File.Exists(path_to_metadata_pb)) | |||||
| { | |||||
| metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| } | |||||
| if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||||
| { | |||||
| return Loader.load(path, options: options) as Model; | |||||
| } | |||||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
| keras_loader.load_layers(compile: compile); | |||||
| Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||||
| nodes_to_load["root"] = (null, null); | |||||
| foreach(var item in keras_loader.LoadedNodes) | |||||
| { | |||||
| nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||||
| } | |||||
| var loaded = Loader.load_partial(path, nodes_to_load, options); | |||||
| keras_loader.finalize_objects(); | |||||
| // keras_loader.del_tracking(); | |||||
| var model = loaded["root"]; | |||||
| if(model is Model && compile) | |||||
| { | |||||
| // TODO: implement it. | |||||
| } | |||||
| if (!tf.Context.executing_eagerly()) | |||||
| { | |||||
| // TODO: implement it. | |||||
| } | |||||
| return model; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,69 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Threading; | |||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| // TODO: remove this class to common project. | |||||
| public class ContextHandler: IDisposable | |||||
| { | |||||
| public Action<bool> DisposeCallBack { get; set; } | |||||
| public void Dispose() | |||||
| { | |||||
| DisposeCallBack.Invoke(true); | |||||
| } | |||||
| } | |||||
| public class LoadContext | |||||
| { | |||||
| private bool _entered_load_context; | |||||
| private LoadOptions? _load_options; | |||||
| private static ThreadLocal<LoadContext> _load_context = new(); | |||||
| private LoadContext() | |||||
| { | |||||
| _entered_load_context = false; | |||||
| _load_options = null; | |||||
| } | |||||
| public void set_load_options(LoadOptions load_options) | |||||
| { | |||||
| _load_options = load_options; | |||||
| _entered_load_context = true; | |||||
| } | |||||
| private void clear_load_options() | |||||
| { | |||||
| _load_options = null; | |||||
| _entered_load_context = false; | |||||
| } | |||||
| private LoadOptions? load_options() | |||||
| { | |||||
| return _load_options; | |||||
| } | |||||
| public static ContextHandler load_context(LoadOptions? load_options) | |||||
| { | |||||
| if(_load_context.Value is null) | |||||
| { | |||||
| _load_context.Value = new LoadContext(); | |||||
| } | |||||
| _load_context.Value.set_load_options(load_options); | |||||
| return new ContextHandler() | |||||
| { | |||||
| DisposeCallBack = _ => _load_context.Value.clear_load_options() | |||||
| }; | |||||
| } | |||||
| public static LoadOptions? get_load_option() | |||||
| { | |||||
| return _load_context.Value.load_options(); | |||||
| } | |||||
| public static bool in_load_context() | |||||
| { | |||||
| return _load_context.Value._entered_load_context; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -22,12 +22,16 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
| { | { | ||||
| public class generic_utils | public class generic_utils | ||||
| { | { | ||||
| private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; | |||||
| /// <summary> | /// <summary> | ||||
| /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -51,6 +55,21 @@ namespace Tensorflow.Keras.Utils | |||||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | ||||
| } | } | ||||
| public static Layer deserialize_keras_object(string class_name, JObject config) | |||||
| { | |||||
| return class_name switch | |||||
| { | |||||
| "Sequential" => new Sequential(config.ToObject<SequentialArgs>()), | |||||
| "InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()), | |||||
| "Flatten" => new Flatten(config.ToObject<FlattenArgs>()), | |||||
| "ELU" => new ELU(config.ToObject<ELUArgs>()), | |||||
| "Dense" => new Dense(config.ToObject<DenseArgs>()), | |||||
| "Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()), | |||||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||||
| }; | |||||
| } | |||||
| public static string to_snake_case(string name) | public static string to_snake_case(string name) | ||||
| { | { | ||||
| return string.Concat(name.Select((x, i) => | return string.Concat(name.Select((x, i) => | ||||
| @@ -60,5 +79,15 @@ namespace Tensorflow.Keras.Utils | |||||
| x.ToString(); | x.ToString(); | ||||
| })).ToLower(); | })).ToLower(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Determines whether config appears to be a valid layer config. | |||||
| /// </summary> | |||||
| /// <param name="config"></param> | |||||
| /// <returns></returns> | |||||
| public static bool validate_config(JObject config) | |||||
| { | |||||
| return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,45 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||||
| [TestClass] | |||||
| public class SequentialModelLoad | |||||
| { | |||||
| [TestMethod] | |||||
| public void SimpleModelFromSequential() | |||||
| { | |||||
| var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||||
| Debug.Assert(model is Model); | |||||
| var m = model as Model; | |||||
| m.summary(); | |||||
| m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
| var data_loader = new MnistModelLoader(); | |||||
| var num_epochs = 1; | |||||
| var batch_size = 50; | |||||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
| { | |||||
| TrainDir = "mnist", | |||||
| OneHot = false, | |||||
| ValidationSize = 50000, | |||||
| }).Result; | |||||
| m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
| } | |||||
| } | |||||
| @@ -63,6 +63,8 @@ public class SequentialModelTest | |||||
| keras.layers.Softmax(1) | keras.layers.Softmax(1) | ||||
| }); | }); | ||||
| model.summary(); | |||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | ||||
| var data_loader = new MnistModelLoader(); | var data_loader = new MnistModelLoader(); | ||||