| @@ -24,7 +24,7 @@ public class TrackableView | |||
| Dictionary<string, Trackable> children = new(); | |||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | |||
| // Therefore it uses `convert_to_trackable` to have an extra process. | |||
| foreach(var pair in obj._trackable_children(save_type)) | |||
| foreach (var pair in obj._trackable_children(save_type)) | |||
| { | |||
| children[pair.Key] = pair.Value; | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| using System; | |||
| using System.Diagnostics.CodeAnalysis; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -90,4 +91,71 @@ namespace Tensorflow | |||
| Dispose(false); | |||
| } | |||
| } | |||
| public abstract class DisposableTrackableObject: Trackable, IDisposable | |||
| { | |||
| protected IntPtr _handle; | |||
| protected bool _disposed; | |||
| protected DisposableTrackableObject() | |||
| { } | |||
| protected DisposableTrackableObject(IntPtr handle) | |||
| => _handle = handle; | |||
| private void Dispose(bool disposing) | |||
| { | |||
| if (_disposed) | |||
| return; | |||
| //first handle managed, they might use the unmanaged resources. | |||
| if (disposing) | |||
| { | |||
| // dispose managed state (managed objects). | |||
| DisposeManagedResources(); | |||
| } | |||
| // free unmanaged memory | |||
| if (_handle != IntPtr.Zero) | |||
| { | |||
| // Call the appropriate methods to clean up | |||
| // unmanaged resources here. | |||
| // If disposing is false, | |||
| // only the following code is executed. | |||
| DisposeUnmanagedResources(_handle); | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| // Note disposing has been done. | |||
| _disposed = true; | |||
| } | |||
| /// <summary> | |||
| /// Dispose any managed resources. | |||
| /// </summary> | |||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||
| protected virtual void DisposeManagedResources() | |||
| { } | |||
| /// <summary> | |||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
| /// </summary> | |||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||
| public void Dispose() | |||
| { | |||
| Dispose(true); | |||
| // This object will be cleaned up by the Dispose method. | |||
| // Therefore, you should call GC.SupressFinalize to | |||
| // take this object off the finalization queue | |||
| // and prevent finalization code for this object | |||
| // from executing a second time. | |||
| GC.SuppressFinalize(this); | |||
| } | |||
| ~DisposableTrackableObject() | |||
| { | |||
| Dispose(false); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,10 +23,10 @@ public class AugmentedGraphView: ObjectGraphView | |||
| list_children(Root); | |||
| } | |||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) | |||
| { | |||
| Dictionary<string, Trackable> children = new(); | |||
| foreach (var pair in base.list_children(obj, save_type)) | |||
| foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) | |||
| { | |||
| var name = pair.Name; | |||
| var child = pair.Refer; | |||
| @@ -142,21 +142,26 @@ namespace Tensorflow.Training | |||
| return value; | |||
| } | |||
| protected static Trackable wrap_or_unwrap(NoDependency value) | |||
| public static Trackable wrap_or_unwrap(NoDependency value) | |||
| { | |||
| return value.Value; | |||
| } | |||
| protected static Trackable wrap_or_unwrap(Trackable value) | |||
| public static Trackable wrap_or_unwrap(Trackable value) | |||
| { | |||
| return value; | |||
| } | |||
| protected static Trackable wrap_or_unwrap(IList<Trackable> value) | |||
| public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||
| { | |||
| return new ListWrapper(value); | |||
| } | |||
| public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||
| { | |||
| return new ListWrapper(value.ToList()); | |||
| } | |||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||
| { | |||
| value = wrap_or_unwrap(value); | |||
| @@ -7,7 +7,7 @@ using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class BaseResourceVariable : DisposableObject | |||
| public class BaseResourceVariable : DisposableTrackableObject | |||
| { | |||
| protected string _name; | |||
| public virtual string Name => _handle_name; | |||
| @@ -20,11 +20,12 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| { | |||
| [Obsolete] | |||
| public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||
| public partial class RefVariable: Trackable, IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||
| { | |||
| protected string _name; | |||
| public string UniqueId => _name; | |||
| @@ -288,6 +288,8 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| public List<IVariableV1> Variables => weights; | |||
| public virtual LayerArgs get_config() | |||
| => args; | |||
| } | |||
| @@ -1,5 +1,8 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| @@ -10,10 +13,54 @@ public partial class KerasSavedModelUtils | |||
| return false; | |||
| } | |||
| public static IDictionary<string, KerasObjectWrapper> wrap_layer_objects(Layer layer, object serialization_cache) | |||
| /// <summary> | |||
| /// Returns extra trackable objects to attach to the serialized layer. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache) | |||
| { | |||
| // TODO: process the loss | |||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||
| return null; | |||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||
| var variables = layer.Variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| var trainable_variables = layer.TrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| var non_trainable_variables = layer.non_trainable_variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| Dictionary<string, Trackable> res = new(); | |||
| res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); | |||
| res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); | |||
| res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); | |||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| return res; | |||
| } | |||
| /// <summary> | |||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache) | |||
| { | |||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||
| // skip the process because of lack of APIs of `Layer`. | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| } | |||
| @@ -17,10 +17,10 @@ public abstract class SavedModelSaver | |||
| public abstract string ObjectIdentifier { get; } | |||
| public abstract string TrackingMetadata { get; } | |||
| public abstract IDictionary<string, CheckpointableBase> objects_to_serialize( | |||
| public abstract IDictionary<string, Trackable> objects_to_serialize( | |||
| IDictionary<string, object> serialization_cache); | |||
| public abstract IDictionary<string, Function> functions_to_serialize( | |||
| public abstract IDictionary<string, Trackable> functions_to_serialize( | |||
| IDictionary<string, object> serialization_cache); | |||
| public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | |||
| @@ -32,8 +32,7 @@ public abstract class SavedModelSaver | |||
| var children = objects_to_serialize(serialization_cache); | |||
| return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) | |||
| .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||
| return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||
| .ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| @@ -19,26 +19,51 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
| get => Constants.LAYER_IDENTIFIER; | |||
| } | |||
| public override IDictionary<string, CheckpointableBase> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||
| public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||
| { | |||
| throw new System.NotImplementedException(); | |||
| return get_serialized_attributes(serialization_cache).ObjectsToSerialize; | |||
| } | |||
| public override IDictionary<string, Function> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||
| public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||
| { | |||
| throw new System.NotImplementedException(); | |||
| return get_serialized_attributes(serialization_cache).FunctionsToSerialize; | |||
| } | |||
| /// <summary> | |||
| /// Generates or retrieves serialized attributes from cache. | |||
| /// </summary> | |||
| /// <param name="serialization_cache"></param> | |||
| protected void get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||
| protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||
| { | |||
| // TODO: deal with cache. | |||
| Layer a; | |||
| var serialized_attr = SerializedAttributes.Create(_obj); | |||
| // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. | |||
| if (KerasSavedModelUtils.should_skip_serialization(_obj)) | |||
| { | |||
| return serialized_attr; | |||
| } | |||
| var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); | |||
| serialized_attr.set_and_validate_objects(object_dict); | |||
| serialized_attr.set_and_validate_functions(function_dict); | |||
| return serialized_attr; | |||
| } | |||
| /// <summary> | |||
| /// Returns dictionary of serialized attributes. | |||
| /// </summary> | |||
| /// <param name="serialization_cache"></param> | |||
| private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache) | |||
| { | |||
| var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | |||
| var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | |||
| functions["_default_save_signature"] = null; | |||
| return (objects, functions); | |||
| } | |||
| public override string TrackingMetadata | |||
| @@ -17,15 +17,15 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| public abstract class SerializedAttributes | |||
| { | |||
| protected IDictionary<string, Trackable?> _object_dict; | |||
| protected IDictionary<string, Function?> _function_dict; | |||
| protected IDictionary<string, Trackable?> _function_dict; | |||
| protected AutoTrackable _keras_trackable; | |||
| protected HashSet<string> _all_functions; | |||
| protected HashSet<string> _all_checkpointable_objects; | |||
| protected SerializedAttributes() | |||
| private SerializedAttributes() | |||
| { | |||
| _object_dict= new Dictionary<string, Trackable?>(); | |||
| _function_dict= new Dictionary<string, Function?>(); | |||
| _function_dict= new Dictionary<string, Trackable?>(); | |||
| _keras_trackable= new AutoTrackable(); | |||
| _all_functions= new HashSet<string>(); | |||
| _all_checkpointable_objects= new HashSet<string>(); | |||
| @@ -34,25 +34,35 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| protected SerializedAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) | |||
| { | |||
| _object_dict = new Dictionary<string, Trackable?>(); | |||
| _function_dict = new Dictionary<string, Function?>(); | |||
| _function_dict = new Dictionary<string, Trackable?>(); | |||
| _keras_trackable = new AutoTrackable(); | |||
| _all_checkpointable_objects = new HashSet<string>(checkpointable_objects); | |||
| _all_functions = new HashSet<string>(functions); | |||
| } | |||
| public IDictionary<string, Function> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| protected SerializedAttributes((IEnumerable<string>, IEnumerable<string>) objects_and_functions) | |||
| { | |||
| _object_dict = new Dictionary<string, Trackable?>(); | |||
| _function_dict = new Dictionary<string, Trackable?>(); | |||
| _keras_trackable = new AutoTrackable(); | |||
| _all_checkpointable_objects = new HashSet<string>(objects_and_functions.Item1); | |||
| _all_functions = new HashSet<string>(objects_and_functions.Item2); | |||
| } | |||
| public IDictionary<string, Trackable> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||
| /// <summary> | |||
| /// Returns functions to attach to the root object during serialization. | |||
| /// </summary> | |||
| public IDictionary<string, Function> FunctionsToSerialize | |||
| public IDictionary<string, Trackable> FunctionsToSerialize | |||
| { | |||
| get | |||
| { | |||
| Dictionary<string, Function> functions = new(); | |||
| Dictionary<string, Trackable> functions = new(); | |||
| foreach(var pair in Functions) | |||
| { | |||
| if (_all_functions.Contains(pair.Key)) | |||
| @@ -82,7 +92,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| /// Saves function dictionary, and validates dictionary values. | |||
| /// </summary> | |||
| /// <param name="function_dict"></param> | |||
| public IDictionary<string, Function> set_and_validate_functions(IDictionary<string, Function> function_dict) | |||
| public IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict) | |||
| { | |||
| foreach(var key in _all_functions) | |||
| { | |||
| @@ -186,94 +196,87 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| // However, currently it's just a normal class. | |||
| public class CommonEndPoints: SerializedAttributes | |||
| { | |||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||
| public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||
| //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||
| // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), | |||
| functions.Concat(new string[] { })) | |||
| { | |||
| if(checkpointable_objects is null) | |||
| { | |||
| checkpointable_objects = new List<string>(); | |||
| } | |||
| if(functions is null) | |||
| { | |||
| functions = new List<string>(); | |||
| } | |||
| return base.get_objects_and_functions_recursively( | |||
| checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||
| // TODO: remove the `__call__`. | |||
| functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||
| ); | |||
| } | |||
| public CommonEndPoints() : | |||
| //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||
| // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||
| base(new string[] { "variables", "trainable_variables"}, | |||
| new string[] {}) | |||
| { | |||
| } | |||
| } | |||
| public class LayerAttributes: CommonEndPoints | |||
| { | |||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||
| public LayerAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||
| //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||
| // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||
| base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | |||
| functions.Concat(new string[] { })) | |||
| { | |||
| if (checkpointable_objects is null) | |||
| { | |||
| checkpointable_objects = new List<string>(); | |||
| } | |||
| if (functions is null) | |||
| { | |||
| functions = new List<string>(); | |||
| } | |||
| return base.get_objects_and_functions_recursively( | |||
| checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||
| functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||
| ); | |||
| } | |||
| public LayerAttributes() : | |||
| //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, | |||
| // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||
| base(new string[] { "non_trainable_variables", "layers" }, | |||
| new string[] { }) | |||
| { | |||
| } | |||
| } | |||
| public class ModelAttributes: LayerAttributes | |||
| { | |||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||
| public ModelAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions): | |||
| base(checkpointable_objects, functions) | |||
| { | |||
| if (checkpointable_objects is null) | |||
| { | |||
| checkpointable_objects = new List<string>(); | |||
| } | |||
| if (functions is null) | |||
| { | |||
| functions = new List<string>(); | |||
| } | |||
| return base.get_objects_and_functions_recursively(checkpointable_objects,functions); | |||
| } | |||
| public ModelAttributes(): base() | |||
| { | |||
| } | |||
| } | |||
| public class MetricAttributes : SerializedAttributes | |||
| { | |||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||
| public MetricAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||
| base(checkpointable_objects.Concat(new string[] { "variables" }), functions) | |||
| { | |||
| if (checkpointable_objects is null) | |||
| { | |||
| checkpointable_objects = new List<string>(); | |||
| } | |||
| if (functions is null) | |||
| { | |||
| functions = new List<string>(); | |||
| } | |||
| return base.get_objects_and_functions_recursively( | |||
| checkpointable_objects.Concat(new string[] { "variables" }), | |||
| functions | |||
| ); | |||
| } | |||
| public MetricAttributes() : | |||
| base(new string[] { "variables" }, new string[] {}) | |||
| { | |||
| } | |||
| } | |||
| public class RNNAttributes: LayerAttributes | |||
| { | |||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||
| public RNNAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||
| base(checkpointable_objects, functions.Concat(new string[] {"states"})) | |||
| { | |||
| if (checkpointable_objects is null) | |||
| { | |||
| checkpointable_objects = new List<string>(); | |||
| } | |||
| if (functions is null) | |||
| { | |||
| functions = new List<string>(); | |||
| } | |||
| return base.get_objects_and_functions_recursively( | |||
| checkpointable_objects.Concat(new string[] { "states" }), | |||
| functions | |||
| ); | |||
| } | |||
| public RNNAttributes() : | |||
| base(new string[] { }, new string[] { "states" }) | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| @@ -12,6 +14,18 @@ public partial class KerasSavedModelUtils | |||
| ShouldHaveTraces = save_traces; | |||
| return res; | |||
| } | |||
| public static IEnumerable<ILayer> list_all_layers(Layer layer) | |||
| { | |||
| if(layer is Model) | |||
| { | |||
| return (layer as Model).Layers; | |||
| } | |||
| else | |||
| { | |||
| return new List<ILayer>(layer._flatten_layers(false, false)); | |||
| } | |||
| } | |||
| } | |||
| /// <summary> | |||