| @@ -12,9 +12,9 @@ namespace Tensorflow.Checkpoint; | |||||
| public static class CheckPointUtils | public static class CheckPointUtils | ||||
| { | { | ||||
| private static string _ESCAPE_CHAR = "."; | private static string _ESCAPE_CHAR = "."; | ||||
| public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>, | |||||
| public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>, | |||||
| IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>, | ||||
| Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||||
| IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) | |||||
| { | { | ||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | ||||
| Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
| @@ -149,4 +149,4 @@ public static class CheckPointUtils | |||||
| // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | // object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i); | ||||
| // } | // } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -2,4 +2,4 @@ | |||||
| public record class CheckpointOptions( | public record class CheckpointOptions( | ||||
| string? experimental_io_device = null, | string? experimental_io_device = null, | ||||
| bool experimental_enable_async_checkpoint = false); | |||||
| bool experimental_enable_async_checkpoint = false); | |||||
| @@ -45,7 +45,7 @@ public class ObjectGraphView: TrackableView, ICloneable | |||||
| get => _attached_dependencies; | get => _attached_dependencies; | ||||
| } | } | ||||
| public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| { | { | ||||
| return base._descendants_with_paths(); | return base._descendants_with_paths(); | ||||
| } | } | ||||
| @@ -61,4 +61,4 @@ public class ObjectGraphView: TrackableView, ICloneable | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -58,7 +58,7 @@ namespace Tensorflow.Checkpoint | |||||
| return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | return (serialized_tensors, feed_additions, registered_savers, object_graph_proto); | ||||
| } | } | ||||
| private static (List<TrackableData>, Dictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||||
| private static (IList<TrackableData>, IDictionary<Trackable, int>) gather_trackable_data(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map) | |||||
| { | { | ||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | ||||
| Dictionary<Trackable, string> object_names = new(); | Dictionary<Trackable, string> object_names = new(); | ||||
| @@ -173,7 +173,7 @@ namespace Tensorflow.Checkpoint | |||||
| tensor_dict[checkpoint_key] = maybe_tensor; | tensor_dict[checkpoint_key] = maybe_tensor; | ||||
| if(maybe_tensor.GetValueA() is SaveSpec) | |||||
| if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>()) | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; | ||||
| @@ -13,7 +13,7 @@ namespace Tensorflow.Checkpoint; | |||||
| public static class SaveUtilV1 | public static class SaveUtilV1 | ||||
| { | { | ||||
| public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||||
| public static (IDictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names, | |||||
| IDictionary<Trackable, Trackable>? object_map = null) | IDictionary<Trackable, Trackable>? object_map = null) | ||||
| { | { | ||||
| // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | // According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md, | ||||
| @@ -44,7 +44,7 @@ public static class SaveUtilV1 | |||||
| return (checkpoint_factory_map, null); | return (checkpoint_factory_map, null); | ||||
| } | } | ||||
| public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||||
| public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | ||||
| object? saveables_cache = null) | object? saveables_cache = null) | ||||
| { | { | ||||
| @@ -73,7 +73,7 @@ public static class SaveUtilV1 | |||||
| } | } | ||||
| } | } | ||||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||||
| public static (IList<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | ||||
| { | { | ||||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | ||||
| @@ -129,7 +129,8 @@ public static class SaveUtilV1 | |||||
| return object_graph_proto; | return object_graph_proto; | ||||
| } | } | ||||
| private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||||
| private static (IList<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph( | |||||
| IList<Trackable> trackable_objects, | |||||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | ||||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | ||||
| bool call_with_mapped_captures, object? saveables_cache = null) | bool call_with_mapped_captures, object? saveables_cache = null) | ||||
| @@ -150,7 +151,7 @@ public static class SaveUtilV1 | |||||
| return (named_saveable_objects, feed_additions, null); | return (named_saveable_objects, feed_additions, null); | ||||
| } | } | ||||
| public static (List<MySaveableObject>, object?) generate_saveable_objects( | |||||
| public static (IList<MySaveableObject>, object?) generate_saveable_objects( | |||||
| IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map, | ||||
| TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids, | ||||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | ||||
| @@ -178,13 +179,13 @@ public static class SaveUtilV1 | |||||
| // TODO: oneflow python has a process with callable `saveable_factory`. | // TODO: oneflow python has a process with callable `saveable_factory`. | ||||
| List<MySaveableObject> saveables = new(); | List<MySaveableObject> saveables = new(); | ||||
| if (maybe_saveable.DataType == typeof(MySaveableObject)) | |||||
| if (maybe_saveable.TryGet<MySaveableObject>(out var s)) | |||||
| { | { | ||||
| saveables.Add(maybe_saveable.GetValueB()); | |||||
| saveables.Add(s); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); | |||||
| saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue<BaseResourceVariable>() as Trackable, key)); | |||||
| } | } | ||||
| foreach (var saveable in saveables) | foreach (var saveable in saveables) | ||||
| @@ -219,4 +220,4 @@ public record class CheckpointFactoryData | |||||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | Maybe<BaseResourceVariable, MySaveableObject> factory, | ||||
| string name, | string name, | ||||
| string checkpoint_key | string checkpoint_key | ||||
| ); | |||||
| ); | |||||
| @@ -52,7 +52,7 @@ public class TrackableView | |||||
| /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | /// Returns a list of all nodes and its paths from self.root using a breadth first traversal. | ||||
| /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths | ||||
| /// </summary> | /// </summary> | ||||
| protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||||
| protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths() | |||||
| { | { | ||||
| List<Trackable> bfs_sorted = new(); | List<Trackable> bfs_sorted = new(); | ||||
| Queue<Trackable> to_visit = new(); | Queue<Trackable> to_visit = new(); | ||||
| @@ -14,112 +14,91 @@ using Tensorflow.Training; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using System.Xml.Linq; | using System.Xml.Linq; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using RestoreFunc = System.Func<object, object>; | |||||
| namespace Tensorflow.Checkpoint | namespace Tensorflow.Checkpoint | ||||
| { | { | ||||
| /// <summary> | |||||
| /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. | |||||
| /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. | |||||
| /// </summary> | |||||
| public interface IFunctionHolder | |||||
| { | |||||
| int ArgCount { get; } | |||||
| object DynamicInvoke(params object[] args); | |||||
| } | |||||
| internal record class FunctionHolder<TR>(Func<TR> Func): IFunctionHolder | |||||
| { | |||||
| public int ArgCount => 0; | |||||
| public object DynamicInvoke(params object[] args) | |||||
| { | |||||
| return Func.DynamicInvoke(args); | |||||
| } | |||||
| public TR Invoke() | |||||
| { | |||||
| return Func.Invoke(); | |||||
| } | |||||
| } | |||||
| internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||||
| { | |||||
| public int ArgCount => 1; | |||||
| public object DynamicInvoke(params object[] args) | |||||
| { | |||||
| return Func.DynamicInvoke(args); | |||||
| } | |||||
| } | |||||
| internal record class FunctionHolder<TA1, TA2, TR>(Func<TA1, TA2, TR> Func) : IFunctionHolder | |||||
| { | |||||
| public int ArgCount => 2; | |||||
| public object DynamicInvoke(params object[] args) | |||||
| { | |||||
| return Func.DynamicInvoke(args); | |||||
| } | |||||
| } | |||||
| internal record class FunctionHolder<TA1, TA2, TA3, TR>(Func<TA1, TA2, TA3, TR> Func) : IFunctionHolder | |||||
| { | |||||
| public int ArgCount => 3; | |||||
| public object DynamicInvoke(params object[] args) | |||||
| { | |||||
| return Func.DynamicInvoke(args); | |||||
| } | |||||
| } | |||||
| public class Maybe<TA, TB> | public class Maybe<TA, TB> | ||||
| { | { | ||||
| private TA? _valueA = default(TA); | private TA? _valueA = default(TA); | ||||
| private TB? _valueB = default(TB); | private TB? _valueB = default(TB); | ||||
| private Type _type; | private Type _type; | ||||
| private bool _assigned = false; | |||||
| private bool _assignedTA; | |||||
| public Maybe(TA value) | public Maybe(TA value) | ||||
| { | { | ||||
| _valueA = value; | _valueA = value; | ||||
| _type= typeof(TA); | _type= typeof(TA); | ||||
| _assigned = true; | |||||
| _assignedTA = true; | |||||
| } | } | ||||
| public Maybe(TB value) | public Maybe(TB value) | ||||
| { | { | ||||
| _valueB = value; | _valueB = value; | ||||
| _type = typeof(TB); | _type = typeof(TB); | ||||
| _assigned = true; | |||||
| _assignedTA = false; | |||||
| } | } | ||||
| public Type DataType => _type; | public Type DataType => _type; | ||||
| public TA GetValueA() | |||||
| /// <summary> | |||||
| /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. | |||||
| /// It returns | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="res"></param> | |||||
| /// <returns></returns> | |||||
| public bool TryGet<T>(out T? res) | |||||
| { | { | ||||
| if(!_assigned || DataType != typeof(TA)) | |||||
| if(_valueA is T && _valueB is not T) | |||||
| { | { | ||||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
| res = (T)(object)_valueA; | |||||
| return _assignedTA; | |||||
| } | } | ||||
| return _valueA; | |||||
| } | |||||
| public TB GetValueB() | |||||
| { | |||||
| if (!_assigned || DataType != typeof(TB)) | |||||
| else if(_valueA is not T && _valueB is T) | |||||
| { | { | ||||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
| res = (T)(object)_valueB; | |||||
| return !_assignedTA; | |||||
| } | } | ||||
| return _valueB; | |||||
| res = default(T); | |||||
| return false; | |||||
| } | } | ||||
| public object GetValue() | |||||
| public bool IsTypeOrDeriveFrom<T>() | |||||
| { | { | ||||
| if (!_assigned) | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | { | ||||
| throw new TypeError("Cannot get the data because of wrong specified type."); | |||||
| return _assignedTA; | |||||
| } | } | ||||
| if(DataType == typeof(TA) && _valueA is not null) | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | { | ||||
| return _valueA; | |||||
| return !_assignedTA; | |||||
| } | } | ||||
| else if(DataType == typeof(TB) && _valueB is not null) | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | { | ||||
| return _valueB; | |||||
| return true; | |||||
| } | } | ||||
| else if(DataType == typeof(TA)) | |||||
| else | |||||
| { | { | ||||
| return _valueA; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| public T GetValue<T>() | |||||
| { | |||||
| if (_valueA is T && _valueB is not T) | |||||
| { | |||||
| return (T)(object)_valueA; | |||||
| } | |||||
| else if (_valueA is not T && _valueB is T) | |||||
| { | |||||
| return (T)(object)_valueB; | |||||
| } | |||||
| else if (_valueA is T && _valueB is T) | |||||
| { | |||||
| throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return _valueB; | |||||
| throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -170,9 +149,8 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
| var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
| if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| { | { | ||||
| var spec = maybe_tensor.GetValueB(); | |||||
| var tensor_value = spec.tensor; | var tensor_value = spec.tensor; | ||||
| if (tensor_value is not null) | if (tensor_value is not null) | ||||
| { | { | ||||
| @@ -183,7 +161,7 @@ namespace Tensorflow.Checkpoint | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tensor = maybe_tensor.GetValueA(); | |||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
| tensors.Add(tensor); | tensors.Add(tensor); | ||||
| slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
| @@ -215,16 +193,15 @@ namespace Tensorflow.Checkpoint | |||||
| var slice_spec = slice.Key; | var slice_spec = slice.Key; | ||||
| var maybe_tensor = slice.Value; | var maybe_tensor = slice.Value; | ||||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | // TODO: deal with other types. Currently only `SaveSpec` is allowed. | ||||
| if(maybe_tensor.DataType == typeof(SaveSpec)) | |||||
| if(maybe_tensor.TryGet<SaveSpec>(out var spec)) | |||||
| { | { | ||||
| var spec = maybe_tensor.GetValueB(); | |||||
| tensor_dtypes.Add(spec.dtype); | tensor_dtypes.Add(spec.dtype); | ||||
| slice_specs.Add(spec.slice_spec); | slice_specs.Add(spec.slice_spec); | ||||
| tensor_names.Add(spec.name); | tensor_names.Add(spec.name); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var tensor = maybe_tensor.GetValueA(); | |||||
| var tensor = maybe_tensor.GetValue<Tensor>(); | |||||
| tensor_dtypes.Add(tensor.dtype); | tensor_dtypes.Add(tensor.dtype); | ||||
| slice_specs.Add(slice_spec); | slice_specs.Add(slice_spec); | ||||
| tensor_names.Add(checkpoint_key); | tensor_names.Add(checkpoint_key); | ||||
| @@ -268,9 +245,9 @@ namespace Tensorflow.Checkpoint | |||||
| public class MultiDeviceSaver | public class MultiDeviceSaver | ||||
| { | { | ||||
| private Dictionary<string, SingleDeviceSaver> _single_device_savers; | private Dictionary<string, SingleDeviceSaver> _single_device_savers; | ||||
| private IDictionary<string, (IFunctionHolder, IFunctionHolder)> _registered_savers; | |||||
| private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; | |||||
| private Dictionary<IFunctionHolder, IList<(string, string)>> _restore_fn_to_keys; | |||||
| private IDictionary<string, (RestoreFunc, RestoreFunc)> _registered_savers; | |||||
| private Dictionary<(string, string), RestoreFunc> _keys_to_restore_fn; | |||||
| private Dictionary<RestoreFunc, IList<(string, string)>> _restore_fn_to_keys; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -280,24 +257,28 @@ namespace Tensorflow.Checkpoint | |||||
| public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | public MultiDeviceSaver(IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors, | ||||
| IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_capture = false) | ||||
| { | { | ||||
| _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); | |||||
| _restore_fn_to_keys = new Dictionary<IFunctionHolder, IList<(string, string)>>(); | |||||
| _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); | |||||
| _restore_fn_to_keys = new Dictionary<RestoreFunc, IList<(string, string)>>(); | |||||
| Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | Dictionary<string, IDictionary<string, IDictionary<string, Tensor>>> tensors_by_device= new(); | ||||
| foreach(var pair in serialized_tensors) | foreach(var pair in serialized_tensors) | ||||
| { | { | ||||
| var obj = pair.Key; | var obj = pair.Key; | ||||
| var tensor_dict = pair.Value; | var tensor_dict = pair.Value; | ||||
| IFunctionHolder restore_fn; | |||||
| RestoreFunc restore_fn; | |||||
| if(obj == Trackable.None) | if(obj == Trackable.None) | ||||
| { | { | ||||
| restore_fn = new FunctionHolder<object?>(() => null); | |||||
| restore_fn = new RestoreFunc(x => null); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x => | |||||
| restore_fn = new RestoreFunc(x => | |||||
| { | { | ||||
| return obj._restore_from_tensors(x); | |||||
| if(x is IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) | |||||
| { | |||||
| return obj._restore_from_tensors(x as IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>); | |||||
| } | |||||
| throw new TypeError($"Expected `IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>` as input, got{x.GetType()}."); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -305,14 +286,14 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var checkpoint_key = item.Key; | var checkpoint_key = item.Key; | ||||
| IDictionary<string, Tensor> spec_to_tensor; | IDictionary<string, Tensor> spec_to_tensor; | ||||
| if(item.Value.DataType != typeof(IDictionary<string, Tensor>)) | |||||
| if(item.Value.TryGet<Tensor>(out var t)) | |||||
| { | { | ||||
| spec_to_tensor = new Dictionary<string, Tensor>(); | spec_to_tensor = new Dictionary<string, Tensor>(); | ||||
| spec_to_tensor[""] = item.Value.GetValueA(); | |||||
| spec_to_tensor[""] = t; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| spec_to_tensor = item.Value.GetValueB(); | |||||
| spec_to_tensor = item.Value.GetValue<IDictionary<string, Tensor>>(); | |||||
| } | } | ||||
| foreach(var spec in spec_to_tensor) | foreach(var spec in spec_to_tensor) | ||||
| @@ -342,7 +323,7 @@ namespace Tensorflow.Checkpoint | |||||
| _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); | ||||
| _registered_savers = new Dictionary<string, (IFunctionHolder, IFunctionHolder)>(); | |||||
| _registered_savers = new Dictionary<string, (RestoreFunc, RestoreFunc)>(); | |||||
| if(registered_savers is not null && registered_savers.Count > 0) | if(registered_savers is not null && registered_savers.Count > 0) | ||||
| { | { | ||||
| // TODO: complete the implementation. | // TODO: complete the implementation. | ||||
| @@ -418,8 +399,8 @@ namespace Tensorflow.Checkpoint | |||||
| IDictionary<string, Operation> restore_func() | IDictionary<string, Operation> restore_func() | ||||
| { | { | ||||
| Dictionary<IFunctionHolder, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
| Dictionary<IFunctionHolder, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||||
| Dictionary<RestoreFunc, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> restore_fn_inputs = new(); | |||||
| Dictionary<RestoreFunc, int> restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); | |||||
| Dictionary<string, Operation> restore_ops = new(); | Dictionary<string, Operation> restore_ops = new(); | ||||
| foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) | ||||
| @@ -449,7 +430,7 @@ namespace Tensorflow.Checkpoint | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| @@ -158,4 +158,4 @@ namespace Tensorflow | |||||
| Dispose(false); | Dispose(false); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -11,4 +11,4 @@ public class AssertionError : TensorflowException | |||||
| { | { | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -37,10 +37,10 @@ namespace Tensorflow.Train | |||||
| var properties = this.GetType().GetProperties(); | var properties = this.GetType().GetProperties(); | ||||
| foreach ( var property in properties ) | foreach ( var property in properties ) | ||||
| { | { | ||||
| string name = property.Name; | |||||
| object value = property.GetValue(this, null); | |||||
| if(value is Function || value is ConcreteFunction) | |||||
| if(property.PropertyType == typeof(Function) || property.PropertyType == typeof(ConcreteFunction)) | |||||
| { | { | ||||
| string name = property.Name; | |||||
| object value = property.GetValue(this, null); | |||||
| functions[name] = (Trackable)value; | functions[name] = (Trackable)value; | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,9 +25,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_op.DataType == typeof(Tensor)) | |||||
| if(_op.TryGet<Tensor>(out var tensor)) | |||||
| { | { | ||||
| return _op.GetValueA(); | |||||
| return tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -8,4 +8,4 @@ public record class AssetInfo | |||||
| Dictionary<object, object> asset_initializers_by_resource, | Dictionary<object, object> asset_initializers_by_resource, | ||||
| Dictionary<AssetInfo, string> asset_filename_map, | Dictionary<AssetInfo, string> asset_filename_map, | ||||
| Dictionary<object, object> asset_index | Dictionary<object, object> asset_index | ||||
| ); | |||||
| ); | |||||
| @@ -86,7 +86,7 @@ public class AugmentedGraphView: ObjectGraphView | |||||
| return concrete_function; | return concrete_function; | ||||
| } | } | ||||
| public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| public override (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal() | |||||
| { | { | ||||
| Trackable get_merged_trackable(Trackable x) | Trackable get_merged_trackable(Trackable x) | ||||
| { | { | ||||
| @@ -130,4 +130,4 @@ public class AugmentedGraphView: ObjectGraphView | |||||
| { | { | ||||
| return _children_cache[obj][name]; | return _children_cache[obj][name]; | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -30,4 +30,4 @@ public static class Constants | |||||
| public static readonly string VARIABLES_DIRECTORY = "variables"; | public static readonly string VARIABLES_DIRECTORY = "variables"; | ||||
| public static readonly string VARIABLES_FILENAME = "variables"; | public static readonly string VARIABLES_FILENAME = "variables"; | ||||
| } | |||||
| } | |||||
| @@ -14,4 +14,4 @@ public class RevivedTypes | |||||
| // TODO: complete the implementation. | // TODO: complete the implementation. | ||||
| return null; | return null; | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -6,4 +6,4 @@ public enum SaveType | |||||
| { | { | ||||
| SAVEDMODEL, | SAVEDMODEL, | ||||
| CHECKPOINT | CHECKPOINT | ||||
| } | |||||
| } | |||||
| @@ -18,13 +18,13 @@ public class SaveableView | |||||
| { | { | ||||
| private AugmentedGraphView _augmented_graph_view; | private AugmentedGraphView _augmented_graph_view; | ||||
| private SaveOptions _options; | private SaveOptions _options; | ||||
| private List<Trackable> _trackable_objects; | |||||
| private IList<Trackable> _trackable_objects; | |||||
| private List<Trackable> _nodes; | private List<Trackable> _nodes; | ||||
| private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||||
| private Dictionary<Trackable, int> _node_ids; | |||||
| private IDictionary<Trackable, IEnumerable<TrackableReference>> _node_paths; | |||||
| private IDictionary<Trackable, int> _node_ids; | |||||
| private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>> | ||||
| _slot_variables; | _slot_variables; | ||||
| private Dictionary<Trackable, string> _object_names; | |||||
| private IDictionary<Trackable, string> _object_names; | |||||
| private List<object> _gradient_functions; // to be completed | private List<object> _gradient_functions; // to be completed | ||||
| private List<RegisteredGradient> _gradient_defs; // to be completed | private List<RegisteredGradient> _gradient_defs; // to be completed | ||||
| private List<ConcreteFunction> _concrete_functions; | private List<ConcreteFunction> _concrete_functions; | ||||
| @@ -45,7 +45,7 @@ public class SaveableView | |||||
| { | { | ||||
| get => _nodes; | get => _nodes; | ||||
| } | } | ||||
| public Dictionary<Trackable, int> NodeIds | |||||
| public IDictionary<Trackable, int> NodeIds | |||||
| { | { | ||||
| get => _node_ids; | get => _node_ids; | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ public class SaveableView | |||||
| { | { | ||||
| get => _gradient_defs; | get => _gradient_defs; | ||||
| } | } | ||||
| public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||||
| public IDictionary<Trackable, IEnumerable<TrackableReference>> NodePaths | |||||
| { | { | ||||
| get => _node_paths; | get => _node_paths; | ||||
| } | } | ||||
| @@ -84,7 +84,7 @@ public class SaveableView | |||||
| private void initialize_nodes_and_concrete_functions() | private void initialize_nodes_and_concrete_functions() | ||||
| { | { | ||||
| _nodes = _trackable_objects.ConvertAll(x => x); // deep copy | |||||
| _nodes = _trackable_objects.ToList().ConvertAll(x => x); // deep copy | |||||
| _gradient_functions = new(); | _gradient_functions = new(); | ||||
| _gradient_defs = new(); | _gradient_defs = new(); | ||||
| @@ -296,4 +296,4 @@ public class SaveableView | |||||
| proto.Nodes.Add(object_proto); | proto.Nodes.Add(object_proto); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -7,4 +7,4 @@ public static class TagConstants | |||||
| public static readonly string EVAL = "eval"; | public static readonly string EVAL = "eval"; | ||||
| public static readonly string GPU = "gpu"; | public static readonly string GPU = "gpu"; | ||||
| public static readonly string TPU = "tpu"; | public static readonly string TPU = "tpu"; | ||||
| } | |||||
| } | |||||
| @@ -19,4 +19,4 @@ public class BuilderUtils | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -81,8 +81,8 @@ public static partial class SavedModelUtils | |||||
| return (saved_nodes, node_paths); | return (saved_nodes, node_paths); | ||||
| } | } | ||||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>, | |||||
| Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||||
| private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList<Trackable>, | |||||
| IDictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||||
| ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | ||||
| { | { | ||||
| using (SaveContext.save_context(options)) | using (SaveContext.save_context(options)) | ||||
| @@ -266,4 +266,4 @@ public static partial class SavedModelUtils | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -104,4 +104,4 @@ public class SignatureMap: Trackable | |||||
| return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -54,4 +54,4 @@ public static partial class SavedModelUtils | |||||
| { | { | ||||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY)); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -136,9 +136,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| full_name = name + "_" + attr; | full_name = name + "_" + attr; | ||||
| } | } | ||||
| if(factory.DataType == typeof(ResourceVariable)) | |||||
| if(factory.TryGet<BaseResourceVariable>(out var variable)) | |||||
| { | { | ||||
| var variable = factory.GetValueA(); | |||||
| foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) | ||||
| { | { | ||||
| yield return op; | yield return op; | ||||
| @@ -146,8 +145,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var variable = factory.GetValueB(); | |||||
| foreach (var op in saveable_objects_for_op(variable, variable.name)) | |||||
| var saveable = factory.GetValue<MySaveableObject>(); | |||||
| foreach (var op in saveable_objects_for_op(saveable, saveable.name)) | |||||
| { | { | ||||
| yield return op; | yield return op; | ||||
| } | } | ||||
| @@ -236,14 +235,14 @@ namespace Tensorflow | |||||
| string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | string spec_name = name + TrackableUtils.escape_local_name(tensor_name); | ||||
| IDictionary<string, Tensor> internal_dict; | IDictionary<string, Tensor> internal_dict; | ||||
| if(maybe_tensor.DataType == typeof(Tensor)) | |||||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| { | { | ||||
| internal_dict= new Dictionary<string, Tensor>(); | internal_dict= new Dictionary<string, Tensor>(); | ||||
| internal_dict[""] = maybe_tensor.GetValueA(); | |||||
| internal_dict[""] = tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict = maybe_tensor.GetValueB(); | |||||
| internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||||
| } | } | ||||
| foreach(var item in internal_dict) | foreach(var item in internal_dict) | ||||
| @@ -287,7 +286,7 @@ namespace Tensorflow | |||||
| var slice_spec = convert_to_string(spec.slice_spec); | var slice_spec = convert_to_string(spec.slice_spec); | ||||
| if (!string.IsNullOrEmpty(slice_spec)) | if (!string.IsNullOrEmpty(slice_spec)) | ||||
| { | { | ||||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValue<IDictionary<string, Tensor>>()[slice_spec] = spec.tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -318,14 +317,14 @@ namespace Tensorflow | |||||
| var maybe_tensor = restored_tensors[name]; | var maybe_tensor = restored_tensors[name]; | ||||
| IDictionary<string, Tensor> dict; | IDictionary<string, Tensor> dict; | ||||
| if(maybe_tensor.DataType == typeof(Tensor)) | |||||
| if(maybe_tensor.TryGet<Tensor>(out var tensor)) | |||||
| { | { | ||||
| dict = new Dictionary<string, Tensor>(); | dict = new Dictionary<string, Tensor>(); | ||||
| dict[""] = maybe_tensor.GetValueA(); | |||||
| dict[""] = tensor; | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| dict = maybe_tensor.GetValueB(); | |||||
| dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); | |||||
| } | } | ||||
| saveable_restored_tensors.Add(dict[slice_spec]); | saveable_restored_tensors.Add(dict[slice_spec]); | ||||
| } | } | ||||
| @@ -38,4 +38,4 @@ public static class Constants | |||||
| RNN_LAYER_IDENTIFIER, | RNN_LAYER_IDENTIFIER, | ||||
| SEQUENTIAL_IDENTIFIER | SEQUENTIAL_IDENTIFIER | ||||
| }; | }; | ||||
| } | |||||
| } | |||||
| @@ -1,11 +0,0 @@ | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public class KerasObjectWrapper | |||||
| { | |||||
| } | |||||
| public class KerasObjectWrapper<T> | |||||
| { | |||||
| public T Item { get; set; } | |||||
| } | |||||
| @@ -3,19 +3,15 @@ using System.Collections.Generic; | |||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using ICSharpCode.SharpZipLib.Zip; | |||||
| using Tensorflow.Checkpoint; | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Exceptions; | |||||
| using Tensorflow.IO; | |||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | namespace Tensorflow.Keras.Saving.SavedModel; | ||||
| @@ -108,5 +104,59 @@ public partial class KerasSavedModelUtils | |||||
| return metadata; | return metadata; | ||||
| } | } | ||||
| } | |||||
| public static bool should_skip_serialization(object layer) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns extra trackable objects to attach to the serialized layer. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
| { | |||||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||||
| var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| Dictionary<string, Trackable> res = new(); | |||||
| res["variables"] = variables; | |||||
| res["trainable_variables"] = trainable_variables; | |||||
| res["non_trainable_variables"] = non_trainable_variables; | |||||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| return res; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
| { | |||||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||||
| // skip the process because of lack of APIs of `Layer`. | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | |||||
| } | |||||
| @@ -1,66 +0,0 @@ | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Tensorflow.Keras.Engine; | |||||
| using Tensorflow.Train; | |||||
| using Tensorflow.Training; | |||||
| namespace Tensorflow.Keras.Saving.SavedModel; | |||||
| public partial class KerasSavedModelUtils | |||||
| { | |||||
| public static bool should_skip_serialization(object layer) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns extra trackable objects to attach to the serialized layer. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
| { | |||||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||||
| var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||||
| { | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||||
| })); | |||||
| Dictionary<string, Trackable> res = new(); | |||||
| res["variables"] = variables; | |||||
| res["trainable_variables"] = trainable_variables; | |||||
| res["non_trainable_variables"] = non_trainable_variables; | |||||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||||
| return res; | |||||
| } | |||||
| /// <summary> | |||||
| /// Returns dict of wrapped layer call function and losses in tf.functions. | |||||
| /// </summary> | |||||
| /// <param name="layer"></param> | |||||
| /// <param name="serialization_cache"></param> | |||||
| /// <returns></returns> | |||||
| public static IDictionary<string, Trackable> wrap_layer_functions(Layer layer, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||||
| { | |||||
| // TODO: deal with type `RevivedLayer` and `Sequential`. | |||||
| // skip the process because of lack of APIs of `Layer`. | |||||
| return new Dictionary<string, Trackable>(); | |||||
| } | |||||
| } | |||||
| @@ -34,5 +34,4 @@ public abstract class SavedModelSaver | |||||
| return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) | ||||
| .ToDictionary(x => x.Key, x => x.Value); | .ToDictionary(x => x.Key, x => x.Value); | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -162,4 +162,4 @@ public class InputLayerSavedModelSaver: SavedModelSaver | |||||
| return JsonConvert.SerializeObject(info); | return JsonConvert.SerializeObject(info); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -44,4 +44,4 @@ public class SaveOptionsContext: IDisposable | |||||
| { | { | ||||
| KerasSavedModelUtils.ShouldHaveTraces = _old_value; | KerasSavedModelUtils.ShouldHaveTraces = _old_value; | ||||
| } | } | ||||
| } | |||||
| } | |||||
| @@ -73,7 +73,7 @@ public class SequentialModelTest | |||||
| { | { | ||||
| TrainDir = "mnist", | TrainDir = "mnist", | ||||
| OneHot = false, | OneHot = false, | ||||
| ValidationSize = 10000, | |||||
| ValidationSize = 50000, | |||||
| }).Result; | }).Result; | ||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
| @@ -119,13 +119,13 @@ public class SequentialModelTest | |||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | ||||
| var num_epochs = 1; | var num_epochs = 1; | ||||
| var batch_size = 16; | |||||
| var batch_size = 8; | |||||
| var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); | ||||
| model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); | ||||
| model.save("./pb_elex_sequential", save_format: "tf"); | |||||
| model.save("./pb_alex_sequential", save_format: "tf"); | |||||
| // The saved model can be test with the following python code: | // The saved model can be test with the following python code: | ||||
| #region alexnet_python_code | #region alexnet_python_code | ||||
| @@ -136,7 +136,7 @@ public class SequentialModelTest | |||||
| // return -a | // return -a | ||||
| //if __name__ == '__main__': | //if __name__ == '__main__': | ||||
| // model = tf.keras.models.load_model("./pb_elex_sequential") | |||||
| // model = tf.keras.models.load_model("./pb_alex_sequential") | |||||
| // model.summary() | // model.summary() | ||||
| // num_classes = 5 | // num_classes = 5 | ||||