diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 6f10fd2e..5dd9784f 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class AutoTrackable : Trackable + public class AutoTrackable : Trackable { public void _delete_tracking(string name) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 6a163fec..ff3c7875 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -1,14 +1,279 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Metrics; +using Tensorflow.Train; namespace Tensorflow.Keras.Saving.SavedModel { + // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#, + // Using the name "Attributes" may be quite confusing. /// /// Class that tracks and validates all serialization attributes. /// - public class SerializedAttributes + public abstract class SerializedAttributes { + protected IDictionary _object_dict; + protected IDictionary _function_dict; + protected AutoTrackable _keras_trackable; + protected HashSet _all_functions; + protected HashSet _all_checkpointable_objects; + protected SerializedAttributes() + { + _object_dict= new Dictionary(); + _function_dict= new Dictionary(); + _keras_trackable= new AutoTrackable(); + _all_functions= new HashSet(); + _all_checkpointable_objects= new HashSet(); + } + + protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions) + { + _object_dict = new Dictionary(); + _function_dict = new Dictionary(); + _keras_trackable = new AutoTrackable(); + + _all_checkpointable_objects = new HashSet(checkpointable_objects); + _all_functions = new HashSet(functions); + } + + public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); + + /// + /// Returns functions to attach to the root object during serialization. + /// + public IDictionary FunctionsToSerialize + { + get + { + Dictionary functions = new(); + foreach(var pair in Functions) + { + if (_all_functions.Contains(pair.Key)) + { + // TODO: deal with `LayerCall`. + functions[pair.Key] = pair.Value; + } + } + return functions; + } + } + + /// + /// Returns objects to attach to the root object during serialization. + /// + public IDictionary ObjectsToSerialize + { + get + { + var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value); + objects[Constants.KERAS_ATTR] = _keras_trackable; + return objects; + } + } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + public IDictionary set_and_validate_functions(IDictionary function_dict) + { + foreach(var key in _all_functions) + { + if (function_dict.ContainsKey(key)) + { + // TODO: deal with type `LayerCall`. + var fn = function_dict[key]; + if (fn is not null && (fn is not Function)) + { + throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key})."); + } + _function_dict[key] = fn; + + var tf_fn = fn; // TODO: deal with type `LayerCall`. + + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if(property.Name == key) + { + property.SetValue(_keras_trackable, tf_fn); + break; + } + } + } + else + { + throw new ValueError($"Function {key} missing from serialized function dict."); + } + } + return Functions; + } + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + public IDictionary set_and_validate_objects(IDictionary object_dict) + { + foreach(var key in _all_checkpointable_objects) + { + if (object_dict.ContainsKey(key)) + { + _object_dict[key] = object_dict[key]; + // Warning: this implmentation should be considered again. + var properties = _keras_trackable.GetType().GetProperties(); + foreach (var property in properties) + { + if (property.Name == key) + { + property.SetValue(_keras_trackable, object_dict[key]); + break; + } + } + } + else + { + throw new ValueError($"Object {key} missing from serialized object dict."); + } + } + return CheckpointableObjects; + } + + /// + /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python). + /// + /// + public static SerializedAttributes Create(Trackable obj) + { + if(obj is Model) + { + return new ModelAttributes(); + } + else if(obj is Metric) + { + return new MetricAttributes(); + } + else if(obj is RNN) + { + return new RNNAttributes(); + } + else if(obj is Layer) + { + return new LayerAttributes(); + } + else + { + throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}"); + } + } + + protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + return (checkpointable_objects ?? (new List()), functions ?? (new List())); + } + } + + // Note that the current implementation still has some potential risks. + // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras". + // However, currently it's just a normal class. + public class CommonEndPoints: SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if(checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if(functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + // TODO: remove the `__call__`. + functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) + ); + } + } + + public class LayerAttributes: CommonEndPoints + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), + functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) + ); + } + } + + public class ModelAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively(checkpointable_objects,functions); + } + } + + public class MetricAttributes : SerializedAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "variables" }), + functions + ); + } + } + + public class RNNAttributes: LayerAttributes + { + protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions) + { + if (checkpointable_objects is null) + { + checkpointable_objects = new List(); + } + if (functions is null) + { + functions = new List(); + } + return base.get_objects_and_functions_recursively( + checkpointable_objects.Concat(new string[] { "states" }), + functions + ); + } } }