From 3a6a59e18cbfd3a948f3e4c16259e3ee07c73878 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sun, 5 Feb 2023 01:08:37 +0800 Subject: [PATCH] Check and refine the code. --- .../Checkpoint/CheckPointUtils.cs | 6 +- .../Checkpoint/CheckpointOptions.cs | 2 +- .../Checkpoint/ObjectGraphView.cs | 4 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 4 +- .../Checkpoint/SaveUtilV1.cs | 19 ++- .../Checkpoint/TrackableView.cs | 2 +- .../Checkpoint/functional_saver.cs | 161 ++++++++---------- src/TensorFlowNET.Core/DisposableObject.cs | 2 +- .../Exceptions/AssertionError.cs | 2 +- .../Training/AutoTrackable.cs | 6 +- .../Training/Saving/SaveableObject.cs | 4 +- .../Training/Saving/SavedModel/AssetInfo.cs | 2 +- .../Saving/SavedModel/AugmentedGraphView.cs | 4 +- .../Training/Saving/SavedModel/Constants.cs | 2 +- .../Saving/SavedModel/RevivedTypes.cs | 2 +- .../Training/Saving/SavedModel/SaveType.cs | 2 +- .../Saving/SavedModel/SaveableView.cs | 16 +- .../Saving/SavedModel/TagConstants.cs | 2 +- .../Training/Saving/SavedModel/builder.cs | 2 +- .../Training/Saving/SavedModel/save.cs | 6 +- .../SavedModel/signature_serialization.cs | 2 +- .../Training/Saving/SavedModel/utils.cs | 2 +- .../Saving/saveable_object_util.py.cs | 21 ++- .../Saving/SavedModel/Constants.cs | 2 +- .../Saving/SavedModel/KerasObjectWrapper.cs | 11 -- .../Saving/SavedModel/Save.cs | 66 ++++++- .../Saving/SavedModel/SaveImpl.cs | 66 ------- .../Saving/SavedModel/base_serialization.cs | 3 +- .../Saving/SavedModel/layer_serialization.cs | 2 +- .../Saving/SavedModel/utils.cs | 2 +- .../SaveModel/SequentialModelTest.cs | 8 +- 31 files changed, 194 insertions(+), 241 deletions(-) delete mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs delete mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index cd37703b..8ae2dae8 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -12,9 +12,9 @@ namespace Tensorflow.Checkpoint; public static class CheckPointUtils { private static string _ESCAPE_CHAR = "."; - public static (List, Dictionary>, Dictionary, + public static (IList, IDictionary>, IDictionary, IDictionary>, - Dictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) + IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); @@ -149,4 +149,4 @@ public static class CheckPointUtils // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); // } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs index f14b5ce7..75b392af 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -2,4 +2,4 @@ public record class CheckpointOptions( string? experimental_io_device = null, - bool experimental_enable_async_checkpoint = false); \ No newline at end of file + bool experimental_enable_async_checkpoint = false); diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs index cb01b539..f435dd88 100644 --- a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -45,7 +45,7 @@ public class ObjectGraphView: TrackableView, ICloneable get => _attached_dependencies; } - public virtual (List, Dictionary>) breadth_first_traversal() + public virtual (IList, IDictionary>) breadth_first_traversal() { return base._descendants_with_paths(); } @@ -61,4 +61,4 @@ public class ObjectGraphView: TrackableView, ICloneable { throw new NotImplementedException(); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index 84e0ca4e..c54cc93f 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -58,7 +58,7 @@ namespace Tensorflow.Checkpoint return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); } - private static (List, Dictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) + private static (IList, IDictionary) gather_trackable_data(ObjectGraphView graph_view, IDictionary? object_map) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); Dictionary object_names = new(); @@ -173,7 +173,7 @@ namespace Tensorflow.Checkpoint tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor.GetValueA() is SaveSpec) + if(maybe_tensor.IsTypeOrDeriveFrom()) { throw new NotImplementedException(); //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 4f1d04d2..3267ae12 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Checkpoint; public static class SaveUtilV1 { - public static (Dictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, + public static (IDictionary>, object?) get_checkpoint_factories_and_keys(IDictionary object_names, IDictionary? object_map = null) { // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, @@ -44,7 +44,7 @@ public static class SaveUtilV1 return (checkpoint_factory_map, null); } - public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, + public static (IList, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { @@ -73,7 +73,7 @@ public static class SaveUtilV1 } } - public static (List, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, + public static (IList, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); @@ -129,7 +129,8 @@ public static class SaveUtilV1 return object_graph_proto; } - private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, + private static (IList, object?, IDictionary>?) add_attributes_to_object_graph( + IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -150,7 +151,7 @@ public static class SaveUtilV1 return (named_saveable_objects, feed_additions, null); } - public static (List, object?) generate_saveable_objects( + public static (IList, object?) generate_saveable_objects( IDictionary> checkpoint_factory_map, TrackableObjectGraph? object_graph_proto, IDictionary? node_ids, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -178,13 +179,13 @@ public static class SaveUtilV1 // TODO: oneflow python has a process with callable `saveable_factory`. List saveables = new(); - if (maybe_saveable.DataType == typeof(MySaveableObject)) + if (maybe_saveable.TryGet(out var s)) { - saveables.Add(maybe_saveable.GetValueB()); + saveables.Add(s); } else { - saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue() as Trackable, key)); } foreach (var saveable in saveables) @@ -219,4 +220,4 @@ public record class CheckpointFactoryData Maybe factory, string name, string checkpoint_key -); \ No newline at end of file +); diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index f89dc10d..dab6d5d9 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -52,7 +52,7 @@ public class TrackableView /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths /// - protected (List, Dictionary>) _descendants_with_paths() + protected (IList, IDictionary>) _descendants_with_paths() { List bfs_sorted = new(); Queue to_visit = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 90bbccf0..09904d68 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -14,112 +14,91 @@ using Tensorflow.Training; using Tensorflow.Graphs; using System.Xml.Linq; using System.Diagnostics; +using RestoreFunc = System.Func; namespace Tensorflow.Checkpoint { - /// - /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. - /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. - /// - public interface IFunctionHolder - { - int ArgCount { get; } - object DynamicInvoke(params object[] args); - } - internal record class FunctionHolder(Func Func): IFunctionHolder - { - public int ArgCount => 0; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - public TR Invoke() - { - return Func.Invoke(); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 1; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 2; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } - internal record class FunctionHolder(Func Func) : IFunctionHolder - { - public int ArgCount => 3; - public object DynamicInvoke(params object[] args) - { - return Func.DynamicInvoke(args); - } - } public class Maybe { private TA? _valueA = default(TA); private TB? _valueB = default(TB); private Type _type; - private bool _assigned = false; + private bool _assignedTA; public Maybe(TA value) { _valueA = value; _type= typeof(TA); - _assigned = true; + _assignedTA = true; } public Maybe(TB value) { _valueB = value; _type = typeof(TB); - _assigned = true; + _assignedTA = false; } public Type DataType => _type; - public TA GetValueA() + /// + /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. + /// It returns + /// + /// + /// + /// + public bool TryGet(out T? res) { - if(!_assigned || DataType != typeof(TA)) + if(_valueA is T && _valueB is not T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + res = (T)(object)_valueA; + return _assignedTA; } - return _valueA; - } - public TB GetValueB() - { - if (!_assigned || DataType != typeof(TB)) + else if(_valueA is not T && _valueB is T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + res = (T)(object)_valueB; + return !_assignedTA; } - return _valueB; + res = default(T); + return false; } - public object GetValue() + + public bool IsTypeOrDeriveFrom() { - if (!_assigned) + if (_valueA is T && _valueB is not T) { - throw new TypeError("Cannot get the data because of wrong specified type."); + return _assignedTA; } - if(DataType == typeof(TA) && _valueA is not null) + else if (_valueA is not T && _valueB is T) { - return _valueA; + return !_assignedTA; } - else if(DataType == typeof(TB) && _valueB is not null) + else if (_valueA is T && _valueB is T) { - return _valueB; + return true; } - else if(DataType == typeof(TA)) + else { - return _valueA; + return false; + } + } + + public T GetValue() + { + if (_valueA is T && _valueB is not T) + { + return (T)(object)_valueA; + } + else if (_valueA is not T && _valueB is T) + { + return (T)(object)_valueB; + } + else if (_valueA is T && _valueB is T) + { + throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); } else { - return _valueB; + throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); } } @@ -170,9 +149,8 @@ namespace Tensorflow.Checkpoint { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - if(maybe_tensor.DataType == typeof(SaveSpec)) + if(maybe_tensor.TryGet(out var spec)) { - var spec = maybe_tensor.GetValueB(); var tensor_value = spec.tensor; if (tensor_value is not null) { @@ -183,7 +161,7 @@ namespace Tensorflow.Checkpoint } else { - var tensor = maybe_tensor.GetValueA(); + var tensor = maybe_tensor.GetValue(); tensor_names.Add(checkpoint_key); tensors.Add(tensor); slice_specs.Add(slice_spec); @@ -215,16 +193,15 @@ namespace Tensorflow.Checkpoint var slice_spec = slice.Key; var maybe_tensor = slice.Value; // TODO: deal with other types. Currently only `SaveSpec` is allowed. - if(maybe_tensor.DataType == typeof(SaveSpec)) + if(maybe_tensor.TryGet(out var spec)) { - var spec = maybe_tensor.GetValueB(); tensor_dtypes.Add(spec.dtype); slice_specs.Add(spec.slice_spec); tensor_names.Add(spec.name); } else { - var tensor = maybe_tensor.GetValueA(); + var tensor = maybe_tensor.GetValue(); tensor_dtypes.Add(tensor.dtype); slice_specs.Add(slice_spec); tensor_names.Add(checkpoint_key); @@ -268,9 +245,9 @@ namespace Tensorflow.Checkpoint public class MultiDeviceSaver { private Dictionary _single_device_savers; - private IDictionary _registered_savers; - private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; - private Dictionary> _restore_fn_to_keys; + private IDictionary _registered_savers; + private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; /// /// /// @@ -280,24 +257,28 @@ namespace Tensorflow.Checkpoint public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { - _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); - _restore_fn_to_keys = new Dictionary>(); + _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); + _restore_fn_to_keys = new Dictionary>(); Dictionary>> tensors_by_device= new(); foreach(var pair in serialized_tensors) { var obj = pair.Key; var tensor_dict = pair.Value; - IFunctionHolder restore_fn; + RestoreFunc restore_fn; if(obj == Trackable.None) { - restore_fn = new FunctionHolder(() => null); + restore_fn = new RestoreFunc(x => null); } else { - restore_fn = new FunctionHolder>>, IDictionary>(x => + restore_fn = new RestoreFunc(x => { - return obj._restore_from_tensors(x); + if(x is IDictionary>>) + { + return obj._restore_from_tensors(x as IDictionary>>); + } + throw new TypeError($"Expected `IDictionary>>` as input, got{x.GetType()}."); }); } @@ -305,14 +286,14 @@ namespace Tensorflow.Checkpoint { var checkpoint_key = item.Key; IDictionary spec_to_tensor; - if(item.Value.DataType != typeof(IDictionary)) + if(item.Value.TryGet(out var t)) { spec_to_tensor = new Dictionary(); - spec_to_tensor[""] = item.Value.GetValueA(); + spec_to_tensor[""] = t; } else { - spec_to_tensor = item.Value.GetValueB(); + spec_to_tensor = item.Value.GetValue>(); } foreach(var spec in spec_to_tensor) @@ -342,7 +323,7 @@ namespace Tensorflow.Checkpoint _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); - _registered_savers = new Dictionary(); + _registered_savers = new Dictionary(); if(registered_savers is not null && registered_savers.Count > 0) { // TODO: complete the implementation. @@ -418,8 +399,8 @@ namespace Tensorflow.Checkpoint IDictionary restore_func() { - Dictionary>>> restore_fn_inputs = new(); - Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); Dictionary restore_ops = new(); foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) @@ -449,7 +430,7 @@ namespace Tensorflow.Checkpoint } else { - internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; + internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; } } else diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 7fac3d0f..c3c677ff 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -158,4 +158,4 @@ namespace Tensorflow Dispose(false); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs index 84ec24cb..977fe234 100644 --- a/src/TensorFlowNET.Core/Exceptions/AssertionError.cs +++ b/src/TensorFlowNET.Core/Exceptions/AssertionError.cs @@ -11,4 +11,4 @@ public class AssertionError : TensorflowException { } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 4d5a664e..4ba3e407 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -37,10 +37,10 @@ namespace Tensorflow.Train var properties = this.GetType().GetProperties(); foreach ( var property in properties ) { - string name = property.Name; - object value = property.GetValue(this, null); - if(value is Function || value is ConcreteFunction) + if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) { + string name = property.Name; + object value = property.GetValue(this, null); functions[name] = (Trackable)value; } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 43d36dba..1309a617 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -25,9 +25,9 @@ namespace Tensorflow { get { - if(_op.DataType == typeof(Tensor)) + if(_op.TryGet(out var tensor)) { - return _op.GetValueA(); + return tensor; } else { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs index 24c8f2f0..d1025782 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs @@ -8,4 +8,4 @@ public record class AssetInfo Dictionary asset_initializers_by_resource, Dictionary asset_filename_map, Dictionary asset_index -); \ No newline at end of file +); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 97162651..a9193335 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -86,7 +86,7 @@ public class AugmentedGraphView: ObjectGraphView return concrete_function; } - public override (List, Dictionary>) breadth_first_traversal() + public override (IList, IDictionary>) breadth_first_traversal() { Trackable get_merged_trackable(Trackable x) { @@ -130,4 +130,4 @@ public class AugmentedGraphView: ObjectGraphView { return _children_cache[obj][name]; } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs index cb7abada..726f6cfd 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs @@ -30,4 +30,4 @@ public static class Constants public static readonly string VARIABLES_DIRECTORY = "variables"; public static readonly string VARIABLES_FILENAME = "variables"; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index fa9d6e50..fe0403c3 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -14,4 +14,4 @@ public class RevivedTypes // TODO: complete the implementation. return null; } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs index b973fd41..8dd4f008 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs @@ -6,4 +6,4 @@ public enum SaveType { SAVEDMODEL, CHECKPOINT -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6132e025..1be54287 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -18,13 +18,13 @@ public class SaveableView { private AugmentedGraphView _augmented_graph_view; private SaveOptions _options; - private List _trackable_objects; + private IList _trackable_objects; private List _nodes; - private Dictionary> _node_paths; - private Dictionary _node_ids; + private IDictionary> _node_paths; + private IDictionary _node_ids; private IDictionary> _slot_variables; - private Dictionary _object_names; + private IDictionary _object_names; private List _gradient_functions; // to be completed private List _gradient_defs; // to be completed private List _concrete_functions; @@ -45,7 +45,7 @@ public class SaveableView { get => _nodes; } - public Dictionary NodeIds + public IDictionary NodeIds { get => _node_ids; } @@ -53,7 +53,7 @@ public class SaveableView { get => _gradient_defs; } - public Dictionary> NodePaths + public IDictionary> NodePaths { get => _node_paths; } @@ -84,7 +84,7 @@ public class SaveableView private void initialize_nodes_and_concrete_functions() { - _nodes = _trackable_objects.ConvertAll(x => x); // deep copy + _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy _gradient_functions = new(); _gradient_defs = new(); @@ -296,4 +296,4 @@ public class SaveableView proto.Nodes.Add(object_proto); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs index 9a066eed..6aa1fbde 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs @@ -7,4 +7,4 @@ public static class TagConstants public static readonly string EVAL = "eval"; public static readonly string GPU = "gpu"; public static readonly string TPU = "tpu"; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs index bcd3ae05..dbbab91d 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs @@ -19,4 +19,4 @@ public class BuilderUtils throw new NotImplementedException(); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index d82d49d8..94760e3d 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -81,8 +81,8 @@ public static partial class SavedModelUtils return (saved_nodes, node_paths); } - private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, - Dictionary>) _build_meta_graph(Trackable obj, + private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList, + IDictionary>) _build_meta_graph(Trackable obj, ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { using (SaveContext.save_context(options)) @@ -266,4 +266,4 @@ public static partial class SavedModelUtils } } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 0d34907f..4a0d3b00 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -104,4 +104,4 @@ public class SignatureMap: Trackable return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs index 2deff027..b0e6411c 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -54,4 +54,4 @@ public static partial class SavedModelUtils { return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 582e2431..a6e21e3e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -136,9 +136,8 @@ namespace Tensorflow { full_name = name + "_" + attr; } - if(factory.DataType == typeof(ResourceVariable)) + if(factory.TryGet(out var variable)) { - var variable = factory.GetValueA(); foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) { yield return op; @@ -146,8 +145,8 @@ namespace Tensorflow } else { - var variable = factory.GetValueB(); - foreach (var op in saveable_objects_for_op(variable, variable.name)) + var saveable = factory.GetValue(); + foreach (var op in saveable_objects_for_op(saveable, saveable.name)) { yield return op; } @@ -236,14 +235,14 @@ namespace Tensorflow string spec_name = name + TrackableUtils.escape_local_name(tensor_name); IDictionary internal_dict; - if(maybe_tensor.DataType == typeof(Tensor)) + if(maybe_tensor.TryGet(out var tensor)) { internal_dict= new Dictionary(); - internal_dict[""] = maybe_tensor.GetValueA(); + internal_dict[""] = tensor; } else { - internal_dict = maybe_tensor.GetValueB(); + internal_dict = maybe_tensor.GetValue>(); } foreach(var item in internal_dict) @@ -287,7 +286,7 @@ namespace Tensorflow var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; + tensor_dict.SetDefault(name, new Dictionary()).GetValue>()[slice_spec] = spec.tensor; } else { @@ -318,14 +317,14 @@ namespace Tensorflow var maybe_tensor = restored_tensors[name]; IDictionary dict; - if(maybe_tensor.DataType == typeof(Tensor)) + if(maybe_tensor.TryGet(out var tensor)) { dict = new Dictionary(); - dict[""] = maybe_tensor.GetValueA(); + dict[""] = tensor; } else { - dict = maybe_tensor.GetValueB(); + dict = maybe_tensor.GetValue>(); } saveable_restored_tensors.Add(dict[slice_spec]); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs index ea6853fd..3ea4f067 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs @@ -38,4 +38,4 @@ public static class Constants RNN_LAYER_IDENTIFIER, SEQUENTIAL_IDENTIFIER }; -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs deleted file mode 100644 index a5f315bb..00000000 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace Tensorflow.Keras.Saving.SavedModel; - -public class KerasObjectWrapper -{ - -} - -public class KerasObjectWrapper -{ - public T Item { get; set; } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 9d1c9609..c7b7e52f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -3,19 +3,15 @@ using System.Collections.Generic; using System.IO; using System.Linq; using Google.Protobuf; -using ICSharpCode.SharpZipLib.Zip; -using Tensorflow.Checkpoint; -using Tensorflow.Contexts; using Tensorflow.Functions; using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Utils; using Tensorflow.ModelSaving; using Tensorflow.Train; -using Tensorflow.Exceptions; -using Tensorflow.IO; using Tensorflow.Keras.Optimizers; using ThirdParty.Tensorflow.Python.Keras.Protobuf; using static Tensorflow.Binding; +using Tensorflow.Training; + namespace Tensorflow.Keras.Saving.SavedModel; @@ -108,5 +104,59 @@ public partial class KerasSavedModelUtils return metadata; } - -} \ No newline at end of file + public static bool should_skip_serialization(object layer) + { + return false; + } + + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. + + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); + + Dictionary res = new(); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs deleted file mode 100644 index f7e1bf45..00000000 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Keras.Engine; -using Tensorflow.Train; -using Tensorflow.Training; - -namespace Tensorflow.Keras.Saving.SavedModel; - -public partial class KerasSavedModelUtils -{ - public static bool should_skip_serialization(object layer) - { - return false; - } - - /// - /// Returns extra trackable objects to attach to the serialized layer. - /// - /// - /// - /// - public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) - { - // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. - - // TODO: change the inherits of `Variable` and revise the implmentation. - var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - - Dictionary res = new(); - res["variables"] = variables; - res["trainable_variables"] = trainable_variables; - res["non_trainable_variables"] = non_trainable_variables; - res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); - - return res; - } - - /// - /// Returns dict of wrapped layer call function and losses in tf.functions. - /// - /// - /// - /// - public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) - { - // TODO: deal with type `RevivedLayer` and `Sequential`. - - // skip the process because of lack of APIs of `Layer`. - - return new Dictionary(); - } -} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 60c4ee5b..eb88c895 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -34,5 +34,4 @@ public abstract class SavedModelSaver return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } - -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 8675ea65..03693cb5 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -162,4 +162,4 @@ public class InputLayerSavedModelSaver: SavedModelSaver return JsonConvert.SerializeObject(info); } } -} \ No newline at end of file +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index 3054271a..51f8d2c9 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -44,4 +44,4 @@ public class SaveOptionsContext: IDisposable { KerasSavedModelUtils.ShouldHaveTraces = _old_value; } -} \ No newline at end of file +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs index c3145344..269b9c05 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -73,7 +73,7 @@ public class SequentialModelTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 10000, + ValidationSize = 50000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); @@ -119,13 +119,13 @@ public class SequentialModelTest model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); var num_epochs = 1; - var batch_size = 16; + var batch_size = 8; var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); - model.save("./pb_elex_sequential", save_format: "tf"); + model.save("./pb_alex_sequential", save_format: "tf"); // The saved model can be test with the following python code: #region alexnet_python_code @@ -136,7 +136,7 @@ public class SequentialModelTest // return -a //if __name__ == '__main__': - // model = tf.keras.models.load_model("./pb_elex_sequential") + // model = tf.keras.models.load_model("./pb_alex_sequential") // model.summary() // num_classes = 5