diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs
index 6f10fd2e..5dd9784f 100644
--- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs
+++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs
@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.Train
{
- public abstract class AutoTrackable : Trackable
+ public class AutoTrackable : Trackable
{
public void _delete_tracking(string name)
{
diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
index 6a163fec..ff3c7875 100644
--- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
+++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
@@ -1,14 +1,279 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers.Rnn;
+using Tensorflow.Keras.Metrics;
+using Tensorflow.Train;
namespace Tensorflow.Keras.Saving.SavedModel
{
+ // TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#,
+ // Using the name "Attributes" may be quite confusing.
///
/// Class that tracks and validates all serialization attributes.
///
- public class SerializedAttributes
+ public abstract class SerializedAttributes
{
+ protected IDictionary _object_dict;
+ protected IDictionary _function_dict;
+ protected AutoTrackable _keras_trackable;
+ protected HashSet _all_functions;
+ protected HashSet _all_checkpointable_objects;
+ protected SerializedAttributes()
+ {
+ _object_dict= new Dictionary();
+ _function_dict= new Dictionary();
+ _keras_trackable= new AutoTrackable();
+ _all_functions= new HashSet();
+ _all_checkpointable_objects= new HashSet();
+ }
+
+ protected SerializedAttributes(IEnumerable checkpointable_objects, IEnumerable functions)
+ {
+ _object_dict = new Dictionary();
+ _function_dict = new Dictionary();
+ _keras_trackable = new AutoTrackable();
+
+ _all_checkpointable_objects = new HashSet(checkpointable_objects);
+ _all_functions = new HashSet(functions);
+ }
+
+ public IDictionary Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
+
+ public IDictionary CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);
+
+ ///
+ /// Returns functions to attach to the root object during serialization.
+ ///
+ public IDictionary FunctionsToSerialize
+ {
+ get
+ {
+ Dictionary functions = new();
+ foreach(var pair in Functions)
+ {
+ if (_all_functions.Contains(pair.Key))
+ {
+ // TODO: deal with `LayerCall`.
+ functions[pair.Key] = pair.Value;
+ }
+ }
+ return functions;
+ }
+ }
+
+ ///
+ /// Returns objects to attach to the root object during serialization.
+ ///
+ public IDictionary ObjectsToSerialize
+ {
+ get
+ {
+ var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value);
+ objects[Constants.KERAS_ATTR] = _keras_trackable;
+ return objects;
+ }
+ }
+
+ ///
+ /// Saves function dictionary, and validates dictionary values.
+ ///
+ ///
+ public IDictionary set_and_validate_functions(IDictionary function_dict)
+ {
+ foreach(var key in _all_functions)
+ {
+ if (function_dict.ContainsKey(key))
+ {
+ // TODO: deal with type `LayerCall`.
+ var fn = function_dict[key];
+ if (fn is not null && (fn is not Function))
+ {
+ throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key}).");
+ }
+ _function_dict[key] = fn;
+
+ var tf_fn = fn; // TODO: deal with type `LayerCall`.
+
+ // Warning: this implmentation should be considered again.
+ var properties = _keras_trackable.GetType().GetProperties();
+ foreach (var property in properties)
+ {
+ if(property.Name == key)
+ {
+ property.SetValue(_keras_trackable, tf_fn);
+ break;
+ }
+ }
+ }
+ else
+ {
+ throw new ValueError($"Function {key} missing from serialized function dict.");
+ }
+ }
+ return Functions;
+ }
+
+ ///
+ /// Saves objects to a dictionary, and validates the values.
+ ///
+ ///
+ public IDictionary set_and_validate_objects(IDictionary object_dict)
+ {
+ foreach(var key in _all_checkpointable_objects)
+ {
+ if (object_dict.ContainsKey(key))
+ {
+ _object_dict[key] = object_dict[key];
+ // Warning: this implmentation should be considered again.
+ var properties = _keras_trackable.GetType().GetProperties();
+ foreach (var property in properties)
+ {
+ if (property.Name == key)
+ {
+ property.SetValue(_keras_trackable, object_dict[key]);
+ break;
+ }
+ }
+ }
+ else
+ {
+ throw new ValueError($"Object {key} missing from serialized object dict.");
+ }
+ }
+ return CheckpointableObjects;
+ }
+
+ ///
+ /// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python).
+ ///
+ ///
+ public static SerializedAttributes Create(Trackable obj)
+ {
+ if(obj is Model)
+ {
+ return new ModelAttributes();
+ }
+ else if(obj is Metric)
+ {
+ return new MetricAttributes();
+ }
+ else if(obj is RNN)
+ {
+ return new RNNAttributes();
+ }
+ else if(obj is Layer)
+ {
+ return new LayerAttributes();
+ }
+ else
+ {
+ throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}");
+ }
+ }
+
+ protected virtual (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ return (checkpointable_objects ?? (new List()), functions ?? (new List()));
+ }
+ }
+
+ // Note that the current implementation still has some potential risks.
+ // The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras".
+ // However, currently it's just a normal class.
+ public class CommonEndPoints: SerializedAttributes
+ {
+ protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ if(checkpointable_objects is null)
+ {
+ checkpointable_objects = new List();
+ }
+ if(functions is null)
+ {
+ functions = new List();
+ }
+ return base.get_objects_and_functions_recursively(
+ checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }),
+ // TODO: remove the `__call__`.
+ functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })
+ );
+ }
+ }
+
+ public class LayerAttributes: CommonEndPoints
+ {
+ protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ if (checkpointable_objects is null)
+ {
+ checkpointable_objects = new List();
+ }
+ if (functions is null)
+ {
+ functions = new List();
+ }
+ return base.get_objects_and_functions_recursively(
+ checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
+ functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
+ );
+ }
+ }
+
+ public class ModelAttributes: LayerAttributes
+ {
+ protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ if (checkpointable_objects is null)
+ {
+ checkpointable_objects = new List();
+ }
+ if (functions is null)
+ {
+ functions = new List();
+ }
+ return base.get_objects_and_functions_recursively(checkpointable_objects,functions);
+ }
+ }
+
+ public class MetricAttributes : SerializedAttributes
+ {
+ protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ if (checkpointable_objects is null)
+ {
+ checkpointable_objects = new List();
+ }
+ if (functions is null)
+ {
+ functions = new List();
+ }
+ return base.get_objects_and_functions_recursively(
+ checkpointable_objects.Concat(new string[] { "variables" }),
+ functions
+ );
+ }
+ }
+
+ public class RNNAttributes: LayerAttributes
+ {
+ protected override (IEnumerable, IEnumerable) get_objects_and_functions_recursively(IEnumerable? checkpointable_objects, IEnumerable? functions)
+ {
+ if (checkpointable_objects is null)
+ {
+ checkpointable_objects = new List();
+ }
+ if (functions is null)
+ {
+ functions = new List();
+ }
+ return base.get_objects_and_functions_recursively(
+ checkpointable_objects.Concat(new string[] { "states" }),
+ functions
+ );
+ }
}
}