| @@ -24,7 +24,7 @@ public class TrackableView | |||||
| Dictionary<string, Trackable> children = new(); | Dictionary<string, Trackable> children = new(); | ||||
| // Note: in python the return type of `Trackable._trackable_children` is not fixed. | // Note: in python the return type of `Trackable._trackable_children` is not fixed. | ||||
| // Therefore it uses `convert_to_trackable` to have an extra process. | // Therefore it uses `convert_to_trackable` to have an extra process. | ||||
| foreach(var pair in obj._trackable_children(save_type)) | |||||
| foreach (var pair in obj._trackable_children(save_type)) | |||||
| { | { | ||||
| children[pair.Key] = pair.Value; | children[pair.Key] = pair.Value; | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| using System; | using System; | ||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -90,4 +91,71 @@ namespace Tensorflow | |||||
| Dispose(false); | Dispose(false); | ||||
| } | } | ||||
| } | } | ||||
| public abstract class DisposableTrackableObject: Trackable, IDisposable | |||||
| { | |||||
| protected IntPtr _handle; | |||||
| protected bool _disposed; | |||||
| protected DisposableTrackableObject() | |||||
| { } | |||||
| protected DisposableTrackableObject(IntPtr handle) | |||||
| => _handle = handle; | |||||
| private void Dispose(bool disposing) | |||||
| { | |||||
| if (_disposed) | |||||
| return; | |||||
| //first handle managed, they might use the unmanaged resources. | |||||
| if (disposing) | |||||
| { | |||||
| // dispose managed state (managed objects). | |||||
| DisposeManagedResources(); | |||||
| } | |||||
| // free unmanaged memory | |||||
| if (_handle != IntPtr.Zero) | |||||
| { | |||||
| // Call the appropriate methods to clean up | |||||
| // unmanaged resources here. | |||||
| // If disposing is false, | |||||
| // only the following code is executed. | |||||
| DisposeUnmanagedResources(_handle); | |||||
| _handle = IntPtr.Zero; | |||||
| } | |||||
| // Note disposing has been done. | |||||
| _disposed = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Dispose any managed resources. | |||||
| /// </summary> | |||||
| /// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||||
| protected virtual void DisposeManagedResources() | |||||
| { } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||||
| public void Dispose() | |||||
| { | |||||
| Dispose(true); | |||||
| // This object will be cleaned up by the Dispose method. | |||||
| // Therefore, you should call GC.SupressFinalize to | |||||
| // take this object off the finalization queue | |||||
| // and prevent finalization code for this object | |||||
| // from executing a second time. | |||||
| GC.SuppressFinalize(this); | |||||
| } | |||||
| ~DisposableTrackableObject() | |||||
| { | |||||
| Dispose(false); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -23,10 +23,10 @@ public class AugmentedGraphView: ObjectGraphView | |||||
| list_children(Root); | list_children(Root); | ||||
| } | } | ||||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) | |||||
| public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) | |||||
| { | { | ||||
| Dictionary<string, Trackable> children = new(); | Dictionary<string, Trackable> children = new(); | ||||
| foreach (var pair in base.list_children(obj, save_type)) | |||||
| foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) | |||||
| { | { | ||||
| var name = pair.Name; | var name = pair.Name; | ||||
| var child = pair.Refer; | var child = pair.Refer; | ||||
| @@ -142,21 +142,26 @@ namespace Tensorflow.Training | |||||
| return value; | return value; | ||||
| } | } | ||||
| protected static Trackable wrap_or_unwrap(NoDependency value) | |||||
| public static Trackable wrap_or_unwrap(NoDependency value) | |||||
| { | { | ||||
| return value.Value; | return value.Value; | ||||
| } | } | ||||
| protected static Trackable wrap_or_unwrap(Trackable value) | |||||
| public static Trackable wrap_or_unwrap(Trackable value) | |||||
| { | { | ||||
| return value; | return value; | ||||
| } | } | ||||
| protected static Trackable wrap_or_unwrap(IList<Trackable> value) | |||||
| public static Trackable wrap_or_unwrap(IList<Trackable> value) | |||||
| { | { | ||||
| return new ListWrapper(value); | return new ListWrapper(value); | ||||
| } | } | ||||
| public static Trackable wrap_or_unwrap(IEnumerable<Trackable> value) | |||||
| { | |||||
| return new ListWrapper(value.ToList()); | |||||
| } | |||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | ||||
| { | { | ||||
| value = wrap_or_unwrap(value); | value = wrap_or_unwrap(value); | ||||
| @@ -7,7 +7,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class BaseResourceVariable : DisposableObject | |||||
| public class BaseResourceVariable : DisposableTrackableObject | |||||
| { | { | ||||
| protected string _name; | protected string _name; | ||||
| public virtual string Name => _handle_name; | public virtual string Name => _handle_name; | ||||
| @@ -20,11 +20,12 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [Obsolete] | [Obsolete] | ||||
| public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
| public partial class RefVariable: Trackable, IVariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
| { | { | ||||
| protected string _name; | protected string _name; | ||||
| public string UniqueId => _name; | public string UniqueId => _name; | ||||
| @@ -288,6 +288,8 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| public List<IVariableV1> Variables => weights; | |||||
| public virtual LayerArgs get_config() | public virtual LayerArgs get_config() | ||||
| => args; | => args; | ||||
| } | } | ||||
| @@ -1,5 +1,8 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | namespace Tensorflow.Keras.Saving.SavedModel; | ||||
| @@ -10,10 +13,54 @@ public partial class KerasSavedModelUtils | |||||
| return false; | return false; | ||||
| } | } | ||||
| public static IDictionary<string, KerasObjectWrapper> wrap_layer_objects(Layer layer, object serialization_cache) | |||||
| /// <summary> | |||||
| /// Returns extra trackable objects to attach to the serialized layer. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, object> serialization_cache) | |||||
| { | { | ||||
| // TODO: process the loss | |||||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||||
| return null; | |||||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||||
| var variables = layer.Variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| }); | |||||
| var trainable_variables = layer.TrainableVariables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| }); | |||||
| var non_trainable_variables = layer.non_trainable_variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| }); | |||||
| Dictionary<string, Trackable> res = new(); | |||||
| res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); | |||||
| res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); | |||||
| res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); | |||||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| return res; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, object> serialization_cache) | |||||
| { | |||||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||||
| // skip the process because of lack of APIs of `Layer`. | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,10 +17,10 @@ public abstract class SavedModelSaver | |||||
| public abstract string ObjectIdentifier { get; } | public abstract string ObjectIdentifier { get; } | ||||
| public abstract string TrackingMetadata { get; } | public abstract string TrackingMetadata { get; } | ||||
| public abstract IDictionary<string, CheckpointableBase> objects_to_serialize( | |||||
| public abstract IDictionary<string, Trackable> objects_to_serialize( | |||||
| IDictionary<string, object> serialization_cache); | IDictionary<string, object> serialization_cache); | ||||
| public abstract IDictionary<string, Function> functions_to_serialize( | |||||
| public abstract IDictionary<string, Trackable> functions_to_serialize( | |||||
| IDictionary<string, object> serialization_cache); | IDictionary<string, object> serialization_cache); | ||||
| public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache) | ||||
| @@ -32,8 +32,7 @@ public abstract class SavedModelSaver | |||||
| var children = objects_to_serialize(serialization_cache); | var children = objects_to_serialize(serialization_cache); | ||||
| return children.ToDictionary(x => x.Key, x => (Trackable)x.Value) | |||||
| .Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||||
| return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||||
| .ToDictionary(x => x.Key, x => x.Value); | .ToDictionary(x => x.Key, x => x.Value); | ||||
| } | } | ||||
| @@ -19,26 +19,51 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
| get => Constants.LAYER_IDENTIFIER; | get => Constants.LAYER_IDENTIFIER; | ||||
| } | } | ||||
| public override IDictionary<string, CheckpointableBase> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| { | { | ||||
| throw new System.NotImplementedException(); | |||||
| return get_serialized_attributes(serialization_cache).ObjectsToSerialize; | |||||
| } | } | ||||
| public override IDictionary<string, Function> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache) | |||||
| { | { | ||||
| throw new System.NotImplementedException(); | |||||
| return get_serialized_attributes(serialization_cache).FunctionsToSerialize; | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Generates or retrieves serialized attributes from cache. | /// Generates or retrieves serialized attributes from cache. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="serialization_cache"></param> | /// <param name="serialization_cache"></param> | ||||
| protected void get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||||
| protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||||
| { | { | ||||
| // TODO: deal with cache. | // TODO: deal with cache. | ||||
| Layer a; | |||||
| var serialized_attr = SerializedAttributes.Create(_obj); | |||||
| // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. | |||||
| if (KerasSavedModelUtils.should_skip_serialization(_obj)) | |||||
| { | |||||
| return serialized_attr; | |||||
| } | |||||
| var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache); | |||||
| serialized_attr.set_and_validate_objects(object_dict); | |||||
| serialized_attr.set_and_validate_functions(function_dict); | |||||
| return serialized_attr; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns dictionary of serialized attributes. | |||||
| /// </summary> | |||||
| /// <param name="serialization_cache"></param> | |||||
| private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache) | |||||
| { | |||||
| var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | |||||
| var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | |||||
| functions["_default_save_signature"] = null; | |||||
| return (objects, functions); | |||||
| } | } | ||||
| public override string TrackingMetadata | public override string TrackingMetadata | ||||
| @@ -17,15 +17,15 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| public abstract class SerializedAttributes | public abstract class SerializedAttributes | ||||
| { | { | ||||
| protected IDictionary<string, Trackable?> _object_dict; | protected IDictionary<string, Trackable?> _object_dict; | ||||
| protected IDictionary<string, Function?> _function_dict; | |||||
| protected IDictionary<string, Trackable?> _function_dict; | |||||
| protected AutoTrackable _keras_trackable; | protected AutoTrackable _keras_trackable; | ||||
| protected HashSet<string> _all_functions; | protected HashSet<string> _all_functions; | ||||
| protected HashSet<string> _all_checkpointable_objects; | protected HashSet<string> _all_checkpointable_objects; | ||||
| protected SerializedAttributes() | |||||
| private SerializedAttributes() | |||||
| { | { | ||||
| _object_dict= new Dictionary<string, Trackable?>(); | _object_dict= new Dictionary<string, Trackable?>(); | ||||
| _function_dict= new Dictionary<string, Function?>(); | |||||
| _function_dict= new Dictionary<string, Trackable?>(); | |||||
| _keras_trackable= new AutoTrackable(); | _keras_trackable= new AutoTrackable(); | ||||
| _all_functions= new HashSet<string>(); | _all_functions= new HashSet<string>(); | ||||
| _all_checkpointable_objects= new HashSet<string>(); | _all_checkpointable_objects= new HashSet<string>(); | ||||
| @@ -34,25 +34,35 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| protected SerializedAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) | protected SerializedAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) | ||||
| { | { | ||||
| _object_dict = new Dictionary<string, Trackable?>(); | _object_dict = new Dictionary<string, Trackable?>(); | ||||
| _function_dict = new Dictionary<string, Function?>(); | |||||
| _function_dict = new Dictionary<string, Trackable?>(); | |||||
| _keras_trackable = new AutoTrackable(); | _keras_trackable = new AutoTrackable(); | ||||
| _all_checkpointable_objects = new HashSet<string>(checkpointable_objects); | _all_checkpointable_objects = new HashSet<string>(checkpointable_objects); | ||||
| _all_functions = new HashSet<string>(functions); | _all_functions = new HashSet<string>(functions); | ||||
| } | } | ||||
| public IDictionary<string, Function> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||||
| protected SerializedAttributes((IEnumerable<string>, IEnumerable<string>) objects_and_functions) | |||||
| { | |||||
| _object_dict = new Dictionary<string, Trackable?>(); | |||||
| _function_dict = new Dictionary<string, Trackable?>(); | |||||
| _keras_trackable = new AutoTrackable(); | |||||
| _all_checkpointable_objects = new HashSet<string>(objects_and_functions.Item1); | |||||
| _all_functions = new HashSet<string>(objects_and_functions.Item2); | |||||
| } | |||||
| public IDictionary<string, Trackable> Functions => _function_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | |||||
| public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | public IDictionary<string, Trackable> CheckpointableObjects => _object_dict.TakeWhile(x => x.Value is not null).ToDictionary(x => x.Key, x => x.Value!); | ||||
| /// <summary> | /// <summary> | ||||
| /// Returns functions to attach to the root object during serialization. | /// Returns functions to attach to the root object during serialization. | ||||
| /// </summary> | /// </summary> | ||||
| public IDictionary<string, Function> FunctionsToSerialize | |||||
| public IDictionary<string, Trackable> FunctionsToSerialize | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| Dictionary<string, Function> functions = new(); | |||||
| Dictionary<string, Trackable> functions = new(); | |||||
| foreach(var pair in Functions) | foreach(var pair in Functions) | ||||
| { | { | ||||
| if (_all_functions.Contains(pair.Key)) | if (_all_functions.Contains(pair.Key)) | ||||
| @@ -82,7 +92,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| /// Saves function dictionary, and validates dictionary values. | /// Saves function dictionary, and validates dictionary values. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="function_dict"></param> | /// <param name="function_dict"></param> | ||||
| public IDictionary<string, Function> set_and_validate_functions(IDictionary<string, Function> function_dict) | |||||
| public IDictionary<string, Trackable> set_and_validate_functions(IDictionary<string, Trackable> function_dict) | |||||
| { | { | ||||
| foreach(var key in _all_functions) | foreach(var key in _all_functions) | ||||
| { | { | ||||
| @@ -186,94 +196,87 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| // However, currently it's just a normal class. | // However, currently it's just a normal class. | ||||
| public class CommonEndPoints: SerializedAttributes | public class CommonEndPoints: SerializedAttributes | ||||
| { | { | ||||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||||
| public CommonEndPoints(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||||
| //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||||
| // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) | |||||
| base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), | |||||
| functions.Concat(new string[] { })) | |||||
| { | { | ||||
| if(checkpointable_objects is null) | |||||
| { | |||||
| checkpointable_objects = new List<string>(); | |||||
| } | |||||
| if(functions is null) | |||||
| { | |||||
| functions = new List<string>(); | |||||
| } | |||||
| return base.get_objects_and_functions_recursively( | |||||
| checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), | |||||
| // TODO: remove the `__call__`. | |||||
| functions.Concat(new string[] {"__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||||
| ); | |||||
| } | |||||
| public CommonEndPoints() : | |||||
| //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, | |||||
| // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) | |||||
| base(new string[] { "variables", "trainable_variables"}, | |||||
| new string[] {}) | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| public class LayerAttributes: CommonEndPoints | public class LayerAttributes: CommonEndPoints | ||||
| { | { | ||||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||||
| public LayerAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||||
| //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||||
| // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||||
| base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | |||||
| functions.Concat(new string[] { })) | |||||
| { | { | ||||
| if (checkpointable_objects is null) | |||||
| { | |||||
| checkpointable_objects = new List<string>(); | |||||
| } | |||||
| if (functions is null) | |||||
| { | |||||
| functions = new List<string>(); | |||||
| } | |||||
| return base.get_objects_and_functions_recursively( | |||||
| checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||||
| functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||||
| ); | |||||
| } | |||||
| public LayerAttributes() : | |||||
| //base(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }, | |||||
| // new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||||
| base(new string[] { "non_trainable_variables", "layers" }, | |||||
| new string[] { }) | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| public class ModelAttributes: LayerAttributes | public class ModelAttributes: LayerAttributes | ||||
| { | { | ||||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||||
| public ModelAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions): | |||||
| base(checkpointable_objects, functions) | |||||
| { | { | ||||
| if (checkpointable_objects is null) | |||||
| { | |||||
| checkpointable_objects = new List<string>(); | |||||
| } | |||||
| if (functions is null) | |||||
| { | |||||
| functions = new List<string>(); | |||||
| } | |||||
| return base.get_objects_and_functions_recursively(checkpointable_objects,functions); | |||||
| } | |||||
| public ModelAttributes(): base() | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| public class MetricAttributes : SerializedAttributes | public class MetricAttributes : SerializedAttributes | ||||
| { | { | ||||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||||
| public MetricAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||||
| base(checkpointable_objects.Concat(new string[] { "variables" }), functions) | |||||
| { | { | ||||
| if (checkpointable_objects is null) | |||||
| { | |||||
| checkpointable_objects = new List<string>(); | |||||
| } | |||||
| if (functions is null) | |||||
| { | |||||
| functions = new List<string>(); | |||||
| } | |||||
| return base.get_objects_and_functions_recursively( | |||||
| checkpointable_objects.Concat(new string[] { "variables" }), | |||||
| functions | |||||
| ); | |||||
| } | |||||
| public MetricAttributes() : | |||||
| base(new string[] { "variables" }, new string[] {}) | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| public class RNNAttributes: LayerAttributes | public class RNNAttributes: LayerAttributes | ||||
| { | { | ||||
| protected override (IEnumerable<string>, IEnumerable<string>) get_objects_and_functions_recursively(IEnumerable<string>? checkpointable_objects, IEnumerable<string>? functions) | |||||
| public RNNAttributes(IEnumerable<string> checkpointable_objects, IEnumerable<string> functions) : | |||||
| base(checkpointable_objects, functions.Concat(new string[] {"states"})) | |||||
| { | { | ||||
| if (checkpointable_objects is null) | |||||
| { | |||||
| checkpointable_objects = new List<string>(); | |||||
| } | |||||
| if (functions is null) | |||||
| { | |||||
| functions = new List<string>(); | |||||
| } | |||||
| return base.get_objects_and_functions_recursively( | |||||
| checkpointable_objects.Concat(new string[] { "states" }), | |||||
| functions | |||||
| ); | |||||
| } | |||||
| public RNNAttributes() : | |||||
| base(new string[] { }, new string[] { "states" }) | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | namespace Tensorflow.Keras.Saving.SavedModel; | ||||
| @@ -12,6 +14,18 @@ public partial class KerasSavedModelUtils | |||||
| ShouldHaveTraces = save_traces; | ShouldHaveTraces = save_traces; | ||||
| return res; | return res; | ||||
| } | } | ||||
| public static IEnumerable<ILayer> list_all_layers(Layer layer) | |||||
| { | |||||
| if(layer is Model) | |||||
| { | |||||
| return (layer as Model).Layers; | |||||
| } | |||||
| else | |||||
| { | |||||
| return new List<ILayer>(layer._flatten_layers(false, false)); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||