| @@ -1,10 +1,11 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| public interface ILayer | |||||
| public interface ILayer: ITrackable | |||||
| { | { | ||||
| string Name { get; } | string Name { get; } | ||||
| bool Trainable { get; } | bool Trainable { get; } | ||||
| @@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -147,5 +148,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public Trackable GetTrackable() { throw new NotImplementedException(); } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,12 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| public interface ITrackable | |||||
| { | |||||
| Trackable GetTrackable(); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,9 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Train; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| } | |||||
| @@ -18,11 +18,12 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
| using Tensorflow.Training; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| { | { | ||||
| public abstract class Trackable | |||||
| public abstract class Trackable: ITrackable | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Corresponding to tensorflow/python/trackable/constants.py | /// Corresponding to tensorflow/python/trackable/constants.py | ||||
| @@ -40,6 +41,7 @@ namespace Tensorflow.Train | |||||
| protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | protected IDictionary<string, ResourceVariable> _self_saveable_object_factories = | ||||
| new Dictionary<string, ResourceVariable>(); | new Dictionary<string, ResourceVariable>(); | ||||
| private bool _manual_tracking = true; | |||||
| private static Trackable _none = new Function(); | private static Trackable _none = new Function(); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -54,6 +56,10 @@ namespace Tensorflow.Train | |||||
| return _none; | return _none; | ||||
| } | } | ||||
| } | } | ||||
| public Trackable GetTrackable() | |||||
| { | |||||
| return this; | |||||
| } | |||||
| public virtual string ObjectIdentifier | public virtual string ObjectIdentifier | ||||
| { | { | ||||
| get => "_generic_user_object"; | get => "_generic_user_object"; | ||||
| @@ -128,6 +134,48 @@ namespace Tensorflow.Train | |||||
| return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); | ||||
| } | } | ||||
| public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false) | |||||
| { | |||||
| _maybe_initialize_trackable(); | |||||
| if (!_manual_tracking) return trackable; | |||||
| var new_reference = new TrackableReference(name, trackable); | |||||
| var current_object = _lookupup_dependency(name); | |||||
| if(current_object is null) | |||||
| { | |||||
| _unconditional_checkpoint_dependencies.Add(new_reference); | |||||
| _handle_deferred_dependencies(name, trackable); | |||||
| } | |||||
| _unconditional_dependency_names[name] = trackable; | |||||
| return trackable; | |||||
| } | |||||
| /// <summary> | |||||
| /// Pop and load any deferred checkpoint restores into `trackable`. | |||||
| /// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for | |||||
| /// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are | |||||
| /// considered fulfilled and so are deleted. | |||||
| /// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations. | |||||
| /// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves, | |||||
| /// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context | |||||
| /// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs). | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <param name="trackable"></param> | |||||
| public virtual void _handle_deferred_dependencies(string name, Trackable trackable) | |||||
| { | |||||
| //_maybe_initialize_trackable(); | |||||
| //trackable._maybe_initialize_trackable(); | |||||
| // TODO: complete the implementation. | |||||
| } | |||||
| public virtual Trackable? _lookupup_dependency(string name) | |||||
| { | |||||
| if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; | |||||
| else return null; | |||||
| } | |||||
| public static Trackable convert_to_trackable(object obj, object? parent = null) | public static Trackable convert_to_trackable(object obj, object? parent = null) | ||||
| { | { | ||||
| if (obj is Trackable) | if (obj is Trackable) | ||||
| @@ -0,0 +1,364 @@ | |||||
| using Google.Protobuf; | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.IO.Compression; | |||||
| using System.Linq; | |||||
| using System.Linq.Expressions; | |||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Keras; | |||||
| using Tensorflow.Operations.Activation; | |||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.ApiDef.Types; | |||||
| namespace Tensorflow.Training | |||||
| { | |||||
| public class NoDependency | |||||
| { | |||||
| public Trackable Value { get; set; } | |||||
| public NoDependency(Trackable value) | |||||
| { | |||||
| Value = value; | |||||
| } | |||||
| } | |||||
| public abstract class TrackableDataStructure : Trackable | |||||
| { | |||||
| private bool _self_trainable; | |||||
| private List<IVariableV1> _self_extra_variables; | |||||
| public TrackableDataStructure() | |||||
| { | |||||
| _self_trainable = true; | |||||
| _self_extra_variables = new List<IVariableV1>(); | |||||
| } | |||||
| public abstract IEnumerable<Trackable> Values { get; } | |||||
| public bool Trainable { get => _self_trainable; set => _self_trainable = value; } | |||||
| public IEnumerable<ILayer> Layers | |||||
| { | |||||
| get | |||||
| { | |||||
| List<ILayer> collected = new(); | |||||
| foreach(var obj in Values) | |||||
| { | |||||
| if(obj is ILayer) | |||||
| { | |||||
| collected.Add((ILayer)obj); | |||||
| } | |||||
| else if(obj is TrackableDataStructure) | |||||
| { | |||||
| collected.AddRange((obj as TrackableDataStructure).Layers); | |||||
| } | |||||
| } | |||||
| return collected; | |||||
| } | |||||
| } | |||||
| public IEnumerable<IVariableV1> TrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!_self_trainable) | |||||
| { | |||||
| return new List<IVariableV1>(); | |||||
| } | |||||
| List<IVariableV1> trainable_variables = new(); | |||||
| foreach (var obj in Values) | |||||
| { | |||||
| // skip the process of `module.Module`. | |||||
| if (obj is TrackableDataStructure) | |||||
| { | |||||
| trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); | |||||
| } | |||||
| } | |||||
| foreach(var v in _self_extra_variables) | |||||
| { | |||||
| if (v.Trainable) | |||||
| { | |||||
| trainable_variables.Add(v); | |||||
| } | |||||
| } | |||||
| return trainable_variables; | |||||
| } | |||||
| } | |||||
| public IEnumerable<IVariableV1> NonTrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList(); | |||||
| var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList(); | |||||
| List<IVariableV1> non_trainable_variables = new(); | |||||
| foreach(var obj in Values) | |||||
| { | |||||
| // skip the process of `module.Module`. | |||||
| if (obj is TrackableDataStructure) | |||||
| { | |||||
| non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables); | |||||
| } | |||||
| } | |||||
| if (!_self_trainable) | |||||
| { | |||||
| // Return order is all trainable vars, then all non-trainable vars. | |||||
| List<IVariableV1> trainable_variables = new(); | |||||
| foreach(var obj in Values) | |||||
| { | |||||
| // skip the process of `module.Module`. | |||||
| if (obj is TrackableDataStructure) | |||||
| { | |||||
| trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables); | |||||
| } | |||||
| } | |||||
| return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables); | |||||
| } | |||||
| else | |||||
| { | |||||
| return non_trainable_variables.concat(non_trainable_extra_variables); | |||||
| } | |||||
| } | |||||
| } | |||||
| public IEnumerable<IVariableV1> Weights => TrainableWeights.Concat(NonTrainableWeights); | |||||
| public IEnumerable<IVariableV1> TrainableVariables => TrainableWeights; | |||||
| public IEnumerable<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||||
| public IEnumerable<IVariableV1> Variables => Weights; | |||||
| // TODO: `losses` property. | |||||
| /// <summary> | |||||
| /// Add a dependency on `value`. | |||||
| /// </summary> | |||||
| /// <param name="value"></param> | |||||
| /// <param name="name"></param> | |||||
| protected virtual Trackable _track_value(Trackable value, string name) | |||||
| { | |||||
| value = sticky_attribute_assignment(this, name, value); | |||||
| if(value is IVariableV1) | |||||
| { | |||||
| _self_extra_variables.Add(value as IVariableV1); | |||||
| } | |||||
| // skip the left process (need to be done in the future). | |||||
| return value; | |||||
| } | |||||
| protected static Trackable wrap_or_unwrap(NoDependency value) | |||||
| { | |||||
| return value.Value; | |||||
| } | |||||
| protected static Trackable wrap_or_unwrap(Trackable value) | |||||
| { | |||||
| return value; | |||||
| } | |||||
| protected static Trackable wrap_or_unwrap(IList<Trackable> value) | |||||
| { | |||||
| return new ListWrapper(value); | |||||
| } | |||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) | |||||
| { | |||||
| value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(value, name, true); | |||||
| return value; | |||||
| } | |||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) | |||||
| { | |||||
| var wrapped_value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(wrapped_value, name, true); | |||||
| return wrapped_value; | |||||
| } | |||||
| protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value) | |||||
| { | |||||
| var wrapped_value = wrap_or_unwrap(value); | |||||
| trackable._track_trackable(wrapped_value, name, true); | |||||
| return wrapped_value; | |||||
| } | |||||
| } | |||||
| public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable | |||||
| { | |||||
| private IList<Trackable> _storage; | |||||
| private bool _non_append_mutation_value; | |||||
| private bool _external_modification_value; | |||||
| private IList<Trackable> _last_wrapped_list_snapshot; | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="wrapped_list">The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be | |||||
| /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param> | |||||
| public ListWrapper(IList<Trackable> wrapped_list) | |||||
| { | |||||
| _storage = wrapped_list; | |||||
| _non_append_mutation_value = _external_modification_value = false; | |||||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||||
| } | |||||
| protected bool NonAppendMuation { | |||||
| get => _non_append_mutation_value; | |||||
| set | |||||
| { | |||||
| // TODO: deal with `attribute_sentinel`. | |||||
| _non_append_mutation_value = value; | |||||
| } | |||||
| } | |||||
| protected bool ExternalModification | |||||
| { | |||||
| get => _external_modification_value; | |||||
| set | |||||
| { | |||||
| // TODO: deal with `attribute_sentinel`. | |||||
| _external_modification_value = value; | |||||
| } | |||||
| } | |||||
| public override IEnumerable<Trackable> Values => this; | |||||
| public bool IsReadOnly { get => _storage.IsReadOnly; } | |||||
| /// <summary> | |||||
| /// Checks for any changes to the wrapped list not through the wrapper. | |||||
| /// </summary> | |||||
| private void check_external_modification() | |||||
| { | |||||
| if (_external_modification_value || _non_append_mutation_value) return; | |||||
| if (!_storage.SequenceEqual(_last_wrapped_list_snapshot)) | |||||
| { | |||||
| _external_modification_value = true; | |||||
| } | |||||
| } | |||||
| private void update_snapshot() | |||||
| { | |||||
| // TODO: deal with `attribute_sentinel`. | |||||
| if (_external_modification_value || _non_append_mutation_value) return; | |||||
| _last_wrapped_list_snapshot = new List<Trackable>(_storage); | |||||
| } | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| check_external_modification(); | |||||
| if (_non_append_mutation_value) | |||||
| { | |||||
| throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" + | |||||
| $", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." + | |||||
| $"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored."); | |||||
| } | |||||
| if (_external_modification_value) | |||||
| { | |||||
| throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " + | |||||
| $"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " + | |||||
| $"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored."); | |||||
| } | |||||
| var children = base._trackable_children(save_type, cache); | |||||
| if(save_type == SaveType.SAVEDMODEL) | |||||
| { | |||||
| children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| return children; | |||||
| } | |||||
| private bool has_mutation_or_trackable() | |||||
| { | |||||
| return _non_append_mutation_value; | |||||
| } | |||||
| /// <summary> | |||||
| /// Allows storage of non-trackable objects. | |||||
| /// </summary> | |||||
| /// <param name="value"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| protected override Trackable _track_value(Trackable value, string name) | |||||
| { | |||||
| try | |||||
| { | |||||
| base._track_value(value, name); | |||||
| } | |||||
| catch(ValueError ex) | |||||
| { | |||||
| value = sticky_attribute_assignment(this, name, value); | |||||
| } | |||||
| return value; | |||||
| } | |||||
| public object Clone() | |||||
| { | |||||
| var res = new ListWrapper(_storage.Select(x => x).ToList()); | |||||
| res.NonAppendMuation= _non_append_mutation_value; | |||||
| res.ExternalModification = _external_modification_value; | |||||
| return res; | |||||
| } | |||||
| public Trackable this[int index] { | |||||
| get => _storage[index]; | |||||
| set | |||||
| { | |||||
| // skip the process of `Slice`, maybe support it in the future. | |||||
| _non_append_mutation_value = true; | |||||
| _storage[index] = _track_value(value, _name_element(index)); | |||||
| update_snapshot(); | |||||
| } | |||||
| } | |||||
| public int IndexOf(Trackable item) => _storage.IndexOf(item); | |||||
| public void Insert(int index, Trackable item) | |||||
| { | |||||
| check_external_modification(); | |||||
| _non_append_mutation_value = true; | |||||
| _storage.Insert(index, item); | |||||
| update_snapshot(); | |||||
| } | |||||
| public void RemoveAt(int index) | |||||
| { | |||||
| check_external_modification(); | |||||
| if (has_mutation_or_trackable()) | |||||
| { | |||||
| _non_append_mutation_value = true; | |||||
| } | |||||
| _storage.RemoveAt(index); | |||||
| update_snapshot(); | |||||
| } | |||||
| public int Count { get => _storage.Count; } | |||||
| public void Add(Trackable item) | |||||
| { | |||||
| check_external_modification(); | |||||
| _storage.Add(item); | |||||
| update_snapshot(); | |||||
| } | |||||
| public void Clear() => _storage.Clear(); | |||||
| public bool Contains(Trackable item) => _storage.Contains(item); | |||||
| public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex); | |||||
| public bool Remove(Trackable item) | |||||
| { | |||||
| check_external_modification(); | |||||
| if (has_mutation_or_trackable()) | |||||
| { | |||||
| _non_append_mutation_value = true; | |||||
| } | |||||
| var res = _storage.Remove(item); | |||||
| update_snapshot(); | |||||
| return res; | |||||
| } | |||||
| public IEnumerator<Trackable> GetEnumerator() => _storage.GetEnumerator(); | |||||
| IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); | |||||
| protected string _name_element(int index) => $"{index}"; | |||||
| } | |||||
| } | |||||
| @@ -22,7 +22,7 @@ namespace Tensorflow | |||||
| protected bool _in_graph_mode; | protected bool _in_graph_mode; | ||||
| protected bool _trainable; | protected bool _trainable; | ||||
| public bool trainable => _trainable; | |||||
| public bool Trainable => _trainable; | |||||
| protected Tensor _initial_value; | protected Tensor _initial_value; | ||||
| @@ -166,7 +166,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| void variable_accessed(BaseResourceVariable variable) | void variable_accessed(BaseResourceVariable variable) | ||||
| { | { | ||||
| if (variable.trainable) | |||||
| if (variable.Trainable) | |||||
| { | { | ||||
| foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
| tape.VariableAccessed(variable as ResourceVariable); | tape.VariableAccessed(variable as ResourceVariable); | ||||
| @@ -46,6 +46,7 @@ namespace Tensorflow | |||||
| Graph Graph { get; } | Graph Graph { get; } | ||||
| TF_DataType dtype { get; } | TF_DataType dtype { get; } | ||||
| Shape shape { get; } | Shape shape { get; } | ||||
| bool Trainable { get; } | |||||
| Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
| Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
| IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); | IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null); | ||||
| @@ -56,6 +56,7 @@ namespace Tensorflow | |||||
| public string Name => _variable.name; | public string Name => _variable.name; | ||||
| public Tensor eval() => _variable; | public Tensor eval() => _variable; | ||||
| public bool Trainable => _trainable; | |||||
| public RefVariable(object initial_value = null, | public RefVariable(object initial_value = null, | ||||
| bool trainable = true, | bool trainable = true, | ||||
| @@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| @@ -20,6 +21,30 @@ namespace Tensorflow.Keras.Engine | |||||
| Dictionary<long, int> tensor_usage_count; | Dictionary<long, int> tensor_usage_count; | ||||
| /// <summary> | |||||
| /// Dictionary of layer dependencies to be included in the checkpoint. | |||||
| /// </summary> | |||||
| public IDictionary<string, ILayer> LayerCheckpointDependencies | |||||
| { | |||||
| get | |||||
| { | |||||
| int weight_layer_index = 0; | |||||
| Dictionary<string, ILayer> dependencies = new(); | |||||
| for(int i = 0; i < Layers.Count; i++) | |||||
| { | |||||
| var layer = Layers[i]; | |||||
| var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList(); | |||||
| if(weights.Count > 0) | |||||
| { | |||||
| dependencies[$"layer_with_weights-{weight_layer_index}"] = layer; | |||||
| weight_layer_index++; | |||||
| } | |||||
| dependencies[$"layer-{i}"] = layer; | |||||
| } | |||||
| return dependencies; | |||||
| } | |||||
| } | |||||
| public Functional(Tensors inputs, Tensors outputs, string name = null) | public Functional(Tensors inputs, Tensors outputs, string name = null) | ||||
| : base(new ModelArgs | : base(new ModelArgs | ||||
| { | { | ||||
| @@ -325,5 +350,11 @@ namespace Tensorflow.Keras.Engine | |||||
| return output_tensors; | return output_tensors; | ||||
| } | } | ||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | |||||
| .ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Train; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -108,5 +109,15 @@ namespace Tensorflow.Keras.Engine | |||||
| return variables; | return variables; | ||||
| } | } | ||||
| } | } | ||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null) | |||||
| { | |||||
| if(save_type == SaveType.SAVEDMODEL) | |||||
| { | |||||
| //TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`. | |||||
| } | |||||
| var children = base._trackable_children(save_type, cache); | |||||
| return children; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -29,6 +29,18 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
| throw new System.NotImplementedException(); | throw new System.NotImplementedException(); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Generates or retrieves serialized attributes from cache. | |||||
| /// </summary> | |||||
| /// <param name="serialization_cache"></param> | |||||
| protected void get_serialized_attributes(IDictionary<string, object> serialization_cache) | |||||
| { | |||||
| // TODO: deal with cache. | |||||
| Layer a; | |||||
| } | |||||
| public override string TrackingMetadata | public override string TrackingMetadata | ||||
| { | { | ||||
| get | get | ||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel | |||||
| { | |||||
| /// <summary> | |||||
| /// Class that tracks and validates all serialization attributes. | |||||
| /// </summary> | |||||
| public class SerializedAttributes | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -4,7 +4,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | public partial class KerasSavedModelUtils | ||||
| { | { | ||||
| public static bool ShouldHaveTraces { get; internal set; } | |||||
| public static bool ShouldHaveTraces { get; internal set; } = true; | |||||
| public static SaveOptionsContext keras_option_scope(bool save_traces) | public static SaveOptionsContext keras_option_scope(bool save_traces) | ||||
| { | { | ||||
| @@ -23,7 +23,7 @@ public class SaveOptionsContext: IDisposable | |||||
| public bool _old_value; | public bool _old_value; | ||||
| public SaveOptionsContext(bool old_value) | public SaveOptionsContext(bool old_value) | ||||
| { | { | ||||
| _old_value = true; | |||||
| _old_value = old_value; | |||||
| } | } | ||||
| public void Dispose() | public void Dispose() | ||||