Browse Source

Add serialized attributes.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
bdca3b5e3d
2 changed files with 267 additions and 2 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  2. +266
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs

+ 1
- 1
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -6,7 +6,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.Train namespace Tensorflow.Train
{ {
public abstract class AutoTrackable : Trackable
public class AutoTrackable : Trackable
{ {
public void _delete_tracking(string name) public void _delete_tracking(string name)
{ {


+ 266
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -1,14 +1,279 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Metrics;
using Tensorflow.Train;


namespace Tensorflow.Keras.Saving.SavedModel namespace Tensorflow.Keras.Saving.SavedModel
{ {
// TODO: revise the name of these "Attributes". Since "Attribute" is a significant feature of C#,
// Using the name "Attributes" may be quite confusing.
/// <summary> /// <summary>
/// Class that tracks and validates all serialization attributes. /// Class that tracks and validates all serialization attributes.
/// </summary> /// </summary>
public class SerializedAttributes
public abstract class SerializedAttributes
{ {
protected IDictionary<string, Trackable?> _object_dict;
protected IDictionary<string, Function?> _function_dict;
protected AutoTrackable _keras_trackable;
protected HashSet<string> _all_functions;
protected HashSet<string> _all_checkpointable_objects;


protected SerializedAttributes()
{
_object_dict= new Dictionary<string, Trackable?>();
_function_dict= new Dictionary<string, Function?>();
_keras_trackable= new AutoTrackable();
_all_functions= new HashSet<string>();
_all_checkpointable_objects= new HashSet<string>();
}

protected SerializedAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions)
{
_object_dict = new Dictionary<string, Trackable?>();
_function_dict = new Dictionary<string, Function?>();
_keras_trackable = new AutoTrackable();

_all_checkpointable_objects = new HashSet<string>(checkpointable_objects);
_all_functions = new HashSet<string>(functions);
}

public IDictionary<string, Function> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);

public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!);

/// <summary>
/// Returns functions to attach to the root object during serialization.
/// </summary>
public IDictionary<string, Function> FunctionsToSerialize
{
get
{
Dictionary<string, Function> functions = new();
foreach(var pair in Functions)
{
if (_all_functions.Contains(pair.Key))
{
// TODO: deal with `LayerCall`.
functions[pair.Key] = pair.Value;
}
}
return functions;
}
}

/// <summary>
/// Returns objects to attach to the root object during serialization.
/// </summary>
public IDictionary<string, Trackable> ObjectsToSerialize
{
get
{
var objects = CheckpointableObjects.TakeWhile( x=> _all_checkpointable_objects.Contains(x.Key)).ToDictionary(x => x.Key, x => x.Value);
objects[Constants.KERAS_ATTR] = _keras_trackable;
return objects;
}
}

/// <summary>
/// Saves function dictionary, and validates dictionary values.
/// </summary>
/// <param name="function_dict"></param>
public IDictionary<string, Function> set_and_validate_functions(IDictionary<string, Function> function_dict)
{
foreach(var key in _all_functions)
{
if (function_dict.ContainsKey(key))
{
// TODO: deal with type `LayerCall`.
var fn = function_dict[key];
if (fn is not null && (fn is not Function))
{
throw new ValueError($"Function dictionary contained a non-function object: {function_dict[key]} (for key {key}).");
}
_function_dict[key] = fn;

var tf_fn = fn; // TODO: deal with type `LayerCall`.

// Warning: this implmentation should be considered again.
var properties = _keras_trackable.GetType().GetProperties();
foreach (var property in properties)
{
if(property.Name == key)
{
property.SetValue(_keras_trackable, tf_fn);
break;
}
}
}
else
{
throw new ValueError($"Function {key} missing from serialized function dict.");
}
}
return Functions;
}

/// <summary>
/// Saves objects to a dictionary, and validates the values.
/// </summary>
/// <param name="object_dict"></param>
public IDictionary<string, Trackable> set_and_validate_objects(IDictionary<string, Trackable> object_dict)
{
foreach(var key in _all_checkpointable_objects)
{
if (object_dict.ContainsKey(key))
{
_object_dict[key] = object_dict[key];
// Warning: this implmentation should be considered again.
var properties = _keras_trackable.GetType().GetProperties();
foreach (var property in properties)
{
if (property.Name == key)
{
property.SetValue(_keras_trackable, object_dict[key]);
break;
}
}
}
else
{
throw new ValueError($"Object {key} missing from serialized object dict.");
}
}
return CheckpointableObjects;
}

/// <summary>
/// Returns a new SerializedAttribute object (corresponding to `new` of tensorflow python).
/// </summary>
/// <returns></returns>
public static SerializedAttributes Create(Trackable obj)
{
if(obj is Model)
{
return new ModelAttributes();
}
else if(obj is Metric)
{
return new MetricAttributes();
}
else if(obj is RNN)
{
return new RNNAttributes();
}
else if(obj is Layer)
{
return new LayerAttributes();
}
else
{
throw new TypeError($"Internal error during serialization: Expected Keras Layer object, got {obj} of type {obj.GetType()}");
}
}

protected virtual (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
return (checkpointable_objects ?? (new List<string>()), functions ?? (new List<string>()));
}
}

// Note that the current implementation still has some potential risks.
// The tensorflow python says that this class is "Common endpoints shared by all models loadable by Keras".
// However, currently it's just a normal class.
public class CommonEndPoints: SerializedAttributes
{
protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
if(checkpointable_objects is null)
{
checkpointable_objects = new List<string>();
}
if(functions is null)
{
functions = new List<string>();
}
return base.get_objects_and_functions_recursively(
checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }),
// TODO: remove the `__call__`.
functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })
);
}
}

public class LayerAttributes: CommonEndPoints
{
protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
if (checkpointable_objects is null)
{
checkpointable_objects = new List<string>();
}
if (functions is null)
{
functions = new List<string>();
}
return base.get_objects_and_functions_recursively(
checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
);
}
}

public class ModelAttributes: LayerAttributes
{
protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
if (checkpointable_objects is null)
{
checkpointable_objects = new List<string>();
}
if (functions is null)
{
functions = new List<string>();
}
return base.get_objects_and_functions_recursively(checkpointable_objects,functions);
}
}

public class MetricAttributes : SerializedAttributes
{
protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
if (checkpointable_objects is null)
{
checkpointable_objects = new List<string>();
}
if (functions is null)
{
functions = new List<string>();
}
return base.get_objects_and_functions_recursively(
checkpointable_objects.Concat(new string[] { "variables" }),
functions
);
}
}

public class RNNAttributes: LayerAttributes
{
protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions)
{
if (checkpointable_objects is null)
{
checkpointable_objects = new List<string>();
}
if (functions is null)
{
functions = new List<string>();
}
return base.get_objects_and_functions_recursively(
checkpointable_objects.Concat(new string[] { "states" }),
functions
);
}
} }
} }

Loading…
Cancel
Save