Browse Source

Add ListWrapper and ITrackable, and revise implmentations.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
ddd06ab9b6
14 changed files with 513 additions and 6 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  3. +12
    -0
      src/TensorFlowNET.Core/Training/ITrackable.cs
  4. +9
    -0
      src/TensorFlowNET.Core/Training/LayerUtils.cs
  5. +49
    -1
      src/TensorFlowNET.Core/Training/Trackable.cs
  6. +364
    -0
      src/TensorFlowNET.Core/Training/data_structures.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  8. +1
    -0
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  9. +1
    -0
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  10. +31
    -0
      src/TensorFlowNET.Keras/Engine/Functional.cs
  11. +11
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  12. +12
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  13. +14
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  14. +2
    -2
      src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs

+ 2
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,10 +1,11 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Training;

namespace Tensorflow.Keras
{
public interface ILayer
public interface ILayer: ITrackable
{
string Name { get; }
bool Trainable { get; }


+ 3
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
using static Tensorflow.Binding;

@@ -147,5 +148,7 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public Trackable GetTrackable() { throw new NotImplementedException(); }
}
}

+ 12
- 0
src/TensorFlowNET.Core/Training/ITrackable.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{
public interface ITrackable
{
Trackable GetTrackable();
}
}

+ 9
- 0
src/TensorFlowNET.Core/Training/LayerUtils.cs View File

@@ -0,0 +1,9 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Training
{

}

+ 49
- 1
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -18,11 +18,12 @@ using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.ModelSaving;
using Tensorflow.Training;
using static Tensorflow.Binding;

namespace Tensorflow.Train
{
public abstract class Trackable
public abstract class Trackable: ITrackable
{
/// <summary>
/// Corresponding to tensorflow/python/trackable/constants.py
@@ -40,6 +41,7 @@ namespace Tensorflow.Train

protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
new Dictionary<string, ResourceVariable>();
private bool _manual_tracking = true;

private static Trackable _none = new Function();
/// <summary>
@@ -54,6 +56,10 @@ namespace Tensorflow.Train
return _none;
}
}
public Trackable GetTrackable()
{
return this;
}
public virtual string ObjectIdentifier
{
get => "_generic_user_object";
@@ -128,6 +134,48 @@ namespace Tensorflow.Train
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
}

public virtual Trackable _track_trackable(Trackable trackable, string name, bool overwrite = false)
{
_maybe_initialize_trackable();
if (!_manual_tracking) return trackable;
var new_reference = new TrackableReference(name, trackable);
var current_object = _lookupup_dependency(name);
if(current_object is null)
{
_unconditional_checkpoint_dependencies.Add(new_reference);
_handle_deferred_dependencies(name, trackable);
}
_unconditional_dependency_names[name] = trackable;
return trackable;
}

/// <summary>
/// Pop and load any deferred checkpoint restores into `trackable`.
/// This method does not add a new dependency on `trackable`, but it does check if any outstanding/deferred dependencies have been queued waiting for
/// this dependency to be added (matched based on `name`). If so, `trackable` and its dependencies are restored. The restorations are
/// considered fulfilled and so are deleted.
/// `_track_trackable` is more appropriate for adding a normal/unconditional dependency, and includes handling for deferred restorations.
/// This method allows objects such as `Optimizer` to use the same restoration logic while managing conditional dependencies themselves,
/// by overriding `_checkpoint_dependencies` and `_lookup_dependency` to change the object's dependencies based on the context
/// it is saved/restored in (a single optimizer instance can have state associated with multiple graphs).
/// </summary>
/// <param name="name"></param>
/// <param name="trackable"></param>
public virtual void _handle_deferred_dependencies(string name, Trackable trackable)
{
//_maybe_initialize_trackable();
//trackable._maybe_initialize_trackable();
// TODO: complete the implementation.
}

public virtual Trackable? _lookupup_dependency(string name)
{
if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency;
else return null;
}

public static Trackable convert_to_trackable(object obj, object? parent = null)
{
if (obj is Trackable)


+ 364
- 0
src/TensorFlowNET.Core/Training/data_structures.cs View File

@@ -0,0 +1,364 @@
using Google.Protobuf;
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO.Compression;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Keras;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using static Tensorflow.ApiDef.Types;

namespace Tensorflow.Training
{
public class NoDependency
{
public Trackable Value { get; set; }
public NoDependency(Trackable value)
{
Value = value;
}
}

public abstract class TrackableDataStructure : Trackable
{
private bool _self_trainable;
private List<IVariableV1> _self_extra_variables;

public TrackableDataStructure()
{
_self_trainable = true;
_self_extra_variables = new List<IVariableV1>();
}

public abstract IEnumerable<Trackable> Values { get; }
public bool Trainable { get => _self_trainable; set => _self_trainable = value; }
public IEnumerable<ILayer> Layers
{
get
{
List<ILayer> collected = new();
foreach(var obj in Values)
{
if(obj is ILayer)
{
collected.Add((ILayer)obj);
}
else if(obj is TrackableDataStructure)
{
collected.AddRange((obj as TrackableDataStructure).Layers);
}
}
return collected;
}
}
public IEnumerable<IVariableV1> TrainableWeights
{
get
{
if (!_self_trainable)
{
return new List<IVariableV1>();
}
List<IVariableV1> trainable_variables = new();
foreach (var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables);
}
}
foreach(var v in _self_extra_variables)
{
if (v.Trainable)
{
trainable_variables.Add(v);
}
}
return trainable_variables;
}
}
public IEnumerable<IVariableV1> NonTrainableWeights
{
get
{
var trainable_extra_variables = _self_extra_variables.TakeWhile(x => x.Trainable).ToList();
var non_trainable_extra_variables = _self_extra_variables.TakeWhile(x => !x.Trainable).ToList();
List<IVariableV1> non_trainable_variables = new();
foreach(var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
non_trainable_variables.AddRange((obj as TrackableDataStructure).NonTrainableVariables);
}
}

if (!_self_trainable)
{
// Return order is all trainable vars, then all non-trainable vars.
List<IVariableV1> trainable_variables = new();
foreach(var obj in Values)
{
// skip the process of `module.Module`.
if (obj is TrackableDataStructure)
{
trainable_variables.AddRange((obj as TrackableDataStructure).TrainableVariables);
}
}
return trainable_variables.concat(trainable_extra_variables).concat(non_trainable_variables).concat(non_trainable_extra_variables);
}
else
{
return non_trainable_variables.concat(non_trainable_extra_variables);
}
}
}
public IEnumerable<IVariableV1> Weights => TrainableWeights.Concat(NonTrainableWeights);
public IEnumerable<IVariableV1> TrainableVariables => TrainableWeights;
public IEnumerable<IVariableV1> NonTrainableVariables => NonTrainableWeights;
public IEnumerable<IVariableV1> Variables => Weights;

// TODO: `losses` property.

/// <summary>
/// Add a dependency on `value`.
/// </summary>
/// <param name="value"></param>
/// <param name="name"></param>
protected virtual Trackable _track_value(Trackable value, string name)
{
value = sticky_attribute_assignment(this, name, value);
if(value is IVariableV1)
{
_self_extra_variables.Add(value as IVariableV1);
}
// skip the left process (need to be done in the future).
return value;
}

protected static Trackable wrap_or_unwrap(NoDependency value)
{
return value.Value;
}

protected static Trackable wrap_or_unwrap(Trackable value)
{
return value;
}

protected static Trackable wrap_or_unwrap(IList<Trackable> value)
{
return new ListWrapper(value);
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value)
{
value = wrap_or_unwrap(value);
trackable._track_trackable(value, name, true);
return value;
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
}

protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList<Trackable> value)
{
var wrapped_value = wrap_or_unwrap(value);
trackable._track_trackable(wrapped_value, name, true);
return wrapped_value;
}
}

public class ListWrapper : TrackableDataStructure, IList<Trackable>, ICloneable
{
private IList<Trackable> _storage;
private bool _non_append_mutation_value;
private bool _external_modification_value;
private IList<Trackable> _last_wrapped_list_snapshot;
/// <summary>
///
/// </summary>
/// <param name="wrapped_list">The initial value of the data structure. A shallow copy may be maintained for error checking. `wrapped_list` itself should not be
/// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save.</param>
public ListWrapper(IList<Trackable> wrapped_list)
{
_storage = wrapped_list;
_non_append_mutation_value = _external_modification_value = false;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

protected bool NonAppendMuation {
get => _non_append_mutation_value;
set
{
// TODO: deal with `attribute_sentinel`.
_non_append_mutation_value = value;
}
}

protected bool ExternalModification
{
get => _external_modification_value;
set
{
// TODO: deal with `attribute_sentinel`.
_external_modification_value = value;
}
}

public override IEnumerable<Trackable> Values => this;
public bool IsReadOnly { get => _storage.IsReadOnly; }

/// <summary>
/// Checks for any changes to the wrapped list not through the wrapper.
/// </summary>
private void check_external_modification()
{
if (_external_modification_value || _non_append_mutation_value) return;
if (!_storage.SequenceEqual(_last_wrapped_list_snapshot))
{
_external_modification_value = true;
}
}

private void update_snapshot()
{
// TODO: deal with `attribute_sentinel`.
if (_external_modification_value || _non_append_mutation_value) return;
_last_wrapped_list_snapshot = new List<Trackable>(_storage);
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
{
check_external_modification();
if (_non_append_mutation_value)
{
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). A list element was replaced" +
$", deleted or moved (sort). In order to support restoration on object creation, tracking is exclusively for append-only data structures." +
$"\n\nIf you don't need this list checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.");
}
if (_external_modification_value)
{
throw new ValueError($"Unable to save the object {this} (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified " +
$"outside the wrapper (its final value was {_storage}, its value when a checkpoint dependency was added was {_last_wrapped_list_snapshot}), which breaks " +
$"restoration on object creation.\n\nIf you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored.");
}
var children = base._trackable_children(save_type, cache);

if(save_type == SaveType.SAVEDMODEL)
{
children = children.Concat(this.TakeWhile(x => x is Function or ConcreteFunction).Select((x, idx) => new KeyValuePair<string, Trackable>(idx.ToString(), x))).ToDictionary(x => x.Key, x => x.Value);
}

return children;
}

private bool has_mutation_or_trackable()
{
return _non_append_mutation_value;
}

/// <summary>
/// Allows storage of non-trackable objects.
/// </summary>
/// <param name="value"></param>
/// <param name="name"></param>
/// <returns></returns>
protected override Trackable _track_value(Trackable value, string name)
{
try
{
base._track_value(value, name);
}
catch(ValueError ex)
{
value = sticky_attribute_assignment(this, name, value);
}
return value;
}

public object Clone()
{
var res = new ListWrapper(_storage.Select(x => x).ToList());
res.NonAppendMuation= _non_append_mutation_value;
res.ExternalModification = _external_modification_value;
return res;
}

public Trackable this[int index] {
get => _storage[index];
set
{
// skip the process of `Slice`, maybe support it in the future.
_non_append_mutation_value = true;
_storage[index] = _track_value(value, _name_element(index));

update_snapshot();
}
}

public int IndexOf(Trackable item) => _storage.IndexOf(item);

public void Insert(int index, Trackable item)
{
check_external_modification();
_non_append_mutation_value = true;
_storage.Insert(index, item);
update_snapshot();
}

public void RemoveAt(int index)
{
check_external_modification();
if (has_mutation_or_trackable())
{
_non_append_mutation_value = true;
}
_storage.RemoveAt(index);
update_snapshot();
}

public int Count { get => _storage.Count; }

public void Add(Trackable item)
{
check_external_modification();
_storage.Add(item);
update_snapshot();
}

public void Clear() => _storage.Clear();

public bool Contains(Trackable item) => _storage.Contains(item);

public void CopyTo(Trackable[] array, int arrayIndex) => _storage.CopyTo(array, arrayIndex);

public bool Remove(Trackable item)
{
check_external_modification();
if (has_mutation_or_trackable())
{
_non_append_mutation_value = true;
}
var res = _storage.Remove(item);
update_snapshot();
return res;
}

public IEnumerator<Trackable> GetEnumerator() => _storage.GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator();

protected string _name_element(int index) => $"{index}";
}
}

+ 2
- 2
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow
protected bool _in_graph_mode;

protected bool _trainable;
public bool trainable => _trainable;
public bool Trainable => _trainable;

protected Tensor _initial_value;

@@ -166,7 +166,7 @@ namespace Tensorflow
/// </summary>
void variable_accessed(BaseResourceVariable variable)
{
if (variable.trainable)
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
tape.VariableAccessed(variable as ResourceVariable);


+ 1
- 0
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -46,6 +46,7 @@ namespace Tensorflow
Graph Graph { get; }
TF_DataType dtype { get; }
Shape shape { get; }
bool Trainable { get; }
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null);


+ 1
- 0
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -56,6 +56,7 @@ namespace Tensorflow
public string Name => _variable.name;

public Tensor eval() => _variable;
public bool Trainable => _trainable;

public RefVariable(object initial_value = null,
bool trainable = true,


+ 31
- 0
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
@@ -20,6 +21,30 @@ namespace Tensorflow.Keras.Engine

Dictionary<long, int> tensor_usage_count;

/// <summary>
/// Dictionary of layer dependencies to be included in the checkpoint.
/// </summary>
public IDictionary<string, ILayer> LayerCheckpointDependencies
{
get
{
int weight_layer_index = 0;
Dictionary<string, ILayer> dependencies = new();
for(int i = 0; i < Layers.Count; i++)
{
var layer = Layers[i];
var weights = layer.TrainableWeights.concat(layer.NonTrainableWeights).ToList();
if(weights.Count > 0)
{
dependencies[$"layer_with_weights-{weight_layer_index}"] = layer;
weight_layer_index++;
}
dependencies[$"layer-{i}"] = layer;
}
return dependencies;
}
}

public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs
{
@@ -325,5 +350,11 @@ namespace Tensorflow.Keras.Engine

return output_tensors;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
{
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value);
}
}
}

+ 11
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Train;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -108,5 +109,15 @@ namespace Tensorflow.Keras.Engine
return variables;
}
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
{
if(save_type == SaveType.SAVEDMODEL)
{
//TODO: deal with `train_function`, `test_function`, `predict_function`, `train_tf_function`.
}
var children = base._trackable_children(save_type, cache);
return children;
}
}
}

+ 12
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -29,6 +29,18 @@ public class LayerSavedModelSaver: SavedModelSaver
throw new System.NotImplementedException();
}

/// <summary>
/// Generates or retrieves serialized attributes from cache.
/// </summary>
/// <param name="serialization_cache"></param>
protected void get_serialized_attributes(IDictionary<string, object> serialization_cache)
{
// TODO: deal with cache.
Layer a;

}

public override string TrackingMetadata
{
get


+ 14
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving.SavedModel
{
/// <summary>
/// Class that tracks and validates all serialization attributes.
/// </summary>
public class SerializedAttributes
{

}
}

+ 2
- 2
src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs View File

@@ -4,7 +4,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static bool ShouldHaveTraces { get; internal set; }
public static bool ShouldHaveTraces { get; internal set; } = true;

public static SaveOptionsContext keras_option_scope(bool save_traces)
{
@@ -23,7 +23,7 @@ public class SaveOptionsContext: IDisposable
public bool _old_value;
public SaveOptionsContext(bool old_value)
{
_old_value = true;
_old_value = old_value;
}

public void Dispose()


Loading…
Cancel
Save