| @@ -12,9 +12,9 @@ namespace Tensorflow.Checkpoint; | |||
| public static class CheckPointUtils | |||
| { | |||
| private static string _ESCAPE_CHAR = "."; | |||
| public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>, | |||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | |||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | |||
| Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||
| IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| @@ -149,4 +149,4 @@ public static class CheckPointUtils | |||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | |||
| // } | |||
| } | |||
| } | |||
| } | |||
| @@ -2,4 +2,4 @@ | |||
| public record class CheckpointOptions( | |||
| string? experimental_io_device = null, | |||
| bool experimental_enable_async_checkpoint = false); | |||
| bool experimental_enable_async_checkpoint = false); | |||
| @@ -45,7 +45,7 @@ public class ObjectGraphView: TrackableView, ICloneable | |||
| get => _attached_dependencies; | |||
| } | |||
| public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| { | |||
| return base._descendants_with_paths(); | |||
| } | |||
| @@ -61,4 +61,4 @@ public class ObjectGraphView: TrackableView, ICloneable | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -58,7 +58,7 @@ namespace Tensorflow.Checkpoint | |||
| return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | |||
| } | |||
| private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||
| private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| Dictionary<Trackable, string> object_names = new(); | |||
| @@ -173,7 +173,7 @@ namespace Tensorflow.Checkpoint | |||
| tensor_dict[checkpoint_key] = maybe_tensor; | |||
| if(maybe_tensor.GetValueA() is SaveSpec) | |||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | |||
| @@ -13,7 +13,7 @@ namespace Tensorflow.Checkpoint; | |||
| public static class SaveUtilV1 | |||
| { | |||
| public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||
| public static (IDictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||
| IDictionary<Trackable, Trackable>? object_map = null) | |||
| { | |||
| // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | |||
| @@ -44,7 +44,7 @@ public static class SaveUtilV1 | |||
| return (checkpoint_factory_map, null); | |||
| } | |||
| public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||
| public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||
| object? saveables_cache = null) | |||
| { | |||
| @@ -73,7 +73,7 @@ public static class SaveUtilV1 | |||
| } | |||
| } | |||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||
| public static (IList<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| @@ -129,7 +129,8 @@ public static class SaveUtilV1 | |||
| return object_graph_proto; | |||
| } | |||
| private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||
| private static (IList<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph( | |||
| IList<Trackable> trackable_objects, | |||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||
| bool call_with_mapped_captures, object? saveables_cache = null) | |||
| @@ -150,7 +151,7 @@ public static class SaveUtilV1 | |||
| return (named_saveable_objects, feed_additions, null); | |||
| } | |||
| public static (List<MySaveableObject>, object?) generate_saveable_objects( | |||
| public static (IList<MySaveableObject>, object?) generate_saveable_objects( | |||
| IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | |||
| TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | |||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||
| @@ -178,13 +179,13 @@ public static class SaveUtilV1 | |||
| // TODO: oneflow python has a process with callable `saveable_factory`. | |||
| List<MySaveableObject> saveables = new(); | |||
| if (maybe_saveable.DataType == typeof(MySaveableObject)) | |||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||
| { | |||
| saveables.Add(maybe_saveable.GetValueB()); | |||
| saveables.Add(s); | |||
| } | |||
| else | |||
| { | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); | |||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||
| } | |||
| foreach (var saveable in saveables) | |||
| @@ -219,4 +220,4 @@ public record class CheckpointFactoryData | |||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| ); | |||
| @@ -52,7 +52,7 @@ public class TrackableView | |||
| /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | |||
| /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | |||
| /// </summary> | |||
| protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||
| protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||
| { | |||
| List<Trackable> bfs_sorted = new(); | |||
| Queue<Trackable> to_visit = new(); | |||
| @@ -14,112 +14,91 @@ using Tensorflow.Training; | |||
| using Tensorflow.Graphs; | |||
| using System.Xml.Linq; | |||
| using System.Diagnostics; | |||
| using RestoreFunc = System.Func<object, object>; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| /// <summary> | |||
| /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. | |||
| /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. | |||
| /// </summary> | |||
| public interface IFunctionHolder | |||
| { | |||
| int ArgCount { get; } | |||
| object DynamicInvoke(params object[] args); | |||
| } | |||
| internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder | |||
| { | |||
| public int ArgCount => 0; | |||
| public object DynamicInvoke(params object[] args) | |||
| { | |||
| return Func.DynamicInvoke(args); | |||
| } | |||
| public TR Invoke() | |||
| { | |||
| return Func.Invoke(); | |||
| } | |||
| } | |||
| internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||
| { | |||
| public int ArgCount => 1; | |||
| public object DynamicInvoke(params object[] args) | |||
| { | |||
| return Func.DynamicInvoke(args); | |||
| } | |||
| } | |||
| internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder | |||
| { | |||
| public int ArgCount => 2; | |||
| public object DynamicInvoke(params object[] args) | |||
| { | |||
| return Func.DynamicInvoke(args); | |||
| } | |||
| } | |||
| internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder | |||
| { | |||
| public int ArgCount => 3; | |||
| public object DynamicInvoke(params object[] args) | |||
| { | |||
| return Func.DynamicInvoke(args); | |||
| } | |||
| } | |||
| public class Maybe<TA, TB> | |||
| { | |||
| private TA? _valueA = default(TA); | |||
| private TB? _valueB = default(TB); | |||
| private Type _type; | |||
| private bool _assigned = false; | |||
| private bool _assignedTA; | |||
| public Maybe(TA value) | |||
| { | |||
| _valueA = value; | |||
| _type= typeof(TA); | |||
| _assigned = true; | |||
| _assignedTA = true; | |||
| } | |||
| public Maybe(TB value) | |||
| { | |||
| _valueB = value; | |||
| _type = typeof(TB); | |||
| _assigned = true; | |||
| _assignedTA = false; | |||
| } | |||
| public Type DataType => _type; | |||
| public TA GetValueA() | |||
| /// <summary> | |||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||
| /// It returns | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <param name="res"></param> | |||
| /// <returns></returns> | |||
| public bool TryGet<T>(out T? res) | |||
| { | |||
| if(!_assigned || DataType != typeof(TA)) | |||
| if(_valueA is T && _valueB is not T) | |||
| { | |||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||
| res = (T)(object)_valueA; | |||
| return _assignedTA; | |||
| } | |||
| return _valueA; | |||
| } | |||
| public TB GetValueB() | |||
| { | |||
| if (!_assigned || DataType != typeof(TB)) | |||
| else if(_valueA is not T && _valueB is T) | |||
| { | |||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||
| res = (T)(object)_valueB; | |||
| return !_assignedTA; | |||
| } | |||
| return _valueB; | |||
| res = default(T); | |||
| return false; | |||
| } | |||
| public object GetValue() | |||
| public bool IsTypeOrDeriveFrom<T>() | |||
| { | |||
| if (!_assigned) | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||
| return _assignedTA; | |||
| } | |||
| if(DataType == typeof(TA) && _valueA is not null) | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return _valueA; | |||
| return !_assignedTA; | |||
| } | |||
| else if(DataType == typeof(TB) && _valueB is not null) | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| return _valueB; | |||
| return true; | |||
| } | |||
| else if(DataType == typeof(TA)) | |||
| else | |||
| { | |||
| return _valueA; | |||
| return false; | |||
| } | |||
| } | |||
| public T GetValue<T>() | |||
| { | |||
| if (_valueA is T && _valueB is not T) | |||
| { | |||
| return (T)(object)_valueA; | |||
| } | |||
| else if (_valueA is not T && _valueB is T) | |||
| { | |||
| return (T)(object)_valueB; | |||
| } | |||
| else if (_valueA is T && _valueB is T) | |||
| { | |||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||
| } | |||
| else | |||
| { | |||
| return _valueB; | |||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||
| } | |||
| } | |||
| @@ -170,9 +149,8 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| if(maybe_tensor.DataType == typeof(SaveSpec)) | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| { | |||
| var spec = maybe_tensor.GetValueB(); | |||
| var tensor_value = spec.tensor; | |||
| if (tensor_value is not null) | |||
| { | |||
| @@ -183,7 +161,7 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValueA(); | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_names.Add(checkpoint_key); | |||
| tensors.Add(tensor); | |||
| slice_specs.Add(slice_spec); | |||
| @@ -215,16 +193,15 @@ namespace Tensorflow.Checkpoint | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
| if(maybe_tensor.DataType == typeof(SaveSpec)) | |||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||
| { | |||
| var spec = maybe_tensor.GetValueB(); | |||
| tensor_dtypes.Add(spec.dtype); | |||
| slice_specs.Add(spec.slice_spec); | |||
| tensor_names.Add(spec.name); | |||
| } | |||
| else | |||
| { | |||
| var tensor = maybe_tensor.GetValueA(); | |||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||
| tensor_dtypes.Add(tensor.dtype); | |||
| slice_specs.Add(slice_spec); | |||
| tensor_names.Add(checkpoint_key); | |||
| @@ -268,9 +245,9 @@ namespace Tensorflow.Checkpoint | |||
| public class MultiDeviceSaver | |||
| { | |||
| private Dictionary<string, SingleDeviceSaver> _single_device_savers; | |||
| private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers; | |||
| private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; | |||
| private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys; | |||
| private IDictionary<string, (RestoreFunc, RestoreFunc)> _registered_savers; | |||
| private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; | |||
| private Dictionary<RestoreFunc, IList<(string, string)>> _restore_fn_to_keys; | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -280,24 +257,28 @@ namespace Tensorflow.Checkpoint | |||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | |||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | |||
| { | |||
| _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); | |||
| _restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>(); | |||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | |||
| foreach(var pair in serialized_tensors) | |||
| { | |||
| var obj = pair.Key; | |||
| var tensor_dict = pair.Value; | |||
| IFunctionHolder restore_fn; | |||
| RestoreFunc restore_fn; | |||
| if(obj == Trackable.None) | |||
| { | |||
| restore_fn = new FunctionHolder<object?>(() => null); | |||
| restore_fn = new RestoreFunc(x => null); | |||
| } | |||
| else | |||
| { | |||
| restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x => | |||
| restore_fn = new RestoreFunc(x => | |||
| { | |||
| return obj._restore_from_tensors(x); | |||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||
| { | |||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||
| } | |||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||
| }); | |||
| } | |||
| @@ -305,14 +286,14 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var checkpoint_key = item.Key; | |||
| IDictionary<string, Tensor> spec_to_tensor; | |||
| if(item.Value.DataType != typeof(IDictionary<string, Tensor>)) | |||
| if(item.Value.TryGet<Tensor>(out var t)) | |||
| { | |||
| spec_to_tensor = new Dictionary<string, Tensor>(); | |||
| spec_to_tensor[""] = item.Value.GetValueA(); | |||
| spec_to_tensor[""] = t; | |||
| } | |||
| else | |||
| { | |||
| spec_to_tensor = item.Value.GetValueB(); | |||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| foreach(var spec in spec_to_tensor) | |||
| @@ -342,7 +323,7 @@ namespace Tensorflow.Checkpoint | |||
| _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | |||
| _registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>(); | |||
| _registered_savers = new Dictionary<string, (RestoreFunc, RestoreFunc)>(); | |||
| if(registered_savers is not null && registered_savers.Count > 0) | |||
| { | |||
| // TODO: complete the implementation. | |||
| @@ -418,8 +399,8 @@ namespace Tensorflow.Checkpoint | |||
| IDictionary<string, Operation> restore_func() | |||
| { | |||
| Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||
| Dictionary<string, Operation> restore_ops = new(); | |||
| foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | |||
| @@ -449,7 +430,7 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
| } | |||
| } | |||
| else | |||
| @@ -158,4 +158,4 @@ namespace Tensorflow | |||
| Dispose(false); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -11,4 +11,4 @@ public class AssertionError : TensorflowException | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -37,10 +37,10 @@ namespace Tensorflow.Train | |||
| var properties = this.GetType().GetProperties(); | |||
| foreach ( var property in properties ) | |||
| { | |||
| string name = property.Name; | |||
| object value = property.GetValue(this, null); | |||
| if(value is Function || value is ConcreteFunction) | |||
| if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) | |||
| { | |||
| string name = property.Name; | |||
| object value = property.GetValue(this, null); | |||
| functions[name] = (Trackable)value; | |||
| } | |||
| } | |||
| @@ -25,9 +25,9 @@ namespace Tensorflow | |||
| { | |||
| get | |||
| { | |||
| if(_op.DataType == typeof(Tensor)) | |||
| if(_op.TryGet<Tensor>(out var tensor)) | |||
| { | |||
| return _op.GetValueA(); | |||
| return tensor; | |||
| } | |||
| else | |||
| { | |||
| @@ -8,4 +8,4 @@ public record class AssetInfo | |||
| Dictionary<object, object> asset_initializers_by_resource, | |||
| Dictionary<AssetInfo, string> asset_filename_map, | |||
| Dictionary<object, object> asset_index | |||
| ); | |||
| ); | |||
| @@ -86,7 +86,7 @@ public class AugmentedGraphView: ObjectGraphView | |||
| return concrete_function; | |||
| } | |||
| public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||
| { | |||
| Trackable get_merged_trackable(Trackable x) | |||
| { | |||
| @@ -130,4 +130,4 @@ public class AugmentedGraphView: ObjectGraphView | |||
| { | |||
| return _children_cache[obj][name]; | |||
| } | |||
| } | |||
| } | |||
| @@ -30,4 +30,4 @@ public static class Constants | |||
| public static readonly string VARIABLES_DIRECTORY = "variables"; | |||
| public static readonly string VARIABLES_FILENAME = "variables"; | |||
| } | |||
| } | |||
| @@ -14,4 +14,4 @@ public class RevivedTypes | |||
| // TODO: complete the implementation. | |||
| return null; | |||
| } | |||
| } | |||
| } | |||
| @@ -6,4 +6,4 @@ public enum SaveType | |||
| { | |||
| SAVEDMODEL, | |||
| CHECKPOINT | |||
| } | |||
| } | |||
| @@ -18,13 +18,13 @@ public class SaveableView | |||
| { | |||
| private AugmentedGraphView _augmented_graph_view; | |||
| private SaveOptions _options; | |||
| private List<Trackable> _trackable_objects; | |||
| private IList<Trackable> _trackable_objects; | |||
| private List<Trackable> _nodes; | |||
| private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||
| private Dictionary<Trackable, int> _node_ids; | |||
| private IDictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||
| private IDictionary<Trackable, int> _node_ids; | |||
| private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | |||
| _slot_variables; | |||
| private Dictionary<Trackable, string> _object_names; | |||
| private IDictionary<Trackable, string> _object_names; | |||
| private List<object> _gradient_functions; // to be completed | |||
| private List<RegisteredGradient> _gradient_defs; // to be completed | |||
| private List<ConcreteFunction> _concrete_functions; | |||
| @@ -45,7 +45,7 @@ public class SaveableView | |||
| { | |||
| get => _nodes; | |||
| } | |||
| public Dictionary<Trackable, int> NodeIds | |||
| public IDictionary<Trackable, int> NodeIds | |||
| { | |||
| get => _node_ids; | |||
| } | |||
| @@ -53,7 +53,7 @@ public class SaveableView | |||
| { | |||
| get => _gradient_defs; | |||
| } | |||
| public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||
| public IDictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||
| { | |||
| get => _node_paths; | |||
| } | |||
| @@ -84,7 +84,7 @@ public class SaveableView | |||
| private void initialize_nodes_and_concrete_functions() | |||
| { | |||
| _nodes = _trackable_objects.ConvertAll(x => x); // deep copy | |||
| _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy | |||
| _gradient_functions = new(); | |||
| _gradient_defs = new(); | |||
| @@ -296,4 +296,4 @@ public class SaveableView | |||
| proto.Nodes.Add(object_proto); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -7,4 +7,4 @@ public static class TagConstants | |||
| public static readonly string EVAL = "eval"; | |||
| public static readonly string GPU = "gpu"; | |||
| public static readonly string TPU = "tpu"; | |||
| } | |||
| } | |||
| @@ -19,4 +19,4 @@ public class BuilderUtils | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -81,8 +81,8 @@ public static partial class SavedModelUtils | |||
| return (saved_nodes, node_paths); | |||
| } | |||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||
| Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList<Trackable>, | |||
| IDictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||
| ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||
| { | |||
| using (SaveContext.save_context(options)) | |||
| @@ -266,4 +266,4 @@ public static partial class SavedModelUtils | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -104,4 +104,4 @@ public class SignatureMap: Trackable | |||
| return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -54,4 +54,4 @@ public static partial class SavedModelUtils | |||
| { | |||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | |||
| } | |||
| } | |||
| } | |||
| @@ -136,9 +136,8 @@ namespace Tensorflow | |||
| { | |||
| full_name = name + "_" + attr; | |||
| } | |||
| if(factory.DataType == typeof(ResourceVariable)) | |||
| if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||
| { | |||
| var variable = factory.GetValueA(); | |||
| foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | |||
| { | |||
| yield return op; | |||
| @@ -146,8 +145,8 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| var variable = factory.GetValueB(); | |||
| foreach (var op in saveable_objects_for_op(variable, variable.name)) | |||
| var saveable = factory.GetValue<MySaveableObject>(); | |||
| foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||
| { | |||
| yield return op; | |||
| } | |||
| @@ -236,14 +235,14 @@ namespace Tensorflow | |||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | |||
| IDictionary<string, Tensor> internal_dict; | |||
| if(maybe_tensor.DataType == typeof(Tensor)) | |||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| { | |||
| internal_dict= new Dictionary<string, Tensor>(); | |||
| internal_dict[""] = maybe_tensor.GetValueA(); | |||
| internal_dict[""] = tensor; | |||
| } | |||
| else | |||
| { | |||
| internal_dict = maybe_tensor.GetValueB(); | |||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| foreach(var item in internal_dict) | |||
| @@ -287,7 +286,7 @@ namespace Tensorflow | |||
| var slice_spec = convert_to_string(spec.slice_spec); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor; | |||
| } | |||
| else | |||
| { | |||
| @@ -318,14 +317,14 @@ namespace Tensorflow | |||
| var maybe_tensor = restored_tensors[name]; | |||
| IDictionary<string, Tensor> dict; | |||
| if(maybe_tensor.DataType == typeof(Tensor)) | |||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||
| { | |||
| dict = new Dictionary<string, Tensor>(); | |||
| dict[""] = maybe_tensor.GetValueA(); | |||
| dict[""] = tensor; | |||
| } | |||
| else | |||
| { | |||
| dict = maybe_tensor.GetValueB(); | |||
| dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||
| } | |||
| saveable_restored_tensors.Add(dict[slice_spec]); | |||
| } | |||
| @@ -38,4 +38,4 @@ public static class Constants | |||
| RNN_LAYER_IDENTIFIER, | |||
| SEQUENTIAL_IDENTIFIER | |||
| }; | |||
| } | |||
| } | |||
| @@ -1,11 +0,0 @@ | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public class KerasObjectWrapper | |||
| { | |||
| } | |||
| public class KerasObjectWrapper<T> | |||
| { | |||
| public T Item { get; set; } | |||
| } | |||
| @@ -3,19 +3,15 @@ using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Google.Protobuf; | |||
| using ICSharpCode.SharpZipLib.Zip; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using Tensorflow.IO; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| @@ -108,5 +104,59 @@ public partial class KerasSavedModelUtils | |||
| return metadata; | |||
| } | |||
| } | |||
| public static bool should_skip_serialization(object layer) | |||
| { | |||
| return false; | |||
| } | |||
| /// <summary> | |||
| /// Returns extra trackable objects to attach to the serialized layer. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||
| var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| Dictionary<string, Trackable> res = new(); | |||
| res["variables"] = variables; | |||
| res["trainable_variables"] = trainable_variables; | |||
| res["non_trainable_variables"] = non_trainable_variables; | |||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| return res; | |||
| } | |||
| /// <summary> | |||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||
| // skip the process because of lack of APIs of `Layer`. | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| } | |||
| @@ -1,66 +0,0 @@ | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public partial class KerasSavedModelUtils | |||
| { | |||
| public static bool should_skip_serialization(object layer) | |||
| { | |||
| return false; | |||
| } | |||
| /// <summary> | |||
| /// Returns extra trackable objects to attach to the serialized layer. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||
| var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| Dictionary<string, Trackable> res = new(); | |||
| res["variables"] = variables; | |||
| res["trainable_variables"] = trainable_variables; | |||
| res["non_trainable_variables"] = non_trainable_variables; | |||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| return res; | |||
| } | |||
| /// <summary> | |||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||
| /// </summary> | |||
| /// <param name="layer"></param> | |||
| /// <param name="serialization_cache"></param> | |||
| /// <returns></returns> | |||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||
| // skip the process because of lack of APIs of `Layer`. | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| } | |||
| @@ -34,5 +34,4 @@ public abstract class SavedModelSaver | |||
| return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | |||
| .ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| } | |||
| @@ -162,4 +162,4 @@ public class InputLayerSavedModelSaver: SavedModelSaver | |||
| return JsonConvert.SerializeObject(info); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -44,4 +44,4 @@ public class SaveOptionsContext: IDisposable | |||
| { | |||
| KerasSavedModelUtils.ShouldHaveTraces = _old_value; | |||
| } | |||
| } | |||
| } | |||
| @@ -73,7 +73,7 @@ public class SequentialModelTest | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = false, | |||
| ValidationSize = 10000, | |||
| ValidationSize = 50000, | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| @@ -119,13 +119,13 @@ public class SequentialModelTest | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||
| var num_epochs = 1; | |||
| var batch_size = 16; | |||
| var batch_size = 8; | |||
| var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | |||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | |||
| model.save("./pb_elex_sequential", save_format: "tf"); | |||
| model.save("./pb_alex_sequential", save_format: "tf"); | |||
| // The saved model can be test with the following python code: | |||
| #region alexnet_python_code | |||
| @@ -136,7 +136,7 @@ public class SequentialModelTest | |||
| // return -a | |||
| //if __name__ == '__main__': | |||
| // model = tf.keras.models.load_model("./pb_elex_sequential") | |||
| // model = tf.keras.models.load_model("./pb_alex_sequential") | |||
| // model.summary() | |||
| // num_classes = 5 | |||