From b92b08d6290477150c403711b98778e8cae55425 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Wed, 25 Jan 2023 10:14:15 +0800 Subject: [PATCH] Implement layer serializations. --- .../Checkpoint/TrackableView.cs | 2 +- src/TensorFlowNET.Core/DisposableObject.cs | 68 ++++++++ .../Saving/SavedModel/AugmentedGraphView.cs | 4 +- .../Training/data_structures.cs | 11 +- .../Variables/BaseResourceVariable.cs | 2 +- .../Variables/RefVariable.cs | 3 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 2 + .../Saving/SavedModel/SaveImpl.cs | 53 ++++++- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 39 ++++- .../SavedModel/serialized_attributes.cs | 145 +++++++++--------- .../Saving/SavedModel/utils.cs | 14 ++ 12 files changed, 257 insertions(+), 93 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 6d81d2c9..69bf76fd 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -24,7 +24,7 @@ public class TrackableView Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach(var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 3c70739b..7fac3d0f 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -17,6 +17,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Tensorflow.Train; namespace Tensorflow { @@ -90,4 +91,71 @@ namespace Tensorflow Dispose(false); } } + + public abstract class DisposableTrackableObject: Trackable, IDisposable + { + protected IntPtr _handle; + protected bool _disposed; + + protected DisposableTrackableObject() + { } + + protected DisposableTrackableObject(IntPtr handle) + => _handle = handle; + + private void Dispose(bool disposing) + { + if (_disposed) + return; + + //first handle managed, they might use the unmanaged resources. + if (disposing) + { + // dispose managed state (managed objects). + DisposeManagedResources(); + } + + // free unmanaged memory + if (_handle != IntPtr.Zero) + { + // Call the appropriate methods to clean up + // unmanaged resources here. + // If disposing is false, + // only the following code is executed. + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; + } + + // Note disposing has been done. + _disposed = true; + } + + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + + public void Dispose() + { + Dispose(true); + // This object will be cleaned up by the Dispose method. + // Therefore, you should call GC.SupressFinalize to + // take this object off the finalization queue + // and prevent finalization code for this object + // from executing a second time. + GC.SuppressFinalize(this); + } + + ~DisposableTrackableObject() + { + Dispose(false); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 6723206c..82da2ee9 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -23,10 +23,10 @@ public class AugmentedGraphView: ObjectGraphView list_children(Root); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) { Dictionary children = new(); - foreach (var pair in base.list_children(obj, save_type)) + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) { var name = pair.Name; var child = pair.Refer; diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index 4cb78181..d4e9c401 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -142,21 +142,26 @@ namespace Tensorflow.Training return value; } - protected static Trackable wrap_or_unwrap(NoDependency value) + public static Trackable wrap_or_unwrap(NoDependency value) { return value.Value; } - protected static Trackable wrap_or_unwrap(Trackable value) + public static Trackable wrap_or_unwrap(Trackable value) { return value; } - protected static Trackable wrap_or_unwrap(IList value) + public static Trackable wrap_or_unwrap(IList value) { return new ListWrapper(value); } + public static Trackable wrap_or_unwrap(IEnumerable value) + { + return new ListWrapper(value.ToList()); + } + protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) { value = wrap_or_unwrap(value); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4526730f..f217a052 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -7,7 +7,7 @@ using static Tensorflow.Binding; namespace Tensorflow { - public class BaseResourceVariable : DisposableObject + public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 38b5b734..7b08f3ea 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -20,11 +20,12 @@ using System; using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; +using Tensorflow.Train; namespace Tensorflow { [Obsolete] - public partial class RefVariable : IVariableV1, IProtoBuf + public partial class RefVariable: Trackable, IVariableV1, IProtoBuf { protected string _name; public string UniqueId => _name; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e95e55d6..b9b01dae 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -288,6 +288,8 @@ namespace Tensorflow.Keras.Engine } } + public List Variables => weights; + public virtual LayerArgs get_config() => args; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index ba0bcc66..7168e25b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -1,5 +1,8 @@ using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Train; +using Tensorflow.Training; namespace Tensorflow.Keras.Saving.SavedModel; @@ -10,10 +13,54 @@ public partial class KerasSavedModelUtils return false; } - public static IDictionary wrap_layer_objects(Layer layer, object serialization_cache) + /// + /// Returns extra trackable objects to attach to the serialized layer. + /// + /// + /// + /// + public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) { - // TODO: process the loss + // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. - return null; + // TODO: change the inherits of `Variable` and revise the implmentation. + var variables = layer.Variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var trainable_variables = layer.TrainableVariables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + var non_trainable_variables = layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + }); + + Dictionary res = new(); + res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); + res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); + res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + + return res; + } + + /// + /// Returns dict of wrapped layer call function and losses in tf.functions. + /// + /// + /// + /// + public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + { + // TODO: deal with type `RevivedLayer` and `Sequential`. + + // skip the process because of lack of APIs of `Layer`. + + return new Dictionary(); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 36111a18..a399eaf1 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -17,10 +17,10 @@ public abstract class SavedModelSaver public abstract string ObjectIdentifier { get; } public abstract string TrackingMetadata { get; } - public abstract IDictionary objects_to_serialize( + public abstract IDictionary objects_to_serialize( IDictionary serialization_cache); - public abstract IDictionary functions_to_serialize( + public abstract IDictionary functions_to_serialize( IDictionary serialization_cache); public IDictionary trackable_children(IDictionary? serialization_cache) @@ -32,8 +32,7 @@ public abstract class SavedModelSaver var children = objects_to_serialize(serialization_cache); - return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) - .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) + return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index f0ad7450..7a0ddd21 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,26 +19,51 @@ public class LayerSavedModelSaver: SavedModelSaver get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary serialization_cache) { - throw new System.NotImplementedException(); + return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } /// /// Generates or retrieves serialized attributes from cache. /// /// - protected void get_serialized_attributes(IDictionary serialization_cache) + protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) { // TODO: deal with cache. - Layer a; - + var serialized_attr = SerializedAttributes.Create(_obj); + + // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. + if (KerasSavedModelUtils.should_skip_serialization(_obj)) + { + return serialized_attr; + } + + var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); + + serialized_attr.set_and_validate_objects(object_dict); + serialized_attr.set_and_validate_functions(function_dict); + return serialized_attr; + } + + /// + /// Returns dictionary of serialized attributes. + /// + /// + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + { + var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + + functions["_default_save_signature"] = null; + + return (objects, functions); } public override string TrackingMetadata diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index ff3c7875..804ea1a9 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -17,15 +17,15 @@ namespace Tensorflow.Keras.Saving.SavedModel public abstract class SerializedAttributes { protected IDictionary _object_dict; - protected IDictionary _function_dict; + protected IDictionary _function_dict; protected AutoTrackable _keras_trackable; protected HashSet _all_functions; protected HashSet _all_checkpointable_objects; - protected SerializedAttributes() + private SerializedAttributes() { _object_dict= new Dictionary(); - _function_dict= new Dictionary(); + _function_dict= new Dictionary(); _keras_trackable= new AutoTrackable(); _all_functions= new HashSet(); _all_checkpointable_objects= new HashSet(); @@ -34,25 +34,35 @@ namespace Tensorflow.Keras.Saving.SavedModel protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) { _object_dict = new Dictionary(); - _function_dict = new Dictionary(); + _function_dict = new Dictionary(); _keras_trackable = new AutoTrackable(); _all_checkpointable_objects = new HashSet(checkpointable_objects); _all_functions = new HashSet(functions); } - public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + protected SerializedAttributes((IEnumerable, IEnumerable) objects_and_functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(objects_and_functions.Item1); + _all_functions = new HashSet(objects_and_functions.Item2); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); /// /// Returns functions to attach to the root object during serialization. /// - public IDictionary FunctionsToSerialize + public IDictionary FunctionsToSerialize { get { - Dictionary functions = new(); + Dictionary functions = new(); foreach(var pair in Functions) { if (_all_functions.Contains(pair.Key)) @@ -82,7 +92,7 @@ namespace Tensorflow.Keras.Saving.SavedModel /// Saves function dictionary, and validates dictionary values. /// /// - public IDictionary set_and_validate_functions(IDictionary function_dict) + public IDictionary set_and_validate_functions(IDictionary function_dict) { foreach(var key in _all_functions) { @@ -186,94 +196,87 @@ namespace Tensorflow.Keras.Saving.SavedModel // However, currently it's just a normal class. public class CommonEndPoints: SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), + functions.Concat(new string[] { })) { - if(checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if(functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), - // TODO: remove the `__call__`. - functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) - ); + + } + + public CommonEndPoints() : + //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + base(new string[] { "variables", "trainable_variables"}, + new string[] {}) + { + } } public class LayerAttributes: CommonEndPoints { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public LayerAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), + functions.Concat(new string[] { })) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), - functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) - ); + + } + + public LayerAttributes() : + //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, + // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + base(new string[] { "non_trainable_variables", "layers" }, + new string[] { }) + { + } } public class ModelAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public ModelAttributes(IEnumerable checkpointable_objects, IEnumerable functions): + base(checkpointable_objects, functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + + } + + public ModelAttributes(): base() + { + } } public class MetricAttributes : SerializedAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public MetricAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects.Concat(new string[] { "variables" }), functions) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "variables" }), - functions - ); + + } + + public MetricAttributes() : + base(new string[] { "variables" }, new string[] {}) + { + } } public class RNNAttributes: LayerAttributes { - protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + public RNNAttributes(IEnumerable checkpointable_objects, IEnumerable functions) : + base(checkpointable_objects, functions.Concat(new string[] {"states"})) { - if (checkpointable_objects is null) - { - checkpointable_objects = new List(); - } - if (functions is null) - { - functions = new List(); - } - return base.get_objects_and_functions_recursively( - checkpointable_objects.Concat(new string[] { "states" }), - functions - ); + + } + + public RNNAttributes() : + base(new string[] { }, new string[] { "states" }) + { + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs index a5d84d67..3054271a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving.SavedModel; @@ -12,6 +14,18 @@ public partial class KerasSavedModelUtils ShouldHaveTraces = save_traces; return res; } + + public static IEnumerable list_all_layers(Layer layer) + { + if(layer is Model) + { + return (layer as Model).Layers; + } + else + { + return new List(layer._flatten_layers(false, false)); + } + } } ///