* Add CheckpointReader and corresponding C APIs. * Add essential components of SavedModel format loading. * Add checkpoint reading for SavedModel format loading. * Revise customized json converters. * Add support for loading models from python. * Fix the duplicated weights in Keras.Model. * Add alexnet loading test and check for loaded weights. * Fix ci error caused by branch merge. * Resolve the comments and errors. * Fix the stucking of training when loading model. * Fix the stucking of training when loading model. * fix intptr. --------- Co-authored-by: Haiping Chen <haiping008@gmail.com>tags/v0.100.4-load-saved-model
| @@ -149,4 +149,22 @@ public static class CheckPointUtils | |||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||
| // } | |||
| } | |||
| /// <summary> | |||
| /// Traverse the object graph and list all accessible objects. | |||
| /// </summary> | |||
| /// <param name="object_graph_view"></param> | |||
| public static IList<Trackable> list_objects(ObjectGraphView graph_view) | |||
| { | |||
| return objects_ids_and_slot_variables_and_paths(graph_view).Item1; | |||
| } | |||
| internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list) | |||
| { | |||
| return full_list.TakeWhile(x => | |||
| { | |||
| var saveables = x.gather_saveables_for_checkpoint(); | |||
| return saveables is not null && saveables.Count > 0; | |||
| }); | |||
| } | |||
| } | |||
| @@ -0,0 +1,100 @@ | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle | |||
| { | |||
| public SafeCheckpointReaderHandle(): base() | |||
| { | |||
| } | |||
| public SafeCheckpointReaderHandle(IntPtr handle): base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteCheckpointReader(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| public class CheckpointReader | |||
| { | |||
| private SafeCheckpointReaderHandle _handle; | |||
| public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } | |||
| public Dictionary<string, Shape> VariableToShapeMap { get; set; } | |||
| public CheckpointReader(string filename) | |||
| { | |||
| Status status = new Status(); | |||
| _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||
| status.Check(true); | |||
| ReadAllShapeAndType(); | |||
| } | |||
| public int HasTensor(string name) | |||
| { | |||
| return c_api.TF_CheckpointReaderHasTensor(_handle, name); | |||
| } | |||
| /// <summary> | |||
| /// Get the variable name. | |||
| /// </summary> | |||
| /// <param name="index"></param> | |||
| /// <returns></returns> | |||
| public string GetVariable(int index) | |||
| { | |||
| return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); | |||
| } | |||
| public int Size() | |||
| { | |||
| return c_api.TF_CheckpointReaderSize(_handle); | |||
| } | |||
| public TF_DataType GetVariableDataType(string name) | |||
| { | |||
| return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); | |||
| } | |||
| public Shape GetVariableShape(string name) | |||
| { | |||
| int num_dims = GetVariableNumDims(name); | |||
| long[] dims = new long[num_dims]; | |||
| Status status = new Status(); | |||
| c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||
| status.Check(true); | |||
| return new Shape(dims); | |||
| } | |||
| public int GetVariableNumDims(string name) | |||
| { | |||
| return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); | |||
| } | |||
| public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| Status status = new Status(); | |||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||
| status.Check(true); | |||
| return new Tensor(tensor); | |||
| } | |||
| private void ReadAllShapeAndType() | |||
| { | |||
| VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | |||
| VariableToShapeMap = new Dictionary<string, Shape>(); | |||
| int size = Size(); | |||
| for(int i = 0; i < size; i++) | |||
| { | |||
| var name = GetVariable(i); | |||
| var shape = GetVariableShape(name); | |||
| var dtype = GetVariableDataType(name); | |||
| VariableToDataTypeMap[name] = dtype; | |||
| VariableToShapeMap[name] = shape; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -175,9 +175,9 @@ public static class SaveUtilV1 | |||
| { | |||
| var name = factory_data.name; | |||
| var key = factory_data.checkpoint_key; | |||
| var maybe_saveable = factory_data.factory; | |||
| var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory); | |||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||
| // TODO: tensorflow python has a process with callable `saveable_factory`. | |||
| List<MySaveableObject> saveables = new(); | |||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
| { | |||
| @@ -217,7 +217,7 @@ public static class SaveUtilV1 | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||
| Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -0,0 +1,27 @@ | |||
| using System.Runtime.InteropServices; | |||
| using Tensorflow.Checkpoint; | |||
| namespace Tensorflow | |||
| { | |||
| public unsafe partial class c_api | |||
| { | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern void TF_DeleteCheckpointReader(IntPtr reader); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); | |||
| [DllImport(TensorFlowLibName)] | |||
| internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); | |||
| } | |||
| } | |||
| @@ -6,8 +6,12 @@ using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Operations; | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Checkpoint; | |||
| @@ -21,8 +25,20 @@ public class TrackableSaver | |||
| private TrackableObjectGraph _last_save_object_graph; | |||
| private Tensor? _object_graph_feed_tensor = null; | |||
| private Tensor? _file_prefix_feed_tensor = null; | |||
| private Tensor? _file_prefix_placeholder = null; | |||
| private Dictionary<Trackable, Trackable>? _object_map = null; | |||
| private object? _cache = null; | |||
| public Tensor? FilePrefixPlaceHolder | |||
| { | |||
| get | |||
| { | |||
| return _file_prefix_placeholder; | |||
| } | |||
| set | |||
| { | |||
| _file_prefix_placeholder = value; | |||
| } | |||
| } | |||
| public TrackableSaver(ObjectGraphView graph_view) | |||
| { | |||
| _graph_view = graph_view; | |||
| @@ -192,4 +208,366 @@ public class TrackableSaver | |||
| return save_path; | |||
| } | |||
| } | |||
| public LoadStatus restore(string? save_path, CheckpointOptions? options = null) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new CheckpointOptions(); | |||
| } | |||
| if(save_path is null) | |||
| { | |||
| return new InitializationOnlyStatus(_graph_view, ops.uid()); | |||
| } | |||
| CheckpointReader reader = new CheckpointReader(save_path); | |||
| bool graph_building = tf.Context.executing_eagerly(); | |||
| Dictionary<string, TF_DataType> dtype_map = null; | |||
| if (!graph_building) | |||
| { | |||
| dtype_map = reader.VariableToDataTypeMap; | |||
| } | |||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
| Dictionary<Tensor, string> file_prefix_feed_dict; | |||
| Tensor file_prefix_tensor; | |||
| if (graph_building) | |||
| { | |||
| if(_file_prefix_placeholder is null) | |||
| { | |||
| tf.device("/cpu:0"); | |||
| _file_prefix_placeholder = constant_op.constant("model"); | |||
| } | |||
| file_prefix_tensor = _file_prefix_placeholder; | |||
| file_prefix_feed_dict = new(); | |||
| file_prefix_feed_dict[_file_prefix_placeholder] = save_path; | |||
| } | |||
| else | |||
| { | |||
| tf.device("/cpu:0"); | |||
| file_prefix_tensor = constant_op.constant(save_path); | |||
| file_prefix_feed_dict = null; | |||
| } | |||
| TrackableObjectGraph object_graph_proto = new(); | |||
| if(object_graph_string.ndim > 0) | |||
| { | |||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||
| } | |||
| else | |||
| { | |||
| object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); | |||
| } | |||
| CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||
| object_graph_proto: object_graph_proto, | |||
| save_path: save_path, | |||
| save_path_tensor: file_prefix_tensor, | |||
| reader: reader, | |||
| restore_op_cache: null, | |||
| graph_view: _graph_view, | |||
| options: options, | |||
| saveables_cache: null | |||
| ); | |||
| new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root); | |||
| if(_graph_view.AttachedDependencies is not null) | |||
| { | |||
| foreach(var refer in _graph_view.AttachedDependencies) | |||
| { | |||
| if(refer.Name == "root") | |||
| { | |||
| continue; | |||
| } | |||
| int? proto_id = null; | |||
| // Find proto ID of attached dependency (if it is in the proto). | |||
| foreach (var proto_refer in object_graph_proto.Nodes[0].Children) | |||
| { | |||
| if(proto_refer.LocalName == refer.Name) | |||
| { | |||
| proto_id = proto_refer.NodeId; | |||
| break; | |||
| } | |||
| } | |||
| if (proto_id is null) | |||
| { | |||
| continue; | |||
| } | |||
| // Object has already been restored. This can happen when there's an | |||
| // indirect connection from the attached object to the root. | |||
| if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value)) | |||
| { | |||
| continue; | |||
| } | |||
| new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer); | |||
| } | |||
| } | |||
| return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view); | |||
| } | |||
| } | |||
| public class CheckpointRestoreCoordinator | |||
| { | |||
| private CheckpointOptions _options; | |||
| private TrackableObjectGraph _object_graph_proto; | |||
| private int _restore_uid; | |||
| private HashSet<int> _matched_proto_ids; | |||
| private Tensor _save_path_tensor; | |||
| private string _save_path_string; | |||
| private CheckpointReader _reader; | |||
| private Dictionary<string, TF_DataType> _dtype_map; | |||
| private Dictionary<string, Shape> _shape_map; | |||
| private ObjectGraphView _graph_view; | |||
| private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations; | |||
| private bool _expect_partial_attr; | |||
| private List<Operation> _restore_ops; | |||
| private List<Trackable> _all_trackables; | |||
| private Dictionary<int, Trackable> _object_by_proto_id; | |||
| private Dictionary<string, Operation> _restore_ops_by_name; | |||
| private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations; | |||
| private Dictionary<int, IList<string>> _unused_attributes; | |||
| public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, | |||
| CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) | |||
| { | |||
| // TODO(Rinne): cache. | |||
| _options = options; | |||
| _object_graph_proto = object_graph_proto; | |||
| _restore_uid = ops.uid(); | |||
| _save_path_tensor = save_path_tensor; | |||
| _save_path_string = save_path; | |||
| _reader = reader; | |||
| if(_reader is null) | |||
| { | |||
| _reader = new CheckpointReader(save_path); | |||
| } | |||
| _dtype_map = _reader.VariableToDataTypeMap; | |||
| _shape_map = _reader.VariableToShapeMap; | |||
| _graph_view = graph_view; | |||
| _restore_ops = new List<Operation>(); | |||
| _restore_ops_by_name = new Dictionary<string, Operation>(); | |||
| _all_trackables = new List<Trackable>(); | |||
| _matched_proto_ids = new HashSet<int>(); | |||
| _object_by_proto_id = new Dictionary<int, Trackable>(); | |||
| _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); | |||
| _deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>(); | |||
| _expect_partial_attr = false; | |||
| for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) | |||
| { | |||
| var node = _object_graph_proto.Nodes[i]; | |||
| foreach(var slot_reference in node.SlotVariables) | |||
| { | |||
| _slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>()) | |||
| .Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName)); | |||
| } | |||
| } | |||
| // skip the deleter and cache. | |||
| } | |||
| public bool ExpectPartial | |||
| { | |||
| get | |||
| { | |||
| return _expect_partial_attr; | |||
| } | |||
| set | |||
| { | |||
| _expect_partial_attr = value; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Corresponding to `all_python_objects` of tensorflow python | |||
| /// </summary> | |||
| public List<Trackable> AllTrackables => _all_trackables; | |||
| public HashSet<int> MatchedProtoIds => _matched_proto_ids; | |||
| public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; | |||
| public int RestoreUid => _restore_uid; | |||
| public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; | |||
| public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations; | |||
| public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations; | |||
| public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name; | |||
| public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes; | |||
| public void new_restore_ops(IEnumerable<Operation> new_ops) | |||
| { | |||
| _restore_ops.AddRange(new_ops); | |||
| // skip the callback. | |||
| } | |||
| public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null) | |||
| { | |||
| List<Operation> restore_ops = new(); | |||
| foreach(var position in positions) | |||
| { | |||
| var key = position.ObjectProto.Attributes[0].CheckpointKey; | |||
| throw new NotImplementedException(); | |||
| } | |||
| Dictionary<string, BaseResourceVariable> variable_dict = new(); | |||
| foreach(var item in tensor_saveables) | |||
| { | |||
| if(item.Value.TryGet<BaseResourceVariable>(out var variable)) | |||
| { | |||
| variable_dict[item.Key] = variable; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError(); | |||
| } | |||
| } | |||
| if (tensor_saveables is not null && tensor_saveables.Count > 0) | |||
| { | |||
| var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict); | |||
| var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options); | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| foreach(var item in new_restore_ops) | |||
| { | |||
| restore_ops.Add(item.Value); | |||
| Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key)); | |||
| _restore_ops_by_name[item.Key] = item.Value; | |||
| } | |||
| } | |||
| } | |||
| return restore_ops; | |||
| } | |||
| } | |||
| public abstract class LoadStatus | |||
| { | |||
| public abstract LoadStatus assert_consumed(); | |||
| public abstract LoadStatus assert_existing_objects_matched(); | |||
| public abstract LoadStatus assert_nontrivial_match(); | |||
| public abstract LoadStatus run_restore_ops(Session? session = null); | |||
| public abstract void initialize_or_restore(Session? session = null); | |||
| public virtual LoadStatus expect_partial() | |||
| { | |||
| return this; | |||
| } | |||
| } | |||
| public class InitializationOnlyStatus: LoadStatus | |||
| { | |||
| private int _restore_uid; | |||
| private ObjectGraphView _object_graph_view; | |||
| private Trackable _root; | |||
| public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid) | |||
| { | |||
| _restore_uid = restore_uid; | |||
| _object_graph_view = object_graph_view; | |||
| _root = object_graph_view.Root; | |||
| } | |||
| public override LoadStatus assert_consumed() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override LoadStatus assert_existing_objects_matched() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override LoadStatus assert_nontrivial_match() | |||
| { | |||
| throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); | |||
| } | |||
| public override LoadStatus run_restore_ops(Session? session = null) | |||
| { | |||
| throw new AssertionError("No checkpoint specified, so no restore ops are available " | |||
| + "(save_path=None to Saver.restore)."); | |||
| } | |||
| public override void initialize_or_restore(Session? session = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return; | |||
| } | |||
| if(session is null) | |||
| { | |||
| session = new Session(); | |||
| } | |||
| var trackable_objects = CheckPointUtils.list_objects(_object_graph_view); | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| internal class CheckpointLoadStatus: LoadStatus | |||
| { | |||
| private CheckpointRestoreCoordinator _checkpoint; | |||
| private Dictionary<Tensor, string> _feed_dict; | |||
| private ObjectGraphView _object_graph_view; | |||
| private Trackable _root; | |||
| public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base() | |||
| { | |||
| _checkpoint = checkpoint; | |||
| _feed_dict = feed_dict; | |||
| _object_graph_view = graph_view; | |||
| _root = graph_view.Root; | |||
| } | |||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
| public override LoadStatus assert_consumed() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus assert_existing_objects_matched() | |||
| { | |||
| for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++) | |||
| { | |||
| var node = _checkpoint.ObjectGraphProto.Nodes[i]; | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) && | |||
| trackable.UpdateUid < _checkpoint.RestoreUid) | |||
| { | |||
| throw new AssertionError($"Object {node} not assigned a value from checkpoint."); | |||
| } | |||
| } | |||
| foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view)) | |||
| { | |||
| if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0) | |||
| { | |||
| continue; | |||
| } | |||
| _checkpoint.AllTrackables.Add(trackable_object); | |||
| } | |||
| var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables) | |||
| .Except(_checkpoint.ObjectByProtoId.Values); | |||
| if (unused_trackables.Any()) | |||
| { | |||
| var num_unused_trackables = unused_trackables.Count(); | |||
| var num_variables_to_show = Math.Min(10, num_unused_trackables); | |||
| throw new AssertionError($"Found {num_unused_trackables} Python objects that were " + | |||
| $"not bound to checkpointed values, likely due to changes in the " + | |||
| $"Python program. Showing {num_variables_to_show} of " + | |||
| $"{num_unused_trackables} unmatched objects: " + | |||
| $"{{list(unused_python_objects)[:num_variables_to_show]}}"); | |||
| } | |||
| return this; | |||
| } | |||
| public override LoadStatus assert_nontrivial_match() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus expect_partial() | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override void initialize_or_restore(Session? session = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public override LoadStatus run_restore_ops(Session? session = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| @@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint | |||
| // tf python has code `with ops.device(restore_device):` here. | |||
| tf.device(restore_device); // may be risky. | |||
| var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
| int idx = 0; | |||
| @@ -0,0 +1,331 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Checkpoint; | |||
| public class CheckpointPosition | |||
| { | |||
| private CheckpointRestoreCoordinator _checkpoint; | |||
| private int _proto_id; | |||
| private bool _skip_restore; | |||
| public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id) | |||
| { | |||
| _checkpoint = checkpoint; | |||
| _proto_id = proto_id; | |||
| _skip_restore = false; | |||
| } | |||
| public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; | |||
| public CheckpointRestoreCoordinator Checkpoint => _checkpoint; | |||
| public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id]; | |||
| public void restore(Trackable trackable) | |||
| { | |||
| using (ops.init_scope()) | |||
| { | |||
| if (bind_project(trackable)) | |||
| { | |||
| var restore_ops = _restore_descendants(); | |||
| if(restore_ops is not null && restore_ops.Count > 0) | |||
| { | |||
| _checkpoint.new_restore_ops(restore_ops); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Set a checkpoint<->object correspondence. | |||
| /// </summary> | |||
| /// <param name="trackable"></param> | |||
| /// <returns></returns> | |||
| public bool bind_project(Trackable trackable) | |||
| { | |||
| _checkpoint.AllTrackables.Add(trackable); | |||
| _checkpoint.MatchedProtoIds.Add(_proto_id); | |||
| if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) | |||
| { | |||
| // skip the `logging.warning`. | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| _checkpoint.ObjectByProtoId[_proto_id] = trackable; | |||
| return true; | |||
| } | |||
| } | |||
| public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables() | |||
| { | |||
| // skip the registered_saver | |||
| if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable); | |||
| List<Operation> existing_restore_ops; | |||
| List<CheckpointPosition> positions = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables; | |||
| if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); | |||
| } | |||
| else if(saveable_factories.Count > 0) | |||
| { | |||
| (existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| return (existing_restore_ops, named_saveables, positions, null); | |||
| } | |||
| public CheckpointPosition create_child_position(int node_id) | |||
| { | |||
| return new CheckpointPosition(_checkpoint, node_id); | |||
| } | |||
| public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable, | |||
| int slot_variable_id, string slot_name) | |||
| { | |||
| //CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id); | |||
| // TODO(Rinne): implement it. | |||
| return (null, null); | |||
| } | |||
| /// <summary> | |||
| /// Creates a saveable using the _serialize_to_tensor method. | |||
| /// </summary> | |||
| /// <param name="saveable_factories"></param> | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| string suffix = SaveableCompat.get_saveable_name(this.Trackable); | |||
| suffix = suffix ?? ""; | |||
| var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix; | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("The restore under graph mode has not been implemented. " + | |||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); | |||
| // skip the cache. | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new(); | |||
| dict[saveable_name] = saveable; | |||
| return (new List<Operation>(), dict); | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name( | |||
| IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| if(ObjectProto.Attributes is null) | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>()); | |||
| } | |||
| List<Operation> existing_restore_ops = new(); | |||
| HashSet<string> created_compat_names = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new(); | |||
| foreach (var serialized_tensor in ObjectProto.Attributes) | |||
| { | |||
| Operation existing_op; | |||
| if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey)) | |||
| { | |||
| existing_op = null; | |||
| } | |||
| else | |||
| { | |||
| existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey]; | |||
| } | |||
| if(existing_op is not null) | |||
| { | |||
| existing_restore_ops.Add(existing_op); | |||
| continue; | |||
| } | |||
| if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x))) | |||
| { | |||
| continue; | |||
| } | |||
| // TODO(Rinne): deal with cache. | |||
| var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names); | |||
| if(saveable is null) | |||
| { | |||
| _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name); | |||
| continue; | |||
| } | |||
| named_saveables[serialized_tensor.CheckpointKey] = saveable; | |||
| } | |||
| return (existing_restore_ops, named_saveables); | |||
| } | |||
| private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories, | |||
| TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names) | |||
| { | |||
| var expected_factory_name = serialized_tensor.Name; | |||
| var factory_input_name = serialized_tensor.CheckpointKey; | |||
| if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory)) | |||
| { | |||
| foreach(var item in saveable_factories) | |||
| { | |||
| var factory_name = item.Key; | |||
| var factory = item.Value; | |||
| if (expected_factory_name.StartsWith(factory_name)) | |||
| { | |||
| if(matched_factory is not null) | |||
| { | |||
| throw new ValueError($"Forward compatibility load error: Unable to load " + | |||
| "checkpoint saved in future version of TensorFlow. " + | |||
| "Please update your version of TensorFlow to the " + | |||
| "version in which the checkpoint was saved."); | |||
| } | |||
| } | |||
| matched_factory = factory; | |||
| factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name; | |||
| created_compat_names.Add(factory_name); | |||
| } | |||
| } | |||
| return matched_factory(factory_input_name); | |||
| } | |||
| private string _extract_saveable_name(string checkpoint_key) | |||
| { | |||
| var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/"; | |||
| return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length); | |||
| } | |||
| /// <summary> | |||
| /// Restore the bound Trackable and dependencies (may be deferred). | |||
| /// </summary> | |||
| private List<Operation> _restore_descendants() | |||
| { | |||
| Queue<(CheckpointPosition, Trackable)> visit_queue = new(); | |||
| visit_queue.Enqueue((this, this.Trackable)); | |||
| List<Operation> restore_ops = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new(); | |||
| List<CheckpointPosition> positions = new(); | |||
| CheckpointPosition current_position = null; | |||
| while (visit_queue.Count > 0) | |||
| { | |||
| current_position = visit_queue.Dequeue().Item1; | |||
| var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore(); | |||
| restore_ops.AddRange(new_restore_ops); | |||
| foreach(var item in new_tensor_saveables) | |||
| { | |||
| tensor_saveables.Add(item.Key, item.Value); | |||
| } | |||
| positions.AddRange(new_positions); | |||
| _queue_children_for_restoration(current_position, visit_queue); | |||
| _queue_slot_variables(current_position, visit_queue); | |||
| } | |||
| restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null)); | |||
| return restore_ops; | |||
| } | |||
| private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
| { | |||
| var trackable = checkpoint_position.Trackable; | |||
| foreach(var child in checkpoint_position.ObjectProto.Children) | |||
| { | |||
| var child_position = checkpoint_position.create_child_position(child.NodeId); | |||
| var local_object = trackable._lookup_dependency(child.LocalName); | |||
| var child_proto = child_position.ObjectProto; | |||
| if(local_object is null) | |||
| { | |||
| if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any()) | |||
| { | |||
| trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (child_position.bind_project(local_object)) | |||
| { | |||
| visit_queue.Enqueue((child_position, local_object)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue) | |||
| { | |||
| var trackable = checkpoint_position.Trackable; | |||
| var checkpoint = checkpoint_position.Checkpoint; | |||
| if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions)) | |||
| { | |||
| checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id); | |||
| foreach (var deferred_slot_restoration in positions) | |||
| { | |||
| var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position( | |||
| trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId, | |||
| deferred_slot_restoration.SlotName | |||
| ); | |||
| if(slot_variable_position is not null) | |||
| { | |||
| visit_queue.Enqueue((slot_variable_position, slot_variable)); | |||
| } | |||
| } | |||
| } | |||
| if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations)) | |||
| { | |||
| checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id); | |||
| foreach (var slot_restoration in restorations) | |||
| { | |||
| if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| // TODO(Rinne); implement it. | |||
| } | |||
| else | |||
| { | |||
| Debug.Assert(trackable is BaseResourceVariable); | |||
| Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>()) | |||
| .Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore() | |||
| { | |||
| var trackable = this.Trackable; | |||
| trackable._maybe_initialize_trackable(); | |||
| if(_checkpoint.RestoreUid > trackable.UpdateUid) | |||
| { | |||
| var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables(); | |||
| trackable.UpdateUid = _checkpoint.RestoreUid; | |||
| return (restore_ops, tensor_saveables, positions, registered_savers); | |||
| } | |||
| else | |||
| { | |||
| return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(), | |||
| new List<CheckpointPosition>(), null); | |||
| } | |||
| } | |||
| } | |||
| public record class DeferredSlotVariableRestoration( | |||
| BaseResourceVariable OriginalVariable, | |||
| int SlotVariableId, | |||
| string SlotName | |||
| ); | |||
| @@ -10,7 +10,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| internal class execute | |||
| internal static class execute | |||
| { | |||
| public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||
| { | |||
| @@ -27,5 +27,9 @@ namespace Tensorflow.Eager | |||
| return tensors; | |||
| } | |||
| public static bool must_record_gradient() | |||
| { | |||
| return false; | |||
| } | |||
| } | |||
| } | |||
| @@ -13,8 +13,8 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| FuncGraph func_graph; | |||
| ForwardBackwardCall forward_backward; | |||
| internal FuncGraph func_graph; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| @@ -23,6 +23,8 @@ namespace Tensorflow.Functions | |||
| public Tensor[] Outputs; | |||
| public Type ReturnType; | |||
| public TensorSpec[] OutputStructure; | |||
| public IEnumerable<string> ArgKeywords { get; set; } | |||
| public long NumPositionArgs { get; set; } | |||
| public ConcreteFunction(string name) | |||
| { | |||
| @@ -163,6 +165,15 @@ namespace Tensorflow.Functions | |||
| return flat_outputs; | |||
| } | |||
| public void AddTograph(Graph? g = null) | |||
| { | |||
| if(!tf.Context.executing_eagerly() && g is null) | |||
| { | |||
| g = ops.get_default_graph(); | |||
| } | |||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| @@ -16,8 +16,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.IO | |||
| { | |||
| @@ -63,5 +65,15 @@ namespace Tensorflow.IO | |||
| dirs.AddRange(Directory.GetFiles(dir)); | |||
| return dirs.ToArray(); | |||
| } | |||
| public string join(params string[] paths) | |||
| { | |||
| Debug.Assert(paths.Length >= 1); | |||
| if (paths[0].Substring(1).Contains("://")) | |||
| { | |||
| throw new NotImplementedException("The combination of urls has not been implemented."); | |||
| } | |||
| return Path.Combine(paths); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| var axis = serializer.Deserialize(reader, typeof(long[])); | |||
| int[]? axis; | |||
| if(reader.ValueType == typeof(long)) | |||
| { | |||
| axis = new int[1]; | |||
| axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||
| } | |||
| else | |||
| { | |||
| axis = serializer.Deserialize(reader, typeof(int[])) as int[]; | |||
| } | |||
| if (axis is null) | |||
| { | |||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||
| @@ -0,0 +1,36 @@ | |||
| using Newtonsoft.Json.Linq; | |||
| using Newtonsoft.Json; | |||
| namespace Tensorflow.Keras.Common | |||
| { | |||
| public class CustomizedDTypeJsonConverter : JsonConverter | |||
| { | |||
| public override bool CanConvert(Type objectType) | |||
| { | |||
| return objectType == typeof(TF_DataType); | |||
| } | |||
| public override bool CanRead => true; | |||
| public override bool CanWrite => true; | |||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
| { | |||
| var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); | |||
| token.WriteTo(writer); | |||
| } | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| if (reader.ValueType == typeof(string)) | |||
| { | |||
| var str = (string)serializer.Deserialize(reader, typeof(string)); | |||
| return dtypes.tf_dtype_from_name(str); | |||
| } | |||
| else | |||
| { | |||
| return (TF_DataType)serializer.Deserialize(reader, typeof(int)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common | |||
| { | |||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
| } | |||
| if(values.Length != 3) | |||
| if(values.Length == 1) | |||
| { | |||
| var array = values[0] as JArray; | |||
| if(array is null) | |||
| { | |||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
| } | |||
| values = array.ToObject<object[]>(); | |||
| } | |||
| if (values.Length < 3) | |||
| { | |||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
| } | |||
| @@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||
| } | |||
| if (values[1] is not int) | |||
| int nodeIndex; | |||
| int tensorIndex; | |||
| if (values[1] is long) | |||
| { | |||
| nodeIndex = (int)(long)values[1]; | |||
| } | |||
| else if (values[1] is int) | |||
| { | |||
| nodeIndex = (int)values[1]; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||
| } | |||
| if (values[2] is not int) | |||
| if (values[2] is long) | |||
| { | |||
| tensorIndex = (int)(long)values[2]; | |||
| } | |||
| else if (values[1] is int) | |||
| { | |||
| tensorIndex = (int)values[2]; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||
| } | |||
| return new NodeConfig() | |||
| { | |||
| Name = values[0] as string, | |||
| NodeIndex = (int)values[1], | |||
| TensorIndex = (int)values[2] | |||
| NodeIndex = nodeIndex, | |||
| TensorIndex = tensorIndex | |||
| }; | |||
| } | |||
| } | |||
| @@ -51,10 +51,28 @@ namespace Tensorflow.Keras.Common | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
| if(dims is null) | |||
| long?[] dims; | |||
| try | |||
| { | |||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
| } | |||
| catch (JsonSerializationException ex) | |||
| { | |||
| if (reader.Value.Equals("class_name")) | |||
| { | |||
| reader.Read(); | |||
| reader.Read(); | |||
| reader.Read(); | |||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||
| } | |||
| else | |||
| { | |||
| throw ex; | |||
| } | |||
| } | |||
| if (dims is null) | |||
| { | |||
| return null; | |||
| } | |||
| long[] convertedDims = new long[dims.Length]; | |||
| for(int i = 0; i < dims.Length; i++) | |||
| @@ -19,6 +19,7 @@ namespace Tensorflow.Keras | |||
| List<IVariableV1> TrainableVariables { get; } | |||
| List<IVariableV1> TrainableWeights { get; } | |||
| List<IVariableV1> NonTrainableWeights { get; } | |||
| List<IVariableV1> Weights { get; } | |||
| Shape OutputShape { get; } | |||
| Shape BatchInputShape { get; } | |||
| TensorShapeConfig BuildInputShape { get; } | |||
| @@ -1,8 +1,11 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
| namespace Tensorflow.Keras.Saving | |||
| { | |||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow.ModelSaving | |||
| { | |||
| @@ -71,6 +71,7 @@ namespace Tensorflow | |||
| public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||
| public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||
| public List<IVariableV1> Weights => throw new NotImplementedException(); | |||
| public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||
| public Shape OutputShape => throw new NotImplementedException(); | |||
| @@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations | |||
| /// | |||
| /// Callers must ensure all the named tensors are indeed stored in the checkpoint. | |||
| /// </remarks> | |||
| public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
| public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| Dictionary<string, object> attrs = new(); | |||
| attrs["dtypes"] = dtypes; | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
| "RestoreV2", name, prefix, tensor_names, shape_and_slices | |||
| ) | |||
| { attrs = attrs }); | |||
| return result; | |||
| } | |||
| catch (Exception) | |||
| { | |||
| try | |||
| { | |||
| return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx); | |||
| } | |||
| catch (Exception) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["prefix"] = prefix; | |||
| dict["tensor_names"] = tensor_names; | |||
| @@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations | |||
| return (tensors); | |||
| } | |||
| public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx) | |||
| { | |||
| prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||
| var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||
| var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
| object[] attrs = new object[] { "dtypes", dtypes }; | |||
| Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||
| var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||
| if (execute.must_record_gradient()) | |||
| { | |||
| // TODO(Rinne); record the gradient | |||
| } | |||
| return result; | |||
| } | |||
| /// <summary> | |||
| /// Reverses specific dimensions of a tensor. | |||
| /// </summary> | |||
| @@ -62,6 +62,7 @@ namespace Tensorflow | |||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
| { | |||
| // Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`. | |||
| var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||
| return _op.outputs; | |||
| @@ -17,8 +17,8 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| @@ -1,9 +1,13 @@ | |||
| namespace Tensorflow | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Keras.Common; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||
| /// The enum values here are identical to corresponding values in types.proto. | |||
| /// </summary> | |||
| [JsonConverter(typeof(CustomizedDTypeJsonConverter))] | |||
| public enum TF_DataType | |||
| { | |||
| DtInvalid = 0, | |||
| @@ -159,7 +159,10 @@ namespace Tensorflow | |||
| "uint32" => TF_DataType.TF_UINT32, | |||
| "int64" => TF_DataType.TF_INT64, | |||
| "uint64" => TF_DataType.TF_UINT64, | |||
| "float16" => TF_DataType.TF_BFLOAT16, | |||
| "float32" => TF_DataType.TF_FLOAT, | |||
| "single" => TF_DataType.TF_FLOAT, | |||
| "float64" => TF_DataType.TF_DOUBLE, | |||
| "double" => TF_DataType.TF_DOUBLE, | |||
| "complex" => TF_DataType.TF_COMPLEX128, | |||
| "string" => TF_DataType.TF_STRING, | |||
| @@ -39,6 +39,24 @@ namespace Tensorflow | |||
| _op = value; | |||
| } | |||
| } | |||
| public BaseResourceVariable variable | |||
| { | |||
| get | |||
| { | |||
| if (_op.TryGet<BaseResourceVariable>(out var v)) | |||
| { | |||
| return v; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("The _op is not a variable."); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| _op = value; | |||
| } | |||
| } | |||
| public SaveSpec[] specs; | |||
| public string name; | |||
| public string device; | |||
| @@ -0,0 +1,23 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public record class LoadOptions | |||
| { | |||
| public bool allow_partial_checkpoint; | |||
| public string experimental_io_device; | |||
| public bool experimental_skip_checkpoint; | |||
| public VariablePolicy experimental_variable_policy; | |||
| public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null, | |||
| bool experimental_skip_checkpoint = false, string experimental_variable_policy = null) | |||
| { | |||
| this.allow_partial_checkpoint = allow_partial_checkpoint; | |||
| this.experimental_io_device = experimental_io_device; | |||
| this.experimental_skip_checkpoint = experimental_skip_checkpoint; | |||
| this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.Train; | |||
| using System; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow; | |||
| @@ -14,4 +15,10 @@ public class RevivedTypes | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto) | |||
| { | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| } | |||
| @@ -2,7 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.ModelSaving | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Options for saving to SavedModel. | |||
| @@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving | |||
| public bool save_variable_devices() | |||
| { | |||
| return this != VariablePolicy.None; | |||
| return this != None; | |||
| } | |||
| /// <summary> | |||
| @@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving | |||
| /// <returns></returns> | |||
| public static VariablePolicy from_obj(object obj) | |||
| { | |||
| if (obj is null) return VariablePolicy.None; | |||
| if (obj is null) return None; | |||
| if (obj is VariablePolicy) return (VariablePolicy)obj; | |||
| var key = obj.ToString().ToLower(); | |||
| return key switch | |||
| { | |||
| null => VariablePolicy.None, | |||
| "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||
| "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||
| null => None, | |||
| "save_variable_devices" => SAVE_VARIABLE_DEVICES, | |||
| "expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES, | |||
| _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||
| }; | |||
| } | |||
| @@ -5,7 +5,6 @@ using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| /// <summary> | |||
| /// A class wraps a concrete function to handle different distributed contexts. | |||
| /// </summary> | |||
| internal class WrapperFunction: ConcreteFunction | |||
| { | |||
| public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||
| { | |||
| this.forward_backward = concrete_function.forward_backward; | |||
| this.Outputs = concrete_function.Outputs; | |||
| this.ReturnType = concrete_function.ReturnType; | |||
| this.OutputStructure = concrete_function.OutputStructure; | |||
| this.ArgKeywords = concrete_function.ArgKeywords; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| public static class function_deserialization | |||
| { | |||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; | |||
| concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
| concrete_function.AddTograph(); | |||
| return concrete_function; | |||
| } | |||
| private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) | |||
| { | |||
| // TODO(Rinne); revise the implementation. | |||
| return new FunctionSpec() | |||
| { | |||
| Fullargspec = function_spec_proto.Fullargspec, | |||
| IsMethod = function_spec_proto.IsMethod, | |||
| InputSignature = function_spec_proto.InputSignature, | |||
| JitCompile = function_spec_proto.JitCompile | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,641 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Net.Sockets; | |||
| using System.Text; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Variables; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// Helper class to load an object-based SavedModel. | |||
| /// </summary> | |||
| public partial class Loader | |||
| { | |||
| private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def; | |||
| private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes; | |||
| private SavedObjectGraph _proto; | |||
| private string _export_dir; | |||
| private CheckpointOptions _checkpoint_options; | |||
| private LoadOptions _save_options; | |||
| private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters; | |||
| private Dictionary<string, int>? _node_path_to_id; | |||
| private List<int>? _filtered_nodes; | |||
| private List<int> _ordered_node_ids; | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
| private List<Trackable> _nodes; | |||
| private Dictionary<int, Action<object, object, object>> _node_setters; | |||
| public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir, | |||
| CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters) | |||
| { | |||
| var meta_graph = saved_model_proto.MetaGraphs[0]; | |||
| _asset_file_def = meta_graph.AssetFileDef; | |||
| _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); | |||
| _proto = object_graph_proto; | |||
| _export_dir = export_dir; | |||
| // TODO: `this._concrete_functions` and `this._restored_concrete_functions` | |||
| _checkpoint_options = ckpt_options; | |||
| _save_options = save_options; | |||
| // TODO: `this._pretty_printer` | |||
| _node_filters = filters; | |||
| _node_path_to_id = _convert_node_paths_to_ints(); | |||
| _loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
| foreach(var filter in filters) | |||
| { | |||
| _loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value; | |||
| } | |||
| _filtered_nodes = _retrieve_all_filtered_nodes(); | |||
| _ordered_node_ids = _generate_ordered_node_ids(); | |||
| _load_all(); | |||
| if (!save_options.experimental_skip_checkpoint) | |||
| { | |||
| _restore_checkpoint(); | |||
| } | |||
| foreach(var node in _nodes) | |||
| { | |||
| // skip the process of `CapturableResource`. | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Maps all string node paths in node_filters to the int node ids. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private Dictionary<string, int>? _convert_node_paths_to_ints() | |||
| { | |||
| if( _node_filters is null) | |||
| { | |||
| return null; | |||
| } | |||
| Dictionary<string, int> path_to_int = new(); | |||
| foreach(var node_id in _node_filters.Keys) | |||
| { | |||
| int int_node_id; | |||
| var node_path = node_id.Split('.'); | |||
| if (node_path[0] != "root") | |||
| { | |||
| throw new ValueError($"When passing string identifiers to node_filters, the first name" + | |||
| $" must be root. Received {node_path[0]}."); | |||
| } | |||
| int_node_id = 0; | |||
| for(int i = 0; i < node_path.Length - 1; i++) | |||
| { | |||
| var name = node_path[i + 1]; | |||
| int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1))); | |||
| } | |||
| path_to_int[node_id] = int_node_id; | |||
| } | |||
| return path_to_int; | |||
| } | |||
| private int _find_node_child(int node_id, string child_name, string path) | |||
| { | |||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||
| { | |||
| if(refer.LocalName == child_name) | |||
| { | |||
| return refer.NodeId; | |||
| } | |||
| } | |||
| throw new ValueError($"Unable to find node {path}."); | |||
| } | |||
| private List<int>? _retrieve_all_filtered_nodes() | |||
| { | |||
| if(_node_filters is null) | |||
| { | |||
| return null; | |||
| } | |||
| HashSet<int> all_filtered_nodes = new(); | |||
| Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys); | |||
| while(nodes_to_visit.Count > 0) | |||
| { | |||
| var node_path = nodes_to_visit.Dequeue(); | |||
| var node_id = _node_path_to_id[node_path]; | |||
| if (all_filtered_nodes.Contains(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| all_filtered_nodes.Add(node_id); | |||
| Trackable node = null; | |||
| Action<object, object, object> setter = null; | |||
| if(_loaded_nodes.TryGetValue(node_id, out var res)) | |||
| { | |||
| (node, setter) = res; | |||
| } | |||
| if(node is not null) | |||
| { | |||
| node._maybe_initialize_trackable(); | |||
| } | |||
| foreach(var refer in _proto.Nodes[node_id].Children) | |||
| { | |||
| Trackable children_object = null; | |||
| if(_loaded_nodes.TryGetValue(refer.NodeId, out var result)) | |||
| { | |||
| children_object = result.Item1; | |||
| } | |||
| // See if node already tracks the child reference, in which case add the child to the loaded_nodes dict. | |||
| if(children_object is null && node is not null) | |||
| { | |||
| children_object = node._lookup_dependency(refer.LocalName); | |||
| if(children_object is TrackableDataStructure) | |||
| { | |||
| // TODO: set setter as lambda. | |||
| _loaded_nodes[refer.NodeId] = (children_object, setter); | |||
| } | |||
| } | |||
| string child_path = $"{node_path}.{refer.LocalName}"; | |||
| _node_path_to_id[child_path] = refer.NodeId; | |||
| nodes_to_visit.Enqueue(child_path); | |||
| } | |||
| } | |||
| if (all_filtered_nodes.Contains(0)) | |||
| { | |||
| return null; | |||
| } | |||
| return all_filtered_nodes.ToList(); | |||
| } | |||
| /// <summary> | |||
| /// Orders the node ids so that dependencies appear first. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private List<int> _generate_ordered_node_ids() | |||
| { | |||
| List<int> unordered_ids; | |||
| if(_filtered_nodes is null) | |||
| { | |||
| unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList(); | |||
| } | |||
| else | |||
| { | |||
| unordered_ids = new List<int>(_filtered_nodes); | |||
| } | |||
| Dictionary<int, List<int>> dependency_map = new(); | |||
| foreach(var node_id in unordered_ids) | |||
| { | |||
| var deps = dependency_map.SetDefault(node_id, new List<int>()); | |||
| if (_loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| var proto = _proto.Nodes[node_id]; | |||
| foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
| { | |||
| deps.Add(dep); | |||
| if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||
| { | |||
| // TODO: add info with `_pretty_printer`. | |||
| throw new ValueError($"Unable to partially load SavedModel since the specified filter " + | |||
| $"does not include all required objects for loading (e.g. " + | |||
| $"variables used in functions or deserialization dependencies). " + | |||
| $"Please include this path in the filter: {dep}"); | |||
| } | |||
| } | |||
| int? prev_slot = null; | |||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||
| { | |||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
| // The optimizer and original variable must be created before the slot | |||
| // variable, since the slot variable is generated using the Optimizer's | |||
| // add_slot API. | |||
| var slot_deps = dependency_map[slot_variable_node_id]; | |||
| slot_deps.Add(node_id); | |||
| slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||
| if(prev_slot is not null) | |||
| { | |||
| slot_deps.Add(prev_slot.Value); | |||
| } | |||
| prev_slot = slot_variable_node_id; | |||
| } | |||
| } | |||
| try | |||
| { | |||
| return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||
| } | |||
| catch (TrackableUtils.CyclicDependencyError ex) | |||
| { | |||
| throw new ValueError("Encountered a cycle in the deserialization dependencies" + | |||
| "in the SavedModel. This is extremely unexpected, please" + | |||
| "file a bug and make sure you are not manually modifying the SavedModel."); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Returns a dictionary of all dependencies of an object. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <returns></returns> | |||
| private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto) | |||
| { | |||
| Dictionary<Maybe<string, int>, int> dependencies = new(); | |||
| foreach(var refer in proto.Dependencies) | |||
| { | |||
| dependencies[refer.LocalName] = refer.NodeId; | |||
| } | |||
| if(proto.KindCase == SavedObject.KindOneofCase.Function) | |||
| { | |||
| var concreete_functions = proto.Function.ConcreteFunctions; | |||
| foreach(var fn_name in concreete_functions) | |||
| { | |||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
| { | |||
| dependencies[bound_input] = bound_input; | |||
| } | |||
| } | |||
| } | |||
| else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction) | |||
| { | |||
| var fn_name = proto.BareConcreteFunction.ConcreteFunctionName; | |||
| foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs) | |||
| { | |||
| dependencies[bound_input] = bound_input; | |||
| } | |||
| } | |||
| else if(proto.KindCase == SavedObject.KindOneofCase.Resource) | |||
| { | |||
| foreach(var child in proto.Children) | |||
| { | |||
| if(child.LocalName == "_create_resource") | |||
| { | |||
| dependencies["_create_resource"] = child.NodeId; | |||
| } | |||
| } | |||
| } | |||
| return dependencies; | |||
| } | |||
| /// <summary> | |||
| /// Loads all nodes and functions from the SavedModel and their edges. | |||
| /// </summary> | |||
| private void _load_all() | |||
| { | |||
| _load_nodes(); | |||
| _load_edges(); | |||
| _setup_remaining_functions(); | |||
| _load_checkpoint_save_and_restore_functions(); | |||
| } | |||
| /// <summary> | |||
| /// Restores the checkpoint-related save/restore functions to all nodes. | |||
| /// </summary> | |||
| private void _load_checkpoint_save_and_restore_functions() | |||
| { | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| var node = get(node_id); | |||
| if(node is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| // Restore Trackable serialize- and restore-from-tensor functions. | |||
| Debug.Assert(proto.SaveableObjects.Count == 1); | |||
| var saveable_object_proto = proto.SaveableObjects.Values.First(); | |||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| else | |||
| { | |||
| // Restore legacy SaveableObject functions. | |||
| Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new(); | |||
| foreach(var item in proto.SaveableObjects) | |||
| { | |||
| var name = item.Key; | |||
| var saveable_object_proto = item.Value; | |||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
| saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||
| } | |||
| node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Load all saved objects. | |||
| /// </summary> | |||
| private void _load_nodes() | |||
| { | |||
| // `nodes` maps from node ids to recreated objects | |||
| // `node_setters` maps from node ids to setter functions | |||
| // (same signature as setattr) for setting children. | |||
| var (nodes, node_setters) = _initialize_loaded_nodes(); | |||
| Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)> | |||
| slot_variable_node_ids = new(); | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| foreach(var slot_variable_proto in proto.SlotVariables) | |||
| { | |||
| var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId; | |||
| slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto); | |||
| } | |||
| } | |||
| // Re-create everything. | |||
| foreach (var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| if (nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| else if (slot_variable_node_ids.ContainsKey(node_id)) | |||
| { | |||
| // Use the public Optimizer interface when creating slot variables. | |||
| var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||
| var optimizer_object = nodes[optimizer_node_id]; | |||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
| // TODO: implement it. | |||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| else | |||
| { | |||
| // skip the function and concrete function. | |||
| if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||
| { | |||
| nodes[node_id] = null; | |||
| node_setters[node_id] = null; | |||
| continue; | |||
| } | |||
| var (node, setter) = _recreate(proto, node_id, nodes); | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| } | |||
| } | |||
| if (!nodes.ContainsKey(0)) | |||
| { | |||
| nodes[0] = _recreate_base_user_object().Item1; | |||
| } | |||
| _nodes = new List<Trackable>(); | |||
| for(int i = 0; i < _proto.Nodes.Count; i++) | |||
| { | |||
| _nodes.Add(nodes[i]); | |||
| } | |||
| _node_setters = node_setters; | |||
| } | |||
| /// <summary> | |||
| /// Load state from checkpoint into the deserialized objects. | |||
| /// </summary> | |||
| private void _restore_checkpoint() | |||
| { | |||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
| tf.device("CPU"); | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| LoadStatus load_status; | |||
| if (_save_options.allow_partial_checkpoint) | |||
| { | |||
| load_status = saver.restore(variables_path, _checkpoint_options).expect_partial(); | |||
| load_status.assert_nontrivial_match(); | |||
| } | |||
| else | |||
| { | |||
| load_status = saver.restore(variables_path, _checkpoint_options); | |||
| load_status.assert_existing_objects_matched(); | |||
| } | |||
| var ckpt = (load_status as CheckpointLoadStatus).Checkpoint; | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| throw new NotImplementedException("The checkpoint restore has not supported graph mode. " + | |||
| "Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Adds edges from objects to other objects and functions. | |||
| /// </summary> | |||
| private void _load_edges() | |||
| { | |||
| foreach(var (node_id, object_proto) in _iter_all_nodes()) | |||
| { | |||
| _add_object_graph_edges(object_proto, node_id); | |||
| } | |||
| if(_filtered_nodes is not null && _filtered_nodes.Contains(0)) | |||
| { | |||
| var root = get(0); | |||
| foreach(var node_path in _node_filters.Keys) | |||
| { | |||
| var loaded_node = _nodes[_node_path_to_id[node_path]]; | |||
| var path = node_path.Split('.'); | |||
| var current_node = root; | |||
| foreach(var name in path.Skip(1).Take(path.Length - 2)) | |||
| { | |||
| // `hasattr` and `setattr` is used here | |||
| throw new NotImplementedException(); | |||
| } | |||
| // `hasattr` and `setattr` is used here | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| private void _setup_remaining_functions() | |||
| { | |||
| // TODO: implement it with concrete functions. | |||
| } | |||
| public Trackable get(int node_id) | |||
| { | |||
| return _nodes[node_id]; | |||
| } | |||
| public Trackable get(string node_id) | |||
| { | |||
| return get(_node_path_to_id[node_id]); | |||
| } | |||
| /// <summary> | |||
| /// Adds edges from an object to its children. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| private void _add_object_graph_edges(SavedObject proto, int node_id) | |||
| { | |||
| var obj = _nodes[node_id]; | |||
| var setter = _node_setters[node_id]; | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| if(obj is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
| // skip the process of "__call__" | |||
| } | |||
| } | |||
| private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
| { | |||
| Dictionary<int, Trackable> nodes = new(); | |||
| Dictionary<int, Action<object, object, object>> node_setters = new(); | |||
| foreach(var item in _loaded_nodes) | |||
| { | |||
| var node_id = item.Key; | |||
| var (node, setter) = item.Value; | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| } | |||
| return (nodes, node_setters); | |||
| } | |||
| private IEnumerable<(int, SavedObject)> _iter_all_nodes() | |||
| { | |||
| foreach(var node_id in _ordered_node_ids) | |||
| { | |||
| yield return (node_id, _proto.Nodes[node_id]); | |||
| } | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
| { | |||
| // skip the registered classes. | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies = new(); | |||
| foreach(var item in _get_node_dependencies(proto)) | |||
| { | |||
| dependencies[item.Key] = nodes[item.Value]; | |||
| } | |||
| return _recreate_default(proto, node_id, dependencies); | |||
| } | |||
| /// <summary> | |||
| /// Creates a Python object from a SavedObject protocol buffer. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| /// <param name="dependencies"></param> | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| return proto.KindCase switch | |||
| { | |||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||
| SavedObject.KindOneofCase.Function => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException() | |||
| }; | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id) | |||
| { | |||
| // skip the check of proto identifier because of lack of property. | |||
| var looked_up = RevivedTypes.deserialize(proto); | |||
| if(looked_up is null) | |||
| { | |||
| return _recreate_base_user_object(proto, node_id); | |||
| } | |||
| return (looked_up.Item1, looked_up.Item2); | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) | |||
| { | |||
| return (new _UserObject(), setattr); | |||
| } | |||
| private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto) | |||
| { | |||
| string name = proto.Name; | |||
| string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>"; | |||
| // TODO(Rinne): `validate_synchronization_aggregation_trainable` | |||
| var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable( | |||
| proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name); | |||
| var saved_device = proto.Device; | |||
| var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device); | |||
| if (load_with_device) | |||
| { | |||
| tf.device(saved_device); | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| } | |||
| else | |||
| { | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| } | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
| } | |||
| // TODO: remove this to a common class. | |||
| public static Action<object, object, object> setattr = (x, y, z) => | |||
| { | |||
| Debug.Assert(y is string); | |||
| var properties = x.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| { | |||
| if((string)y == p.Name) | |||
| { | |||
| p.SetValue(x, z); | |||
| return; | |||
| } | |||
| } | |||
| // TODO(Rinne): check if the property has been set successfully. | |||
| //throw new ValueError($"Cannot find the property {y} of {x}."); | |||
| }; | |||
| public class _UserObject: AutoTrackable | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,122 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Loader | |||
| { | |||
| public static SavedModel parse_saved_model(string export_dir) | |||
| { | |||
| var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT); | |||
| var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB); | |||
| SavedModel saved_model = new SavedModel(); | |||
| if (File.Exists(path_to_pb)) | |||
| { | |||
| byte[] file_content; | |||
| using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read)) | |||
| { | |||
| file_content = new byte[f.Length]; | |||
| Debug.Assert(f.Length <= int.MaxValue); | |||
| f.Read(file_content, 0, (int)f.Length); | |||
| } | |||
| // TODO: change to stream mode. | |||
| saved_model.MergeFrom(file_content); | |||
| return saved_model; | |||
| } | |||
| else if (File.Exists(path_to_pbtxt)) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| else | |||
| { | |||
| throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" + | |||
| $"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}"); | |||
| } | |||
| } | |||
| // TODO: revise the type of `tags` | |||
| public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null) | |||
| { | |||
| return load_partial(export_dir, null, tags, options)["root"]; | |||
| } | |||
| public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null) | |||
| { | |||
| if (options is null) | |||
| { | |||
| options = new LoadOptions(); | |||
| } | |||
| if (tags is not null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir); | |||
| Trackable root = null; | |||
| Loader loader = null; | |||
| if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null) | |||
| { | |||
| // skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)` | |||
| var meta_graph_def = saved_model_proto.MetaGraphs[0]; | |||
| if (!BitConverter.IsLittleEndian) | |||
| { | |||
| SavedModelUtils.swap_function_tensor_content(meta_graph_def); | |||
| } | |||
| var object_graph_proto = meta_graph_def.ObjectGraphDef; | |||
| var ckpt_options = new CheckpointOptions(options.experimental_io_device); | |||
| tf_with(ops.init_scope(), x => | |||
| { | |||
| loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||
| root = loader.get(0); | |||
| // skip the assignment of `graph_debug_info`. | |||
| }); | |||
| // skip the assignment of `tensorflow_version` | |||
| // skip the assignment of `tensorflow_git_version` | |||
| // skip the process of `metrics`. | |||
| } | |||
| else | |||
| { | |||
| if(filters is not null && filters.Count > 0) | |||
| { | |||
| throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" | |||
| + " version) cannot be loaded with node filters."); | |||
| } | |||
| tf_with(ops.init_scope(), x => | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| }); | |||
| } | |||
| if(filters != null && filters.Count > 0) | |||
| { | |||
| return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||
| } | |||
| else | |||
| { | |||
| var res = new Dictionary<string, Trackable>(); | |||
| res["root"] = root; | |||
| return res; | |||
| } | |||
| } | |||
| public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir) | |||
| { | |||
| var saved_model = parse_saved_model(export_dir); | |||
| // TODO: implement debug info. | |||
| return (saved_model, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -6,7 +6,6 @@ using System.Text; | |||
| using Google.Protobuf; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.Binding; | |||
| @@ -1,7 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.ModelSaving; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| @@ -68,6 +68,34 @@ namespace Tensorflow | |||
| return saveables.ToArray(); | |||
| } | |||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables) | |||
| { | |||
| var saveables = new List<MySaveableObject>(); | |||
| var seen_ops = new List<Tensor>(); | |||
| foreach (var (name, op) in enumerate(names_to_saveables)) | |||
| { | |||
| foreach (var converted_saveable_object in saveable_objects_for_op(op, name)) | |||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||
| } | |||
| return saveables.ToArray(); | |||
| } | |||
| public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables) | |||
| { | |||
| var saveables = new List<MySaveableObject>(); | |||
| var seen_ops = new List<BaseResourceVariable>(); | |||
| foreach(var item in names_to_saveables.OrderBy(x => x.Key)) | |||
| { | |||
| foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key)) | |||
| { | |||
| _add_saveable(saveables, seen_ops, converted_saveable_object); | |||
| } | |||
| } | |||
| return saveables.ToArray(); | |||
| } | |||
| private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject | |||
| { | |||
| if (seen_ops.Contains(saveable.op)) | |||
| @@ -77,6 +105,15 @@ namespace Tensorflow | |||
| seen_ops.Add(saveable.op); | |||
| } | |||
| private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable) | |||
| { | |||
| if (seen_ops.Contains(saveable.variable)) | |||
| throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}"); | |||
| saveables.Add(saveable); | |||
| seen_ops.Add(saveable.variable); | |||
| } | |||
| /// <summary> | |||
| /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. | |||
| /// </summary> | |||
| @@ -136,19 +173,20 @@ namespace Tensorflow | |||
| { | |||
| full_name = name + "_" + attr; | |||
| } | |||
| if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||
| var op = factory(full_name); | |||
| if(op.TryGet<BaseResourceVariable>(out var variable)) | |||
| { | |||
| foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| { | |||
| yield return op; | |||
| yield return v; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| var saveable = factory.GetValue<MySaveableObject>(); | |||
| foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||
| var saveable = op.GetValue<MySaveableObject>(); | |||
| foreach (var v in saveable_objects_for_op(saveable, saveable.name)) | |||
| { | |||
| yield return op; | |||
| yield return v; | |||
| } | |||
| } | |||
| } | |||
| @@ -214,20 +252,19 @@ namespace Tensorflow | |||
| return names_to_saveables; | |||
| } | |||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj) | |||
| { | |||
| // skip the process of type `PythonState` | |||
| if (trackable_has_serialize_to_tensor(obj)) | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; | |||
| // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. | |||
| var tensor_dict = obj.serialize_to_tensors(); | |||
| List<SaveSpec> specs = new(); | |||
| List<string> local_names = new(); | |||
| string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; | |||
| foreach(var pair in tensor_dict) | |||
| foreach (var pair in tensor_dict) | |||
| { | |||
| var tensor_name = pair.Key; | |||
| var maybe_tensor = pair.Value; | |||
| @@ -235,9 +272,9 @@ namespace Tensorflow | |||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
| IDictionary<string, Tensor> internal_dict; | |||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| if (maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| { | |||
| internal_dict= new Dictionary<string, Tensor>(); | |||
| internal_dict = new Dictionary<string, Tensor>(); | |||
| internal_dict[""] = tensor; | |||
| } | |||
| else | |||
| @@ -245,13 +282,18 @@ namespace Tensorflow | |||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| foreach(var item in internal_dict) | |||
| foreach (var item in internal_dict) | |||
| { | |||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
| } | |||
| } | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||
| res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| return new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| } | |||
| if (trackable_has_serialize_to_tensor(obj)) | |||
| { | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; | |||
| return res; | |||
| } | |||
| else | |||
| @@ -333,6 +375,28 @@ namespace Tensorflow | |||
| return restored_ops; | |||
| }; | |||
| } | |||
| /// <summary> | |||
| /// Returns a dict of SaveableObject factories generated from loaded fns. | |||
| /// </summary> | |||
| /// <param name="saveable_fn_by_name"></param> | |||
| /// <param name="temp_session"></param> | |||
| public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects( | |||
| IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) | |||
| { | |||
| if (saveable_fn_by_name.Count > 0) | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); | |||
| } | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| return res; | |||
| } | |||
| public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory, | |||
| bool call_with_mapped_captures = false) | |||
| { | |||
| return factory(key); | |||
| } | |||
| } | |||
| public class SaveableCompatibilityConverter: Trackable | |||
| @@ -20,8 +20,8 @@ using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| @@ -41,9 +41,10 @@ namespace Tensorflow.Train | |||
| protected IDictionary<string, Trackable> _unconditional_dependency_names; | |||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
| protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies; | |||
| protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
| new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories = | |||
| new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| private bool _manual_tracking = true; | |||
| private static Trackable _none = new AutoTrackable(); | |||
| @@ -71,6 +72,18 @@ namespace Tensorflow.Train | |||
| public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } | |||
| public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } | |||
| public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } | |||
| public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies; | |||
| public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories | |||
| { | |||
| get | |||
| { | |||
| return _self_saveable_object_factories; | |||
| } | |||
| set | |||
| { | |||
| _self_saveable_object_factories = value; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Restore-on-create for a variable be saved with this `Checkpointable`. | |||
| @@ -136,9 +149,11 @@ namespace Tensorflow.Train | |||
| _self_update_uid = -1; | |||
| _unconditional_checkpoint_dependencies = new List<TrackableReference>(); | |||
| _unconditional_dependency_names = new Dictionary<string, Trackable>(); | |||
| _unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>(); | |||
| } | |||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache) | |||
| public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, | |||
| IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | |||
| @@ -174,10 +189,19 @@ namespace Tensorflow.Train | |||
| /// <param name="trackable"></param> | |||
| public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | |||
| { | |||
| //_maybe_initialize_trackable(); | |||
| //trackable._maybe_initialize_trackable(); | |||
| // TODO: complete the implementation. | |||
| _maybe_initialize_trackable(); | |||
| trackable._maybe_initialize_trackable(); | |||
| if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies)) | |||
| { | |||
| _unconditional_deferred_dependencies.Remove(name); | |||
| foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid)) | |||
| { | |||
| checkpoint_position.restore(trackable); | |||
| } | |||
| } | |||
| // TODO(Rinne): deal with `_self_name_based_restores` | |||
| } | |||
| public virtual Trackable? _lookup_dependency(string name) | |||
| @@ -225,12 +249,19 @@ namespace Tensorflow.Train | |||
| return self_tensor_map.Keys.ToList(); | |||
| } | |||
| public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "") | |||
| { | |||
| throw new NotImplementedException(); | |||
| //return new TrackableSaveable(this, null, name, null, null); | |||
| } | |||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
| { | |||
| // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). | |||
| throw new NotImplementedException(); | |||
| Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new(); | |||
| res[""] = create_saveable; | |||
| return res; | |||
| } | |||
| else | |||
| { | |||
| @@ -259,4 +290,6 @@ namespace Tensorflow.Train | |||
| } | |||
| public record class TrackableReference(string Name, Trackable Refer); | |||
| public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.Train; | |||
| @@ -20,9 +21,9 @@ public static class TrackableUtils | |||
| LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); | |||
| } | |||
| } | |||
| private static string _ESCAPE_CHAR = "."; | |||
| private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
| private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
| internal static string _ESCAPE_CHAR = "."; | |||
| internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; | |||
| internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; | |||
| internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; | |||
| public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) | |||
| { | |||
| @@ -5,9 +5,9 @@ using Tensorflow.Variables; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.ModelSaving; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -19,7 +19,11 @@ namespace Tensorflow | |||
| protected TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| protected string _handle_name; | |||
| protected string handle_name => _handle_name; | |||
| public string handle_name | |||
| { | |||
| get { return _handle_name; } | |||
| set { _handle_name = value; } | |||
| } | |||
| protected string _unique_id; | |||
| public string UniqueId => _unique_id; | |||
| @@ -289,10 +293,10 @@ namespace Tensorflow | |||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||
| } | |||
| public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint() | |||
| { | |||
| var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
| var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; | |||
| return res; | |||
| } | |||
| @@ -238,5 +238,23 @@ namespace Tensorflow | |||
| { | |||
| return _graph_element.eval(session); | |||
| } | |||
| public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable( | |||
| VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name) | |||
| { | |||
| if(aggregation is null) | |||
| { | |||
| aggregation = VariableAggregation.None; | |||
| } | |||
| if(synchronization is null) | |||
| { | |||
| synchronization = VariableSynchronization.Auto; | |||
| } | |||
| if (trainable is null) | |||
| { | |||
| trainable = synchronization != VariableSynchronization.OnRead; | |||
| } | |||
| return (synchronization.Value, aggregation.Value, trainable.Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="config"></param> | |||
| /// <returns></returns> | |||
| static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||
| { | |||
| // Layer instances created during the graph reconstruction process. | |||
| var created_layers = new Dictionary<string, ILayer>(); | |||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||
| var node_index_map = new Dictionary<(string, int), int>(); | |||
| var node_count_by_layer = new Dictionary<ILayer, int>(); | |||
| var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||
| @@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine | |||
| layer = created_layers[layer_name]; | |||
| else | |||
| { | |||
| layer = layer_data.ClassName switch | |||
| { | |||
| "InputLayer" => InputLayer.from_config(layer_data.Config), | |||
| "Dense" => Dense.from_config(layer_data.Config), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); | |||
| created_layers[layer_name] = layer; | |||
| } | |||
| @@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine | |||
| Inputs = inputs, | |||
| Outputs = outputs | |||
| }) | |||
| { | |||
| Initialize(inputs, outputs, name); | |||
| } | |||
| internal void Initialize(Tensors inputs, Tensors outputs, string name = null) | |||
| { | |||
| _input_layers = new List<ILayer>(); | |||
| _output_layers = new List<ILayer>(); | |||
| @@ -70,7 +75,14 @@ namespace Tensorflow.Keras.Engine | |||
| this.inputs = inputs; | |||
| this.outputs = outputs; | |||
| built = true; | |||
| _buildInputShape = inputs.shape; | |||
| if(inputs.Length > 0) | |||
| { | |||
| _buildInputShape = inputs.shape; | |||
| } | |||
| else | |||
| { | |||
| _buildInputShape = new Saving.TensorShapeConfig(); | |||
| } | |||
| if (outputs.Any(x => x.KerasHistory == null)) | |||
| base_layer_utils.create_keras_history(outputs); | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine | |||
| public virtual Shape ComputeOutputShape(Shape input_shape) | |||
| => throw new NotImplementedException(""); | |||
| protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) | |||
| { | |||
| List<IVariableV1> res = new(); | |||
| var nested_layers = _flatten_layers(false, false); | |||
| foreach (var layer in nested_layers) | |||
| { | |||
| if (layer is Layer l) | |||
| { | |||
| if (include_trainable == true && include_non_trainable == true) | |||
| { | |||
| res.AddRange(l.Variables); | |||
| } | |||
| else if (include_trainable == true && include_non_trainable == false) | |||
| { | |||
| res.AddRange(l.TrainableVariables); | |||
| } | |||
| else if(include_trainable == false && include_non_trainable == true) | |||
| { | |||
| res.AddRange(l.NonTrainableVariables); | |||
| } | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -12,7 +12,7 @@ public abstract partial class Layer | |||
| public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
| public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||
| public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| @@ -66,16 +67,74 @@ namespace Tensorflow.Keras.Engine | |||
| public bool SupportsMasking { get; set; } | |||
| protected List<IVariableV1> _trainable_weights; | |||
| public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||
| public virtual List<IVariableV1> TrainableVariables => TrainableWeights; | |||
| protected List<IVariableV1> _non_trainable_weights; | |||
| public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | |||
| public List<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||
| public List<IVariableV1> Variables => Weights; | |||
| public virtual List<IVariableV1> TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| if (!this.Trainable) | |||
| { | |||
| return new List<IVariableV1>(); | |||
| } | |||
| var children_weights = _gather_children_variables(true); | |||
| return children_weights.Concat(_trainable_weights).Distinct().ToList(); | |||
| } | |||
| } | |||
| public virtual List<IVariableV1> NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| if (!this.Trainable) | |||
| { | |||
| var children_weights = _gather_children_variables(true, true); | |||
| return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); | |||
| } | |||
| else | |||
| { | |||
| var children_weights = _gather_children_variables(include_non_trainable: true); | |||
| return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); | |||
| } | |||
| } | |||
| } | |||
| public virtual List<IVariableV1> Weights | |||
| { | |||
| get | |||
| { | |||
| return TrainableWeights.Concat(NonTrainableWeights).ToList(); | |||
| } | |||
| set | |||
| { | |||
| if (Weights.Count() != value.Count()) throw new ValueError( | |||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||
| $"with a weight list of length {len(value)}, but the layer was " + | |||
| $"expecting {len(Weights)} weights."); | |||
| foreach (var (this_w, v_w) in zip(Weights, value)) | |||
| this_w.assign(v_w, read_value: true); | |||
| } | |||
| } | |||
| protected int id; | |||
| public int Id => id; | |||
| protected string name; | |||
| protected string base_name; | |||
| public string Name => name; | |||
| public string Name | |||
| { | |||
| get | |||
| { | |||
| return name; | |||
| } | |||
| set | |||
| { | |||
| name = value; | |||
| } | |||
| } | |||
| protected bool computePreviousMask; | |||
| protected List<Operation> updates; | |||
| @@ -85,10 +144,11 @@ namespace Tensorflow.Keras.Engine | |||
| List<INode> inboundNodes; | |||
| public List<INode> InboundNodes => inboundNodes; | |||
| List<INode> outboundNodes; | |||
| public List<INode> OutboundNodes => outboundNodes; | |||
| public JObject SerializedAttributes { get; set; } | |||
| ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>(); | |||
| public CallContext CallContext => callContext.Value; | |||
| public Tensor[] input | |||
| @@ -117,6 +177,11 @@ namespace Tensorflow.Keras.Engine | |||
| protected List<ILayer> _self_tracked_trackables; | |||
| public Layer(LayerArgs args) | |||
| { | |||
| Initialize(args); | |||
| } | |||
| internal virtual void Initialize(LayerArgs args) | |||
| { | |||
| this.args = args; | |||
| // A stateful layer is a layer whose updates are run during inference too, | |||
| @@ -273,46 +338,9 @@ namespace Tensorflow.Keras.Engine | |||
| public int count_params() | |||
| { | |||
| if (Trainable) | |||
| return layer_utils.count_params(this, weights); | |||
| return layer_utils.count_params(this, Weights); | |||
| return 0; | |||
| } | |||
| List<IVariableV1> ILayer.TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| return _trainable_weights; | |||
| } | |||
| } | |||
| List<IVariableV1> ILayer.NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| return _non_trainable_weights; | |||
| } | |||
| } | |||
| public List<IVariableV1> weights | |||
| { | |||
| get | |||
| { | |||
| var weights = new List<IVariableV1>(); | |||
| weights.AddRange(_trainable_weights); | |||
| weights.AddRange(_non_trainable_weights); | |||
| return weights; | |||
| } | |||
| set | |||
| { | |||
| if (weights.Count() != value.Count()) throw new ValueError( | |||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||
| $"with a weight list of length {len(value)}, but the layer was " + | |||
| $"expecting {len(weights)} weights."); | |||
| foreach (var (this_w, v_w) in zip(weights, value)) | |||
| this_w.assign(v_w, read_value: true); | |||
| } | |||
| } | |||
| public List<IVariableV1> Variables => weights; | |||
| public virtual IKerasConfig get_config() | |||
| => args; | |||
| @@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| using (SharedObjectSavingScope.Enter()) | |||
| { | |||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
| KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine | |||
| IVariableV1 _predict_counter; | |||
| bool _base_model_initialized; | |||
| bool stop_training; | |||
| public bool IsGraphNetwork => _is_graph_network; | |||
| public OptimizerV2 Optimizer | |||
| { | |||
| @@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine | |||
| _init_batch_counters(); | |||
| } | |||
| internal override void Initialize(LayerArgs args) | |||
| { | |||
| _init_batch_counters(); | |||
| base.Initialize(args); | |||
| } | |||
| void _configure_steps_per_execution(int steps_per_execution) | |||
| { | |||
| _steps_per_execution = tf.Variable(steps_per_execution, | |||
| @@ -81,10 +89,11 @@ namespace Tensorflow.Keras.Engine | |||
| public override List<ILayer> Layers | |||
| => _flatten_layers(recursive: false, include_self: false).ToList(); | |||
| public override List<IVariableV1> TrainableVariables | |||
| public override List<IVariableV1> TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| // skip the assertion of weights created. | |||
| var variables = new List<IVariableV1>(); | |||
| if (!Trainable) | |||
| @@ -95,18 +104,40 @@ namespace Tensorflow.Keras.Engine | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| if (trackable_obj.Trainable) | |||
| variables.AddRange(trackable_obj.TrainableVariables); | |||
| variables.AddRange(trackable_obj.TrainableWeights); | |||
| } | |||
| foreach (var layer in _self_tracked_trackables) | |||
| variables.AddRange(_trainable_weights); | |||
| return variables.Distinct().ToList(); | |||
| } | |||
| } | |||
| public override List<IVariableV1> NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| // skip the assertion of weights created. | |||
| var variables = new List<IVariableV1>(); | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| if (layer.Trainable) | |||
| variables.AddRange(layer.TrainableVariables); | |||
| variables.AddRange(trackable_obj.NonTrainableWeights); | |||
| } | |||
| // variables.AddRange(_trainable_weights); | |||
| if (!Trainable) | |||
| { | |||
| var trainable_variables = new List<IVariableV1>(); | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| variables.AddRange(trackable_obj.TrainableWeights); | |||
| } | |||
| variables.AddRange(trainable_variables); | |||
| variables.AddRange(_trainable_weights); | |||
| variables.AddRange(_non_trainable_weights); | |||
| } | |||
| return variables; | |||
| return variables.Distinct().ToList(); | |||
| } | |||
| } | |||
| @@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine | |||
| : base(args.Inputs, args.Outputs, name: args.Name) | |||
| { | |||
| this.args = args; | |||
| if (args.Layers == null) | |||
| args.Layers = new List<ILayer>(); | |||
| // SupportsMasking = true; | |||
| _compute_output_and_mask_jointly = true; | |||
| _auto_track_sub_layers = false; | |||
| @@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine | |||
| _created_nodes = new List<INode>(); | |||
| // Add to the model any layers passed to the constructor. | |||
| if (args.Layers != null) | |||
| if (args.Layers is not null) | |||
| { | |||
| foreach (var layer in args.Layers) | |||
| add(layer); | |||
| InitLayers(args.Layers); | |||
| } | |||
| } | |||
| public void InitLayers(IEnumerable<ILayer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| add(layer); | |||
| } | |||
| } | |||
| @@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||
| { | |||
| throw new ValueError("Alpha must be a number greater than 0."); | |||
| } | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| @@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||
| } | |||
| public override void build(Shape input_shape) | |||
| { | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
| { | |||
| @@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers { | |||
| if ( alpha < 0f ) { | |||
| throw new ValueError("Alpha must be a number greater than 0."); | |||
| } | |||
| _buildInputShape = input_shape; | |||
| built = true; | |||
| base.build(input_shape); | |||
| } | |||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
| Tensor output = inputs; | |||
| @@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers | |||
| return outputs; | |||
| } | |||
| public static Dense from_config(LayerArgs args) | |||
| { | |||
| return new Dense(args as DenseArgs); | |||
| } | |||
| } | |||
| } | |||
| @@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers | |||
| name: Name); | |||
| } | |||
| public static InputLayer from_config(LayerArgs args) | |||
| { | |||
| return new InputLayer(args as InputLayerArgs); | |||
| } | |||
| public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | |||
| } | |||
| } | |||
| @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics | |||
| public virtual void reset_states() | |||
| { | |||
| foreach (var v in weights) | |||
| foreach (var v in Weights) | |||
| v.assign(0); | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.IO; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| namespace Tensorflow.Keras.Models | |||
| @@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models | |||
| public Functional from_config(ModelConfig config) | |||
| => Functional.from_config(config); | |||
| public void load_model(string filepath, bool compile = true) | |||
| public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
| { | |||
| var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||
| var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||
| var meta_graph_def = saved_mode.MetaGraphs[0]; | |||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
| bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||
| var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||
| // Recreate layers and metrics using the info stored in the metadata. | |||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
| keras_loader.load_layers(compile: compile); | |||
| return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,12 +1,24 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using System.Text.RegularExpressions; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Layers.Rnn; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.ApiDef.Types; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| @@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| public class KerasObjectLoader | |||
| { | |||
| SavedMetadata _metadata; | |||
| SavedObjectGraph _proto; | |||
| Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>(); | |||
| List<int> _traversed_nodes_from_config = new List<int>(); | |||
| private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; | |||
| private SavedMetadata _metadata; | |||
| private SavedObjectGraph _proto; | |||
| private Dictionary<int, string> _node_paths = new Dictionary<int, string>(); | |||
| private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>(); | |||
| private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>(); | |||
| private List<int> _traversed_nodes_from_config = new List<int>(); | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes; | |||
| private List<int> _models_to_reconstruct; | |||
| public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes; | |||
| static KerasObjectLoader() | |||
| { | |||
| PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; | |||
| } | |||
| public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) | |||
| { | |||
| _metadata = metadata; | |||
| _proto = object_graph_def; | |||
| _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath); | |||
| _models_to_reconstruct = new List<int>(); | |||
| loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>(); | |||
| } | |||
| /// <summary> | |||
| @@ -42,15 +66,255 @@ namespace Tensorflow.Keras.Saving | |||
| continue; | |||
| } | |||
| _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); | |||
| } | |||
| foreach(var node_metadata in metric_list) | |||
| { | |||
| try | |||
| { | |||
| if (node_metadata.Identifier.Equals("_tf_keras_metric")) | |||
| { | |||
| continue; | |||
| } | |||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | |||
| node_metadata.Metadata); | |||
| } | |||
| catch(ValueError e) | |||
| { | |||
| if (compile) | |||
| { | |||
| throw e; | |||
| } | |||
| // TODO: add logging.warning. | |||
| } | |||
| } | |||
| } | |||
| public string get_path(int node_id) | |||
| { | |||
| return _node_paths[node_id]; | |||
| } | |||
| /// <summary> | |||
| /// Finish setting up Keras objects. | |||
| /// | |||
| /// This function is executed after all objects and functions have been created. | |||
| /// Call functions and losses are attached to each layer, and once all layers | |||
| /// have been fully set up, graph networks are initialized. | |||
| /// | |||
| /// Subclassed models that are revived from the SavedModel are treated like | |||
| /// layers, and have their call/loss functions attached here. | |||
| /// </summary> | |||
| public void finalize_objects() | |||
| { | |||
| List<Layer> layers_revived_from_config = new(); | |||
| List<Layer> layers_revived_from_saved_model = new(); | |||
| foreach(var item in loaded_nodes) | |||
| { | |||
| var node_id = item.Key; | |||
| var node = item.Value.Item1; | |||
| if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| } | |||
| _unblock_model_reconstruction(node_id, node as Layer); | |||
| if(node is InputLayer or Metric) | |||
| { | |||
| continue; | |||
| } | |||
| // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. | |||
| layers_revived_from_config.Add(node as Layer); | |||
| } | |||
| _finalize_saved_model_layers(layers_revived_from_saved_model); | |||
| _finalize_config_layers(layers_revived_from_config); | |||
| _reconstruct_all_models(); | |||
| } | |||
| private void _reconstruct_all_models() | |||
| { | |||
| HashSet<int> all_initialized_models = new(); | |||
| for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--) | |||
| { | |||
| int model_id = _models_to_reconstruct[i]; | |||
| all_initialized_models.Add(model_id); | |||
| var (model, layers) = model_layer_dependencies[model_id]; | |||
| _reconstruct_model(model_id, model, layers.ToList()); | |||
| _finalize_config_layers(new List<Layer>() { model }); | |||
| } | |||
| Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys)); | |||
| } | |||
| private void _reconstruct_model(int model_id, Model model, List<Layer> layers) | |||
| { | |||
| var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"]; | |||
| if(model.input is not null && model.input.Length > 0) | |||
| { | |||
| } | |||
| else if(model is Sequential s) | |||
| { | |||
| if(layers is null || layers.Count == 0 || layers[0] is not InputLayer) | |||
| { | |||
| if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||
| { | |||
| layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||
| } | |||
| else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||
| { | |||
| // TODO(Rinne): implement it | |||
| } | |||
| } | |||
| // `model.__init__(layers, config["name"])` | |||
| s.InitLayers(layers); | |||
| s.Name = config["name"].ToObject<string>(); | |||
| if(s.input is null || s.input.Length == 0) | |||
| { | |||
| var first_layer = _get_child_layer_node_ids(model_id)[0]; | |||
| var input_specs = _infer_inputs(first_layer); | |||
| var input_shapes = _infer_inputs(first_layer, true); | |||
| // `model._set_inputs(input_specs)` | |||
| // skip the check of input_specs is Dictionary | |||
| if (!s.Built) | |||
| { | |||
| s.build(input_shapes); | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // skip the parameter `created_layers`. | |||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), | |||
| layers.ToDictionary(x => x.Name, x => x as ILayer)); | |||
| // skip the `model.__init__` | |||
| (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||
| (model as Functional).connect_ancillary_layers(created_layers); | |||
| } | |||
| _set_network_attributes_from_metadata(model); | |||
| _unblock_model_reconstruction(model_id, model); | |||
| } | |||
| private void _set_network_attributes_from_metadata(Model revived_object) | |||
| { | |||
| // TODO: implement it. | |||
| } | |||
| /// <summary> | |||
| /// Runs the final steps of loading Keras Layers from config. | |||
| /// </summary> | |||
| /// <param name="layers"></param> | |||
| private void _finalize_config_layers(List<Layer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| if (_is_graph_network(layer)) | |||
| { | |||
| _restore_layer_unconditional_losses(layer); | |||
| } | |||
| _restore_layer_activation_loss(layer); | |||
| _restore_layer_metrics(layer); | |||
| // TODO(Rinne): deal with RNN. | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Runs the final steps of loading Keras Layers from SavedModel. | |||
| /// </summary> | |||
| /// <param name="layers"></param> | |||
| private void _finalize_saved_model_layers(List<Layer> layers) | |||
| { | |||
| foreach(var layer in layers) | |||
| { | |||
| // TODO(Rinne): deal with `RevivedNetwork`. | |||
| _restore_layer_unconditional_losses(layer); | |||
| _restore_layer_activation_loss(layer); | |||
| _restore_layer_metrics(layer); | |||
| } | |||
| } | |||
| private void _restore_layer_unconditional_losses(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| private void _restore_layer_activation_loss(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| private void _restore_layer_metrics(Layer layer) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| /// <summary> | |||
| /// Removes layer from blocking model reconstruction. | |||
| /// </summary> | |||
| /// <param name="layer_id"></param> | |||
| /// <param name="layer"></param> | |||
| private void _unblock_model_reconstruction(int layer_id, Layer layer) | |||
| { | |||
| foreach(var depencency in model_layer_ids_dependencies) | |||
| { | |||
| var layer_ids = depencency.Value.Item2; | |||
| var layers = model_layer_dependencies.SetDefault(depencency.Key, | |||
| (depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2; | |||
| if (!layer_ids.Contains(layer_id)) | |||
| { | |||
| continue; | |||
| } | |||
| layers[Array.IndexOf(layer_ids, layer_id)] = layer; | |||
| if (layers.All(x => x is not null)) | |||
| { | |||
| _models_to_reconstruct.Add(depencency.Key); | |||
| } | |||
| } | |||
| } | |||
| void _load_layer(int node_id, string identifier, string metadata_json) | |||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
| { | |||
| metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| _revive_from_config(identifier, metadata, node_id); | |||
| if (loaded_nodes.ContainsKey(node_id)) | |||
| { | |||
| var (node, setter) = loaded_nodes[node_id]; | |||
| _maybe_add_serialized_attributes(node as Layer, metadata); | |||
| var config = metadata.Config; | |||
| if(_is_graph_network(node as Layer) && generic_utils.validate_config(config)) | |||
| { | |||
| Debug.Assert(node is Model); | |||
| var child_nodes = _get_child_layer_node_ids(node_id); | |||
| model_layer_ids_dependencies[node_id] = (node as Model, child_nodes); | |||
| if(child_nodes is null || child_nodes.Length == 0) | |||
| { | |||
| _models_to_reconstruct.Add(node_id); | |||
| } | |||
| } | |||
| return (node, setter); | |||
| } | |||
| else | |||
| { | |||
| var (obj, setter) = _revive_from_config(identifier, metadata, node_id); | |||
| if (obj is null) | |||
| { | |||
| (obj, setter) = _revive_custom_object(identifier, metadata); | |||
| } | |||
| Debug.Assert(obj is Layer); | |||
| _maybe_add_serialized_attributes(obj as Layer, metadata); | |||
| return (obj, setter); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -59,11 +323,34 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="identifier"></param> | |||
| /// <param name="metadata"></param> | |||
| /// <param name="node_id"></param> | |||
| void _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
| private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) | |||
| { | |||
| var obj = _revive_graph_network(identifier, metadata, node_id); | |||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
| Trackable obj; | |||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| return (null, null); | |||
| //throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| else | |||
| { | |||
| obj = _revive_graph_network(identifier, metadata, node_id); | |||
| obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id); | |||
| } | |||
| if(obj is null) | |||
| { | |||
| return (null, null); | |||
| } | |||
| var setter = _config_node_setter(_revive_setter); | |||
| _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id); | |||
| return (obj, setter); | |||
| } | |||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id) | |||
| @@ -71,6 +358,12 @@ namespace Tensorflow.Keras.Saving | |||
| var config = metadata.Config; | |||
| var class_name = metadata.ClassName; | |||
| Model model = null; | |||
| if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional") | |||
| { | |||
| return null; | |||
| } | |||
| if (class_name == "Sequential") | |||
| { | |||
| model = new Sequential(new SequentialArgs | |||
| @@ -78,34 +371,82 @@ namespace Tensorflow.Keras.Saving | |||
| Name = config.GetValue("name").ToString() | |||
| }); | |||
| } | |||
| else if (class_name == "Functional") | |||
| else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| model = new Sequential(new SequentialArgs | |||
| { | |||
| Name = class_name | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||
| } | |||
| if (!metadata.IsGraphNetwork) | |||
| return null; | |||
| // Record this model and its layers. This will later be used to reconstruct | |||
| // the model. | |||
| var layers = _get_child_layer_node_ids(node_id); | |||
| model_layer_dependencies[node_id] = (model, layers); | |||
| model_layer_ids_dependencies[node_id] = (model, layers); | |||
| if(layers is null || layers.Length == 0) | |||
| { | |||
| _models_to_reconstruct.Add(node_id); | |||
| } | |||
| return model; | |||
| } | |||
| Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
| Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id) | |||
| { | |||
| var config = metadata.Config; | |||
| var class_name = metadata.ClassName; | |||
| var shared_object_id = metadata.SharedObjectId; | |||
| var must_restore_from_config = metadata.MustRestoreFromConfig; | |||
| var obj = class_name switch | |||
| { | |||
| "Resizing" => Resizing.from_config(config), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| var obj = generic_utils.deserialize_keras_object(class_name, config); | |||
| obj.Name = metadata.Name; | |||
| // TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec` | |||
| var built = _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
| return null; | |||
| if (!built) | |||
| { | |||
| return null; | |||
| } | |||
| return obj; | |||
| } | |||
| private void _revive_setter(object layer, object name, object value) | |||
| { | |||
| Debug.Assert(name is string); | |||
| Debug.Assert(layer is Layer); | |||
| if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) | |||
| { | |||
| if(value is Trackable) | |||
| { | |||
| (layer as Layer)._track_trackable(value as Trackable, name as string); | |||
| } | |||
| if((layer as Layer).SerializedAttributes is null) | |||
| { | |||
| (layer as Layer).SerializedAttributes = new JObject(); | |||
| } | |||
| (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); | |||
| } | |||
| else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) | |||
| { | |||
| (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); | |||
| } | |||
| else | |||
| { | |||
| var properties = layer.GetType().GetProperties(); | |||
| foreach(var p in properties) | |||
| { | |||
| if(p.Name == name as string && p.GetValue(layer) is not null) | |||
| { | |||
| return; | |||
| } | |||
| } | |||
| Loader.setattr(layer, name, value); | |||
| } | |||
| } | |||
| /// <summary> | |||
| @@ -143,34 +484,186 @@ namespace Tensorflow.Keras.Saving | |||
| /// <param name="obj"></param> | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id) | |||
| void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id) | |||
| { | |||
| if (_traversed_nodes_from_config.Contains(node_id)) | |||
| return; | |||
| var parent_path = _node_paths[node_id]; | |||
| _traversed_nodes_from_config.Add(node_id); | |||
| if (!obj.Built) | |||
| obj._maybe_initialize_trackable(); | |||
| if(obj is Layer layer && !layer.Built) | |||
| { | |||
| var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| _try_build_layer(obj, node_id, metadata.BuildInputShape); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata); | |||
| _try_build_layer(layer, node_id, metadata.BuildInputShape); | |||
| } | |||
| List<(Trackable, int, string)> children = new(); | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| var obj_child = obj._lookup_dependency(refer.LocalName); | |||
| children.Add((obj_child, refer.NodeId, refer.LocalName)); | |||
| } | |||
| var metric_list_node_id = _search_for_child_node(node_id, new string[] { | |||
| Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" | |||
| }); | |||
| if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) | |||
| { | |||
| var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x); | |||
| foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children) | |||
| { | |||
| if (obj_metrics.TryGetValue(refer.LocalName, out var metric)) | |||
| { | |||
| var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}"; | |||
| children.Add((metric as Metric, refer.NodeId, metric_path)); | |||
| } | |||
| } | |||
| } | |||
| foreach(var (obj_child, child_id, child_name) in children) | |||
| { | |||
| if(obj_child is null) | |||
| { | |||
| continue; | |||
| } | |||
| var child_proto = _proto.Nodes[child_id]; | |||
| // skip the check for registered identifier | |||
| Action<object, object, object> setter; | |||
| if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) | |||
| { | |||
| setter = _revive_setter; | |||
| } | |||
| else | |||
| { | |||
| setter = Loader.setattr; | |||
| } | |||
| if (loaded_nodes.ContainsKey(child_id)) | |||
| { | |||
| // skip the logging.warning | |||
| continue; | |||
| } | |||
| if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name)) | |||
| { | |||
| (obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0"; | |||
| } | |||
| if(obj_child is TrackableDataStructure) | |||
| { | |||
| setter = (x, y, z) => { }; | |||
| } | |||
| var child_path = $"{parent_path}.{child_name}"; | |||
| _node_paths[child_id] = child_path; | |||
| _add_children_recreated_from_config(obj_child, child_proto, child_id); | |||
| loaded_nodes[child_id] = (obj_child, setter); | |||
| } | |||
| } | |||
| bool _try_build_layer(Model obj, int node_id, Shape build_input_shape) | |||
| private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
| { | |||
| if (obj.Built) | |||
| return true; | |||
| if(build_input_shape is null) | |||
| { | |||
| build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); | |||
| } | |||
| if(build_input_shape is not null) | |||
| { | |||
| obj.build(build_input_shape); | |||
| // In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`. | |||
| // On the one hand, C# does not support call a method from specified parent class. | |||
| // On the other hand, currently All class derived from Layer call `Layer.Build` or | |||
| // move the implementation of `Layer.build` to its own `build` method. | |||
| // Therefore we do not call it here. | |||
| // However, it's still quite risky once in the future a certain class derived from | |||
| // `Layer` does not call `Layer.build`. | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape) | |||
| /// <summary> | |||
| /// Infers input shape of layer from SavedModel functions. | |||
| /// </summary> | |||
| /// <param name="layer_node_id"></param> | |||
| /// <param name="convert_to_shapes"></param> | |||
| /// <returns></returns> | |||
| private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) | |||
| { | |||
| if (obj.Built) | |||
| return true; | |||
| var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); | |||
| if(call_fn_id is null) | |||
| { | |||
| return null; | |||
| } | |||
| var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions; | |||
| if(concrete_functions is null) | |||
| { | |||
| return null; | |||
| } | |||
| var call_fn_name = concrete_functions[0]; | |||
| var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child) | |||
| { | |||
| if(path_to_child is null || path_to_child.Count() == 0) | |||
| { | |||
| return parent_id; | |||
| } | |||
| foreach(var child in _proto.Nodes[parent_id].Children) | |||
| { | |||
| if(child.LocalName == path_to_child.First()) | |||
| { | |||
| return _search_for_child_node(child.NodeId, path_to_child.Skip(1)); | |||
| } | |||
| } | |||
| return null; | |||
| } | |||
| private bool _is_graph_network(Layer layer) | |||
| { | |||
| // TODO: deal with `RevivedLayer` | |||
| if(layer is Functional) | |||
| { | |||
| return (layer as Functional).IsGraphNetwork || layer is Sequential; | |||
| } | |||
| return false; | |||
| } | |||
| private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) | |||
| { | |||
| // TODO: deal with `RevivedLayer` | |||
| } | |||
| /// <summary> | |||
| /// Creates edges for nodes that are recreated from config. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| private Action<object, object, object> _config_node_setter(Action<object, object, object> setter) | |||
| { | |||
| void setattr_wrapper(object obj, object name, object value) | |||
| { | |||
| Debug.Assert(obj is Trackable); | |||
| Debug.Assert(name is string); | |||
| if((obj as Trackable)._lookup_dependency(name as string) is null) | |||
| { | |||
| setter(obj, name, value); | |||
| } | |||
| } | |||
| return setattr_wrapper; | |||
| } | |||
| } | |||
| } | |||
| @@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
| public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, | |||
| SaveOptions? options, bool save_traces = true) | |||
| { | |||
| if (!overwrite && File.Exists(filepath)) | |||
| @@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils | |||
| BadConsumers = { } | |||
| }, | |||
| Identifier = layer.ObjectIdentifier, | |||
| Metadata = layer.TrackingMetadata | |||
| Metadata = layer.GetTrackingMetadata() | |||
| }; | |||
| metadata.Nodes.Add(saved_object); | |||
| @@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| @@ -0,0 +1,96 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| public class KerasLoadModelUtils | |||
| { | |||
| /// <summary> | |||
| /// Corresponding to keras/saving/save.py/load_model | |||
| /// </summary> | |||
| /// <param name="filepath"></param> | |||
| /// <param name="custom_objects"></param> | |||
| /// <param name="compile"></param> | |||
| /// <param name="options"></param> | |||
| /// <returns></returns> | |||
| public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null, | |||
| bool compile = true, LoadOptions? options = null) | |||
| { | |||
| using (SharedObjectSavingScope.Enter()) | |||
| { | |||
| using (LoadContext.load_context(options)) | |||
| { | |||
| if (!File.Exists(filepath) && !Directory.Exists(filepath)) | |||
| { | |||
| throw new IOException($"No file or directory found at {filepath}."); | |||
| } | |||
| if (Directory.Exists(filepath)) | |||
| { | |||
| return load(filepath, compile, options); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed."); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
| { | |||
| SavedMetadata metadata = new SavedMetadata(); | |||
| var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
| string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH); | |||
| if (File.Exists(path_to_metadata_pb)) | |||
| { | |||
| metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read)); | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| } | |||
| if (metadata.Nodes is null || metadata.Nodes.Count == 0) | |||
| { | |||
| return Loader.load(path, options: options) as Model; | |||
| } | |||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
| keras_loader.load_layers(compile: compile); | |||
| Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new(); | |||
| nodes_to_load["root"] = (null, null); | |||
| foreach(var item in keras_loader.LoadedNodes) | |||
| { | |||
| nodes_to_load[keras_loader.get_path(item.Key)] = item.Value; | |||
| } | |||
| var loaded = Loader.load_partial(path, nodes_to_load, options); | |||
| keras_loader.finalize_objects(); | |||
| // keras_loader.del_tracking(); | |||
| var model = loaded["root"]; | |||
| if(model is Model && compile) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| return model; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow.Keras.Saving.SavedModel | |||
| { | |||
| // TODO: remove this class to common project. | |||
| public class ContextHandler: IDisposable | |||
| { | |||
| public Action<bool> DisposeCallBack { get; set; } | |||
| public void Dispose() | |||
| { | |||
| DisposeCallBack.Invoke(true); | |||
| } | |||
| } | |||
| public class LoadContext | |||
| { | |||
| private bool _entered_load_context; | |||
| private LoadOptions? _load_options; | |||
| private static ThreadLocal<LoadContext> _load_context = new(); | |||
| private LoadContext() | |||
| { | |||
| _entered_load_context = false; | |||
| _load_options = null; | |||
| } | |||
| public void set_load_options(LoadOptions load_options) | |||
| { | |||
| _load_options = load_options; | |||
| _entered_load_context = true; | |||
| } | |||
| private void clear_load_options() | |||
| { | |||
| _load_options = null; | |||
| _entered_load_context = false; | |||
| } | |||
| private LoadOptions? load_options() | |||
| { | |||
| return _load_options; | |||
| } | |||
| public static ContextHandler load_context(LoadOptions? load_options) | |||
| { | |||
| if(_load_context.Value is null) | |||
| { | |||
| _load_context.Value = new LoadContext(); | |||
| } | |||
| _load_context.Value.set_load_options(load_options); | |||
| return new ContextHandler() | |||
| { | |||
| DisposeCallBack = _ => _load_context.Value.clear_load_options() | |||
| }; | |||
| } | |||
| public static LoadOptions? get_load_option() | |||
| { | |||
| return _load_context.Value.load_options(); | |||
| } | |||
| public static bool in_load_context() | |||
| { | |||
| return _load_context.Value._entered_load_context; | |||
| } | |||
| } | |||
| } | |||
| @@ -19,15 +19,21 @@ using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Data; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Reflection; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class generic_utils | |||
| { | |||
| private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config"; | |||
| /// <summary> | |||
| /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | |||
| /// </summary> | |||
| @@ -51,6 +57,58 @@ namespace Tensorflow.Keras.Utils | |||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | |||
| } | |||
| public static Layer deserialize_keras_object(string class_name, JToken config) | |||
| { | |||
| var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
| var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
| .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
| var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
| var args = deserializationGenericMethod.Invoke(config, null); | |||
| var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
| Debug.Assert(layer is Layer); | |||
| return layer as Layer; | |||
| } | |||
| public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||
| { | |||
| var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); | |||
| Debug.Assert(layer is Layer); | |||
| return layer as Layer; | |||
| } | |||
| public static LayerArgs deserialize_layer_args(string class_name, JToken config) | |||
| { | |||
| var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); | |||
| var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) | |||
| .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); | |||
| var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); | |||
| var args = deserializationGenericMethod.Invoke(config, null); | |||
| Debug.Assert(args is LayerArgs); | |||
| return args as LayerArgs; | |||
| } | |||
| public static ModelConfig deserialize_model_config(JToken json) | |||
| { | |||
| ModelConfig config = new ModelConfig(); | |||
| config.Name = json["name"].ToObject<string>(); | |||
| config.Layers = new List<LayerConfig>(); | |||
| var layersToken = json["layers"]; | |||
| foreach (var token in layersToken) | |||
| { | |||
| var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||
| config.Layers.Add(new LayerConfig() | |||
| { | |||
| Config = args, | |||
| Name = token["name"].ToObject<string>(), | |||
| ClassName = token["class_name"].ToObject<string>(), | |||
| InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||
| }); | |||
| } | |||
| config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||
| config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>(); | |||
| return config; | |||
| } | |||
| public static string to_snake_case(string name) | |||
| { | |||
| return string.Concat(name.Select((x, i) => | |||
| @@ -60,5 +118,15 @@ namespace Tensorflow.Keras.Utils | |||
| x.ToString(); | |||
| })).ToLower(); | |||
| } | |||
| /// <summary> | |||
| /// Determines whether config appears to be a valid layer config. | |||
| /// </summary> | |||
| /// <param name="config"></param> | |||
| /// <returns></returns> | |||
| public static bool validate_config(JObject config) | |||
| { | |||
| return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY); | |||
| } | |||
| } | |||
| } | |||
| @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils | |||
| } | |||
| var trainable_count = count_params(model, model.TrainableVariables); | |||
| var non_trainable_count = count_params(model, model.non_trainable_variables); | |||
| var non_trainable_count = count_params(model, model.NonTrainableVariables); | |||
| print($"Total params: {trainable_count + non_trainable_count}"); | |||
| print($"Trainable params: {trainable_count}"); | |||
| @@ -0,0 +1,9 @@ | |||
| ´$root"_tf_keras_network*’${"name": "model", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": []}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]]}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]]}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "ndim": 4, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 28, 28, 1]}, "float32", "input_1"]}, "keras_version": "2.11.0", "backend": "tensorflow", "model_config": {"class_name": "Functional", "config": {"name": "model", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "name": "input_1", "inbound_nodes": [], "shared_object_id": 0}, {"class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "name": "flatten", "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense", "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "name": "dense_1", "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7}, {"class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "name": "softmax", "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8}], "input_layers": [["input_1", 0, 0]], "output_layers": [["softmax", 0, 0]]}}}2 | |||
| †root.layer-0"_tf_keras_input_layer*Ö{"class_name": "InputLayer", "name": "input_1", "dtype": "float32", "sparse": false, "ragged": false, "batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 28, 28, 1]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}2 | |||
| Íroot.layer-1"_tf_keras_layer*£{"name": "flatten", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Flatten", "config": {"name": "flatten", "trainable": true, "dtype": "float32", "data_format": "channels_last"}, "inbound_nodes": [[["input_1", 0, 0, {}]]], "shared_object_id": 1, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 1, "axes": {}}, "shared_object_id": 14}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 28, 28, 1]}}2 | |||
| ¯root.layer_with_weights-0"_tf_keras_layer*ø{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten", 0, 0, {}]]], "shared_object_id": 4, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 784}}, "shared_object_id": 15}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 784]}}2 | |||
| ²root.layer_with_weights-1"_tf_keras_layer*û{"name": "dense_1", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 10, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 5}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 6}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dense", 0, 0, {}]]], "shared_object_id": 7, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 100}}, "shared_object_id": 16}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 100]}}2 | |||
| Šroot.layer-4"_tf_keras_layer*à{"name": "softmax", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Softmax", "config": {"name": "softmax", "trainable": true, "dtype": "float32", "axis": -1}, "inbound_nodes": [[["dense_1", 0, 0, {}]]], "shared_object_id": 8, "build_input_shape": {"class_name": "TensorShape", "items": [null, 10]}}2 | |||
| ¹Troot.keras_api.metrics.0"_tf_keras_metric*‚{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 17}2 | |||
| ™Uroot.keras_api.metrics.1"_tf_keras_metric*â{"class_name": "MeanMetricWrapper", "name": "sparse_categorical_accuracy", "dtype": "float32", "config": {"name": "sparse_categorical_accuracy", "dtype": "float32", "fn": "sparse_categorical_accuracy"}, "shared_object_id": 18}2 | |||
| @@ -0,0 +1,68 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.NumPy; | |||
| using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| [TestClass] | |||
| public class SequentialModelLoad | |||
| { | |||
| [TestMethod] | |||
| public void SimpleModelFromAutoCompile() | |||
| { | |||
| var model = keras.models.load_model(@"Assets/simple_model_from_auto_compile"); | |||
| model.summary(); | |||
| model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| // check the weights | |||
| var kernel1 = np.load(@"Assets/simple_model_from_auto_compile/kernel1.npy"); | |||
| var bias0 = np.load(@"Assets/simple_model_from_auto_compile/bias0.npy"); | |||
| Assert.IsTrue(kernel1.Zip(model.TrainableWeights[2].numpy()).All(x => x.First == x.Second)); | |||
| Assert.IsTrue(bias0.Zip(model.TrainableWeights[1].numpy()).All(x => x.First == x.Second)); | |||
| var data_loader = new MnistModelLoader(); | |||
| var num_epochs = 1; | |||
| var batch_size = 8; | |||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = false, | |||
| ValidationSize = 50000, | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| } | |||
| [TestMethod] | |||
| public void AlexnetFromSequential() | |||
| { | |||
| new SequentialModelSave().AlexnetFromSequential(); | |||
| var model = keras.models.load_model(@"./alexnet_from_sequential"); | |||
| model.summary(); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||
| var num_epochs = 1; | |||
| var batch_size = 8; | |||
| var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
| } | |||
| } | |||
| @@ -1,27 +1,21 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using System.Diagnostics; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Operations; | |||
| using System.Diagnostics; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| [TestClass] | |||
| public class SequentialModelTest | |||
| public class SequentialModelSave | |||
| { | |||
| [TestMethod] | |||
| public void SimpleModelFromAutoCompile() | |||
| @@ -63,6 +57,8 @@ public class SequentialModelTest | |||
| keras.layers.Softmax(1) | |||
| }); | |||
| model.summary(); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| var data_loader = new MnistModelLoader(); | |||
| @@ -82,7 +78,7 @@ public class SequentialModelTest | |||
| } | |||
| [TestMethod] | |||
| public void AlexModelFromSequential() | |||
| public void AlexnetFromSequential() | |||
| { | |||
| Model model = KerasApi.keras.Sequential(new List<ILayer>() | |||
| { | |||
| @@ -116,7 +112,7 @@ public class SequentialModelTest | |||
| keras.layers.Softmax(1) | |||
| }); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||
| var num_epochs = 1; | |||
| var batch_size = 8; | |||
| @@ -125,7 +121,7 @@ public class SequentialModelTest | |||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
| model.save("./pb_alex_sequential", save_format: "tf"); | |||
| model.save("./alexnet_from_sequential", save_format: "tf"); | |||
| // The saved model can be test with the following python code: | |||
| #region alexnet_python_code | |||
| @@ -27,4 +27,28 @@ | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Keras\Tensorflow.Keras.csproj" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <None Update="Assets\simple_model_from_auto_compile\fingerprint.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\keras_metadata.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\saved_model.pb"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\variables\variables.data-00000-of-00001"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\variables\variables.index"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\kernel1.npy"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| <None Update="Assets\simple_model_from_auto_compile\bias0.npy"> | |||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
| </None> | |||
| </ItemGroup> | |||
| </Project> | |||