From 83906b8f798d7faa99784da7d66489ca51dae4fd Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Mon, 30 Jan 2023 13:42:51 +0800 Subject: [PATCH] Add lacked implementations (mainly MultiDeviceSaver). --- .../Checkpoint/CheckpointOptions.cs | 2 +- .../Checkpoint/ObjectGraphView.cs | 9 +- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 23 +- .../Checkpoint/SaveUtilV1.cs | 27 +- .../Checkpoint/TrackableView.cs | 5 +- .../Checkpoint/checkpoint.cs | 9 +- .../Checkpoint/functional_saver.cs | 515 +++++++++++++++++- .../SavedModel/ISerializedAttributes.cs | 35 ++ .../Training/AutoTrackable.cs | 3 +- .../Saving/SavedModel/AugmentedGraphView.cs | 109 +++- .../Saving/SavedModel/SaveableView.cs | 6 +- .../Training/Saving/SavedModel/save.cs | 16 +- .../SavedModel/signature_serialization.cs | 99 +++- .../Saving/saveable_object_util.py.cs | 156 +++++- src/TensorFlowNET.Core/Training/Trackable.cs | 48 +- .../Training/TrackableUtils.cs | 28 +- .../Training/data_structures.cs | 3 +- .../Variables/BaseResourceVariable.cs | 3 + .../Variables/ResourceVariable.cs | 9 + src/TensorFlowNET.Keras/Engine/Functional.cs | 3 +- .../Engine/Layer.Serialize.cs | 7 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 24 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.cs | 3 +- .../Saving/SavedModel/Save.cs | 9 +- .../Saving/SavedModel/SaveImpl.cs | 4 +- .../Saving/SavedModel/base_serialization.cs | 7 +- .../Saving/SavedModel/layer_serialization.cs | 28 +- .../SavedModel/serialized_attributes.cs | 2 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 4 +- 30 files changed, 1037 insertions(+), 161 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs index d8297ea3..f14b5ce7 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs @@ -1,5 +1,5 @@ namespace Tensorflow.Checkpoint; public record class CheckpointOptions( - string experimental_io_device = null, + string? experimental_io_device = null, bool experimental_enable_async_checkpoint = false); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs index 2ad55448..cb01b539 100644 --- a/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Serilog.Debugging; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow.Checkpoint; @@ -21,9 +22,9 @@ public class ObjectGraphView: TrackableView, ICloneable return new ObjectGraphView(Root, _attached_dependencies); } - public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual List list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - List res = base.children(obj, save_type) + List res = base.children(obj, save_type, serialization_cache) .Select(x => new TrackableReference(x.Key, x.Value)).ToList(); // Check the reference, not value. if (obj == Root && _attached_dependencies is not null) @@ -34,9 +35,9 @@ public class ObjectGraphView: TrackableView, ICloneable return res; } - public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public override IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? serialization_cache = null) { - return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer); + return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer); } public IEnumerable? AttachedDependencies diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index dc2a92fb..e646f1f0 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -117,16 +117,16 @@ namespace Tensorflow.Checkpoint /// /// /// - private static IDictionary> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) { - Dictionary> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach(var td in tensor_trackables) { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; var trackable = td.object_to_save; - IDictionary tensor_dict; + IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); @@ -147,12 +147,12 @@ namespace Tensorflow.Checkpoint return serialized_tensors; } - private static IDictionary get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { var trackable = trackable_data.object_to_save; // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. - IDictionary ret_tensor_dict; + IDictionary>> ret_tensor_dict; if (call_with_mapped_captures) { throw new NotImplementedException(); @@ -162,8 +162,8 @@ namespace Tensorflow.Checkpoint ret_tensor_dict = trackable.serialize_to_tensors(); } - // TODO: revise the types and complete it - Dictionary tensor_dict = new(); + // TODO: deal with the type `SaveSpce` (currently it will never be it). + Dictionary>> tensor_dict = new(); foreach(var pair in ret_tensor_dict) { var local_name = TrackableUtils.escape_local_name(pair.Key); @@ -172,9 +172,10 @@ namespace Tensorflow.Checkpoint tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor is SaveSpec) + if(maybe_tensor.GetValueA() is SaveSpec) { - ((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + throw new NotImplementedException(); + //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; } if(object_graph_proto is not null) @@ -198,7 +199,7 @@ namespace Tensorflow.Checkpoint /// /// /// - private static (Trackable, IDictionary) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { Dictionary object_names = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 44fa5c5d..d8e251ec 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -174,25 +174,20 @@ public static class SaveUtilV1 { var name = factory_data.name; var key = factory_data.checkpoint_key; - var saveable_factory = factory_data.factory; - + var maybe_saveable = factory_data.factory; + // TODO: oneflow python has a process with callable `saveable_factory`. - var maybe_saveable = saveable_factory; - IEnumerable savesbles; - if (maybe_saveable is MySaveableObject) - { - savesbles = new List() { (MySaveableObject)maybe_saveable }; - } - else if (maybe_saveable is Tensor) + List saveables = new(); + if (maybe_saveable.DataType == typeof(MySaveableObject)) { - savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key); + saveables.Add(maybe_saveable.GetValueB()); } else { - throw new TypeError("Unexpected type."); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValueA() as Trackable, key)); } - foreach (var saveable in savesbles) + foreach (var saveable in saveables) { if (!saveable.name.Contains(key)) { @@ -204,11 +199,11 @@ public static class SaveUtilV1 // skip the process of PythonState - named_saveable_objects.AddRange(savesbles); + named_saveable_objects.AddRange(saveables); if(!fill_object_proto) continue; - - // skip the process of TrackableSaveable + + // skip the process of `TrackableSaveable` because of lack of APIs. object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor() { Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) }); @@ -221,7 +216,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - object factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs index 69bf76fd..f89dc10d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs +++ b/src/TensorFlowNET.Core/Checkpoint/TrackableView.cs @@ -2,6 +2,7 @@ using Tensorflow.Train; using System.Collections.Generic; using System.IO; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow.Checkpoint; @@ -18,13 +19,13 @@ public class TrackableView _root_ref = obj; } - public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT) + public virtual IDictionary children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { obj._maybe_initialize_trackable(); Dictionary children = new(); // Note: in python the return type of `Trackable._trackable_children` is not fixed. // Therefore it uses `convert_to_trackable` to have an extra process. - foreach (var pair in obj._trackable_children(save_type)) + foreach (var pair in obj._trackable_children(save_type, cache)) { children[pair.Key] = pair.Value; } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 79109489..c9bee0db 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public class TrackableSaver } - private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -125,7 +125,7 @@ public class TrackableSaver } Dictionary feed_dict = new(); - bool use_session = (!new Context().executing_eagerly() && !ops.inside_function()); + bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}"; @@ -133,6 +133,7 @@ public class TrackableSaver Tensor file_prefix_tensor; Tensor object_graph_tensor; + string file_prefix_to_save; if (use_session) { if (_object_graph_feed_tensor is null) @@ -145,16 +146,18 @@ public class TrackableSaver object_graph_tensor = _object_graph_feed_tensor; file_prefix_tensor = _file_prefix_feed_tensor; feed_dict[file_prefix_tensor] = file_prefix; + file_prefix_to_save = ""; } else { // In python there is `with ops.device("/cpu:0")`. file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING); object_graph_tensor = null; + file_prefix_to_save = file_prefix; } var (save_path, new_feed_additions) = - save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options); + save_cached_when_graph_building(file_prefix_to_save, object_graph_tensor, options); if (new_feed_additions is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 759cbd66..c4a03985 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -6,9 +6,254 @@ using Tensorflow.Train; using static Tensorflow.ApiDef.Types; using static Tensorflow.CostGraphDef.Types; using static Tensorflow.OptimizerOptions.Types; +using static Tensorflow.Binding; +using System.Text.RegularExpressions; +using System.Linq; +using Tensorflow.Operations; +using Tensorflow.Training; +using Tensorflow.Graphs; namespace Tensorflow.Checkpoint { + /// + /// `FunctionHolder` is a series of containers to help dynamically call some dotnet functions. + /// Note that this API does not gurantee performance. Besides, it is not supposed to be exposed to users. + /// + public interface IFunctionHolder + { + int ArgCount { get; } + object DynamicInvoke(params object[] args); + } + internal record class FunctionHolder(Func Func): IFunctionHolder + { + public int ArgCount => 0; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 1; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 2; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + internal record class FunctionHolder(Func Func) : IFunctionHolder + { + public int ArgCount => 3; + public object DynamicInvoke(params object[] args) + { + return Func.DynamicInvoke(args); + } + } + public class Maybe + { + private TA? _valueA = default(TA); + private TB? _valueB = default(TB); + private Type _type; + private bool _assigned = false; + public Maybe(TA value) + { + _valueA = value; + _type= typeof(TA); + _assigned = true; + } + public Maybe(TB value) + { + _valueB = value; + _type = typeof(TB); + _assigned = true; + } + + public Type DataType => _type; + + public TA GetValueA() + { + if(!_assigned || DataType != typeof(TA)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueA; + } + public TB GetValueB() + { + if (!_assigned || DataType != typeof(TB)) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + return _valueB; + } + public object GetValue() + { + if (!_assigned) + { + throw new TypeError("Cannot get the data because of wrong specified type."); + } + if(DataType == typeof(TA) && _valueA is not null) + { + return _valueA; + } + else if(DataType == typeof(TB) && _valueB is not null) + { + return _valueB; + } + else if(DataType == typeof(TA)) + { + return _valueA; + } + else + { + return _valueB; + } + } + + public static implicit operator Maybe(TA a) + { + return new Maybe(a); + } + public static implicit operator Maybe(TB b) + { + return new Maybe(b); + } + } + internal class SingleDeviceSaver + { + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict; + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public SingleDeviceSaver(IDictionary> tensor_slice_dict) + { + _tensor_slice_dict = tensor_slice_dict.ToDictionary( + x => x.Key, x => x.Value.ToDictionary( + y => y.Key, y => new Maybe(y.Value)) + as IDictionary>); + } + public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensors = new(); + List slice_specs = new(); + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + var tensor_value = spec.tensor; + if (tensor_value is not null) + { + tensor_names.Add(spec.name); + tensors.Add(tensor_value); + slice_specs.Add(spec.slice_spec); + } + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_names.Add(checkpoint_key); + tensors.Add(tensor); + slice_specs.Add(slice_spec); + } + } + } + // TODO: specify the device. + return tf.io.save_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensors.ToArray()); + } + + public Operation? save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix, TF_DataType.TF_STRING), options); + + public IDictionary> restore(Tensor file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + List tensor_names = new(); + List tensor_dtypes = new(); + List slice_specs = new(); + + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice in tensor_slices) + { + var slice_spec = slice.Key; + var maybe_tensor = slice.Value; + // TODO: deal with other types. Currently only `SaveSpec` is allowed. + if(maybe_tensor.DataType == typeof(SaveSpec)) + { + var spec = maybe_tensor.GetValueB(); + tensor_dtypes.Add(spec.dtype); + slice_specs.Add(spec.slice_spec); + tensor_names.Add(spec.name); + } + else + { + var tensor = maybe_tensor.GetValueA(); + tensor_dtypes.Add(tensor.dtype); + slice_specs.Add(slice_spec); + tensor_names.Add(checkpoint_key); + } + } + } + + string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; + + // tf python has code `with ops.device(restore_device):` here. + tf.device(restore_device); // may be risky. + var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + + Dictionary> restored_tensor_dict = new(); + int idx = 0; + foreach(var pair in _tensor_slice_dict) + { + var checkpoint_key = pair.Key; + var tensor_slices = pair.Value; + foreach(var slice_spec in tensor_slices.Keys) + { + var restored_tensor = restored_tensors[idx++]; + if (!restored_tensor_dict.ContainsKey(checkpoint_key)) + { + restored_tensor_dict[checkpoint_key] = new Dictionary(); + } + restored_tensor_dict[checkpoint_key][slice_spec] = restored_tensor; + } + } + return restored_tensor_dict; + } + + public IDictionary> restore(string file_prefix, CheckpointOptions? options = null) => restore(tf.constant(file_prefix)); + } /// /// Saves checkpoints directly from multiple devices. /// Note that this is a low-level utility which stores Tensors in the keys @@ -17,20 +262,280 @@ namespace Tensorflow.Checkpoint /// public class MultiDeviceSaver { - public MultiDeviceSaver(IDictionary> serialized_tensors, + private Dictionary _single_device_savers; + private IDictionary _registered_savers; + private Dictionary<(string, string), IFunctionHolder> _keys_to_restore_fn; + private Dictionary> _restore_fn_to_keys; + /// + /// + /// + /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. + /// + /// + public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { + _keys_to_restore_fn = new Dictionary<(string, string), IFunctionHolder>(); + _restore_fn_to_keys = new Dictionary>(); + Dictionary>> tensors_by_device= new(); + + foreach(var pair in serialized_tensors) + { + var obj = pair.Key; + var tensor_dict = pair.Value; + IFunctionHolder restore_fn; + if(obj is null) + { + restore_fn = new FunctionHolder(() => null); + } + else + { + restore_fn = null; + // TODO: implement obj._restore_from_tensors + } + + foreach(var item in tensor_dict) + { + var checkpoint_key = item.Key; + IDictionary spec_to_tensor; + if(item.Value.DataType != typeof(IDictionary)) + { + spec_to_tensor = new Dictionary(); + spec_to_tensor[""] = item.Value.GetValueA(); + } + else + { + spec_to_tensor = item.Value.GetValueB(); + } + + foreach(var spec in spec_to_tensor) + { + var slice_spec = spec.Key; + var tensor = spec.Value; + if(_keys_to_restore_fn.ContainsKey((checkpoint_key, slice_spec))) + { + throw new ValueError("Recieved multiple tensors with the same checkpoint key and " + + $"slice spec. This is invalid because one will overwrite the " + + $"other in the checkpoint. This indicates a bug in the Checkpoint key-generation."); + } + _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; + _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); + + // skip the process of device name because lack of API. + var host_device = tensor.Device; + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>()); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + internal_dict[checkpoint_key] = new Dictionary(); + } + internal_dict[checkpoint_key][slice_spec] = tensor; + } + } + } + + _single_device_savers = tensors_by_device.ToDictionary(x => x.Key, x => new SingleDeviceSaver(x.Value)); + _registered_savers = new Dictionary(); + if(registered_savers is not null && registered_savers.Count > 0) + { + // TODO: complete the implementation. + throw new NotImplementedException(); + } } - public Operation? save(string file_prefix, CheckpointOptions? options= null) + public Operation save(string file_prefix, CheckpointOptions? options= null) { - throw new NotImplementedException(); + if(options is null) + { + options = new CheckpointOptions(); + } + + tf.device("CPU"); // may be risky. + // TODO: optimize the implementation with new APIs adding to `string_ops`. + string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; + var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + + Operation save_fn() + { + List saved_prefixes= new(); + foreach(var saver in _registered_savers) + { + // TODO: implementi it later. + throw new NotImplementedException(); + } + + int num_shards = _single_device_savers.Count; + List sharded_saves = new(); + var num_shards_tensor = constant_op.constant(num_shards, name: "num_shards"); + string? last_device = null; + int shard = 0; + foreach(var pair in _single_device_savers.OrderBy(x => x.Key)) + { + var device = pair.Key; + var saver = pair.Value; + last_device = device; + // skip the extra process of device name because of lack of API. + tf.device(device); + var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + saved_prefixes.Add(shard_prefix); + sharded_saves.Add(saver.save(shard_prefix, options)); + } + using (var controller = ops.control_dependencies(sharded_saves.ToArray())) + { + string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; + tf.device(merge_device); + return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + } + } + + if(tf.Context.executing_eagerly() && _single_device_savers.Count > 1) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return save_fn(); + } } - public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) + public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + + public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + { + if(options is null) + { + options = new CheckpointOptions(); + } + + IDictionary restore_func() + { + Dictionary>>> restore_fn_inputs = new(); + Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); + Dictionary restore_ops = new(); + + foreach(var single_saver in _single_device_savers.OrderBy(x => x.Key)) + { + var device = single_saver.Key; + var saver = single_saver.Value; + tf.device(device); + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach(var pair in restored_tensor_dict) + { + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach(var item in slice_and_tensor) + { + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) + { + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(dict); + } + else + { + internal_dict[checkpoint_key].GetValueB()[slice_spec] = tensor; + } + } + else + { + internal_dict[checkpoint_key] = new Maybe>(tensor); + } + restore_fn_input_count[restore_fn]--; + + if (restore_fn_input_count[restore_fn] == 0) + { + Dictionary>> restored_tensors = new(); + foreach(var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if(ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } + } + } + } + } + + foreach(var item in _registered_savers) + { + throw new NotImplementedException(); + } + return restore_ops; + } + + // TODO: complete the implementation. Currently skip it because of lack of API. + bool has_custom_device_saver = false; + + if (tf.Context.executing_eagerly() && (_single_device_savers.Count > 1 || has_custom_device_saver)) + { + // TODO: implement it. Currently `autograph` does not support the function with non parameter. + throw new NotImplementedException(); + } + else + { + return restore_func(); + } + } + + /// + /// Serializes to a SaverDef referencing the current graph. + /// + public SaverDef to_proto() + { + var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); + var save_tensor = _traced_save(filename_tensor); + var restore_op = _traced_restore(filename_tensor).op; + return new SaverDef() + { + FilenameTensorName = filename_tensor.name, + SaveTensorName = save_tensor.name, + RestoreOpName = restore_op.name, + Version = SaverDef.Types.CheckpointFormatVersion.V2 + }; + } + + [AutoGraph] + private Tensor _traced_save(Tensor file_prefix) + { + var save_op = save(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[]{ save_op })) + { + return array_ops.identity(file_prefix); + } + } + + [AutoGraph] + private Tensor _traced_restore(Tensor file_prefix) + { + var restore_op = restore(file_prefix.StringData()[0]); + tf.device("cpu:0"); + using (ops.control_dependencies(new object[] { restore_op })) + { + return array_ops.identity(file_prefix); + } + } + + private static Tensor registered_saver_filename(string filename, string saver_name) + { + return tf.constant($"{filename}-{saver_name}"); + } + private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - throw new NotImplementedException(); + return filename_tensor; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs new file mode 100644 index 00000000..ae8a1ab1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/SavedModel/ISerializedAttributes.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using Tensorflow.Train; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public interface ISerializedAttributes + { + IDictionary Functions { get; } + + IDictionary CheckpointableObjects { get; } + + /// + /// Returns functions to attach to the root object during serialization. + /// + IDictionary FunctionsToSerialize { get; } + + /// + /// Returns objects to attach to the root object during serialization. + /// + IDictionary ObjectsToSerialize{get; } + + /// + /// Saves function dictionary, and validates dictionary values. + /// + /// + IDictionary set_and_validate_functions(IDictionary function_dict); + + /// + /// Saves objects to a dictionary, and validates the values. + /// + /// + IDictionary set_and_validate_objects(IDictionary object_dict); + } +} diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 5dd9784f..4d5a664e 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -24,7 +25,7 @@ namespace Tensorflow.Train } } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if(save_type != SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs index 82da2ee9..97162651 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs @@ -4,57 +4,130 @@ using Tensorflow.Train; using System.Collections.Generic; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; namespace Tensorflow; public class AugmentedGraphView: ObjectGraphView { - // private object _children_cache; - // private object _serialization_cache; + private Dictionary> _children_cache; + private Dictionary> _serialization_cache; private List _untraces_functions; + private Dictionary _wrapped_functions; public AugmentedGraphView(Trackable root): base(root) { - _untraces_functions = new(); + _children_cache= new Dictionary>(); + _serialization_cache = new Dictionary>(); + _untraces_functions = new List(); + _wrapped_functions = new Dictionary(); } - public void set_signature(object signature_map, object wrapped_functions) + public void set_signature(SignatureMap signature_map, IDictionary wrapped_functions) { - // TODO: cache list_children(Root); + var name = SignatureSerializationUtils.SIGNATURE_ATTRIBUTE_NAME; + if (!_children_cache.ContainsKey(Root)) + { + _children_cache[Root] = new Dictionary(); + } + _children_cache[Root][name] = signature_map; + _wrapped_functions = _wrapped_functions.Concat(wrapped_functions).ToDictionary(x => x.Key, x => x.Value); } - public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL) + public override List list_children(Trackable obj, SaveType save_type = SaveType.SAVEDMODEL, IDictionary>? serialization_cache = null) { - Dictionary children = new(); - foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL)) + if(serialization_cache is not null) + { + throw new ValueError("Serialization cache should not be passed to `AugmentedGraphView.list_children`, please either remove the parameter or use `ObjectGraphView.list_children`."); + } + + if (!_children_cache.ContainsKey(obj)) + { + Dictionary children = new Dictionary(); + _children_cache[obj] = children; + foreach (var pair in base.list_children(obj, SaveType.SAVEDMODEL, _serialization_cache)) + { + var name = pair.Name; + var child = pair.Refer; + if(child is ConcreteFunction) + { + child = maybe_uncache_variable_captures((ConcreteFunction)child); + } + children[name] = child; + } + + if (obj is Function && children.Count == 0) + { + _untraces_functions.Add(((Function)obj).Name); + } + } + + List res = new(); + foreach(var pair in _children_cache[obj]) { - var name = pair.Name; - var child = pair.Refer; - children[name] = child; + res.Add(new TrackableReference(pair.Key, pair.Value)); } - if (obj is Function && children.Count == 0) + return res; + } + + private ConcreteFunction maybe_uncache_variable_captures(ConcreteFunction concrete_function) + { + if (_wrapped_functions.ContainsKey(concrete_function)) { - _untraces_functions.Add(((Function)obj).Name); + return _wrapped_functions[concrete_function]; } + // skip the process here because of lack of feature. + // In the future, we may add an attribute which could specify if the variable is supposed to be cached. + //foreach(var capture in concrete_function.CapturedInputs) + //{ - return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList(); + //} + return concrete_function; } public override (List, Dictionary>) breadth_first_traversal() { - // TODO: implement it if needed. + Trackable get_merged_trackable(Trackable x) + { + // TODO: complete it with new definitions `Asset` and `TrackableConstant`. + return x; + } + var trackable_objects = base.breadth_first_traversal(); + + foreach(var obj in _children_cache.Keys) + { + // skip the deletion of cache (maybe do it later). + foreach(var pair in _children_cache[obj]) + { + _children_cache[obj][pair.Key] = get_merged_trackable(pair.Value); + } + } + return base.breadth_first_traversal(); } public List<(string, Trackable)> list_dependencies(Trackable obj) { - // TODO: deal with cache. - return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList(); + IDictionary children; + if (!_children_cache.ContainsKey(obj)) + { + children= new Dictionary(); + } + else + { + children= _children_cache[obj]; + } + List<(string, Trackable)> res = new(); + foreach(var pair in obj.deserialization_dependencies(children)) + { + res.Add((pair.Key, pair.Value)); + } + return res; } public Trackable get_child(Trackable obj, string name) { - throw new NotImplementedException(); + return _children_cache[obj][name]; } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6a241f0e..6700e277 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -141,16 +141,16 @@ public class SaveableView foreach (var node in _nodes) { var node_id = _node_ids[node]; - List deps = new(); + List deps = new List(); + dependency_map.Add(node_id, deps); // TODO: deal with captured tensor. - string node_path; foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node)) { if (!_node_ids.ContainsKey(dep)) { - node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); + var node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]); throw new ValueError( $"Found an untracked dependency. Object {node_path} depends on {dep}, " + $"but this dependency isn't listed as a child. Please track this child by " + diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index cc839952..f3f273b8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -24,7 +24,7 @@ public static partial class SavedModelUtils }.Select(x => (int)x); public static (IList, IDictionary>) save_and_return_nodes(Trackable obj, - string export_dir, IDictionary? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) + string export_dir, ConcreteFunction? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false) { if (options is null) { @@ -41,9 +41,9 @@ public static partial class SavedModelUtils if (!experimental_skip_checkpoint) { - Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir); + SavedModelUtils.get_or_create_variables_dir(export_dir); CheckpointOptions ckpt_options = new(options.experimental_io_device); - object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); } BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); @@ -67,7 +67,7 @@ public static partial class SavedModelUtils } var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); - File.WriteAllText(path, saved_model.ToString()); + File.WriteAllBytes(path, saved_model.ToByteArray()); if (options.save_debug_info) { @@ -81,7 +81,7 @@ public static partial class SavedModelUtils private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List, Dictionary>) _build_meta_graph(Trackable obj, - IDictionary? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) + ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { if (ops.inside_function()) { @@ -95,9 +95,9 @@ public static partial class SavedModelUtils } AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is not null) + if (signatures is null) { - throw new NotImplementedException(); + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); } // TODO: process of aignatures and wrapped_functions @@ -125,7 +125,7 @@ public static partial class SavedModelUtils } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, - IDictionary signatures, IEnumerable namespace_whitelist, + ConcreteFunction signatures, IEnumerable namespace_whitelist, bool save_custom_gradients) { var resource_initializers = saveable_view.get_concrete_resource_initializers(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs index 21272941..0d34907f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs @@ -1,15 +1,84 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Functions; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; namespace Tensorflow; +public static class SignatureSerializationUtils +{ + internal static readonly string DEFAULT_SIGNATURE_ATTR = "_default_save_signature"; + internal static readonly string SIGNATURE_ATTRIBUTE_NAME = "signatures"; + internal static readonly int _NUM_DISPLAY_NORMALIZED_SIGNATURES = 5; + public static SignatureMap create_signature_map(IDictionary signatures) + { + var signature_map = new SignatureMap(); + foreach (var pair in signatures) + { + var name = pair.Key; + var func = pair.Value; + Debug.Assert(func is ConcreteFunction); + // TODO: assert the `func.structured_outputs` and arg_keywords. + signature_map._add_signature(name, (ConcreteFunction)func); + } + + return signature_map; + } + + public static ConcreteFunction find_function_to_export(AugmentedGraphView graph_view) + { + var children = graph_view.list_children(graph_view.Root); + List possible_signatures = new(); + foreach (var item in children) + { + var name = item.Name; + var child = item.Refer; + if(child is not (Function or ConcreteFunction)) + { + continue; + } + if(name == DEFAULT_SIGNATURE_ATTR) + { + Debug.Assert(child is ConcreteFunction); + return (ConcreteFunction)child; + } + ConcreteFunction concrete = get_signature(child); + if(concrete is not null && valid_signature(concrete)) + { + possible_signatures.Add(concrete); + } + } + + if(possible_signatures.Count == 1) + { + var signature = get_signature(possible_signatures[0]); + if(signature is not null && valid_signature(signature)) + { + return signature; + } + } + return null; + } + + private static ConcreteFunction get_signature(Trackable function) + { + // TODO: implement it. + return null; + } + + private static bool valid_signature(ConcreteFunction concreate_function) + { + // TODO: implement it. + return false; + } +} + public class SignatureMap: Trackable { - private Dictionary _signatures; - private Dictionary _concrete_signatures; + private Dictionary _signatures; public SignatureMap() { @@ -18,7 +87,7 @@ public class SignatureMap: Trackable public void _add_signature(string name, ConcreteFunction concrete_function) { - _concrete_signatures[name] = concrete_function; + _signatures[name] = concrete_function; } public void _add_signature(string name, Function concrete_function) @@ -26,33 +95,13 @@ public class SignatureMap: Trackable _signatures[name] = concrete_function; } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if (save_type != SaveType.SAVEDMODEL) { return new Dictionary(); } - Dictionary res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value); - foreach (var pair in _concrete_signatures) - { - res[pair.Key] = pair.Value; - } - - return res; - } - - public static SignatureMap create_signature_map(IDictionary signatures) - { - var signature_map = new SignatureMap(); - foreach (var pair in signatures) - { - var name = pair.Key; - var func = pair.Value; - // TODO: assert the arg_keywords - signature_map._add_signature(name, func); - } - - return signature_map; + return _signatures.TakeWhile(x => x.Value is Function or ConcreteFunction).ToDictionary(x => x.Key, x => x.Value); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 622eed3a..7066b366 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -16,18 +16,38 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; using Tensorflow.Train; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow { - public static class saveable_object_util + /// + /// A SaveableObject that defines `Trackable` checkpointing steps. + /// + public class TrackableSaveable : MySaveableObject { - public class TrackableSaveable: MySaveableObject + private string _prefix; + private IEnumerable _local_names; + private Trackable _trackable; + private bool _call_with_mapped_captures; + // TODO: revise the implementation. Currently the parameter of constructor of this class and its base class has conflict. + public TrackableSaveable(Trackable obj, IEnumerable specs, string name, IEnumerable local_names, + string prefix, bool call_with_mapped_captures = false) : base((object)obj as Tensor, specs.ToArray(), name) { - + _prefix = prefix; + _trackable = obj; + _local_names = local_names; + _call_with_mapped_captures = call_with_mapped_captures; } + + // TODO: complete this class. + } + public static class saveable_object_util + { /// /// Returns the variables and names that will be used for a Saver. /// @@ -57,7 +77,7 @@ namespace Tensorflow } /// - /// Create `SaveableObject`s from an operation. + /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// /// /// @@ -79,6 +99,74 @@ namespace Tensorflow } } + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(Trackable obj, string name) + { + // The `op` maybe `Variable` or `Trackable`. + if (obj is BaseResourceVariable) + { + var variable = obj as BaseResourceVariable; + if (variable.InGraphMode) + { + yield return new ResourceVariableSaveable(variable.GraphElement, "", name); + } + else + { + Debug.Assert(variable is ResourceVariable); + yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + } + } + else + { + foreach(var pair in saveable_objects_from_trackable(obj)) + { + var attr = pair.Key; + var factory = pair.Value; + string full_name; + if(attr == Trackable.Constants.VARIABLE_VALUE_KEY) + { + full_name = name; + } + else + { + full_name = name + "_" + attr; + } + if(factory.DataType == typeof(ResourceVariable)) + { + var variable = factory.GetValueA(); + foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name)) + { + yield return op; + } + } + else + { + var variable = factory.GetValueB(); + foreach (var op in saveable_objects_for_op(variable, variable.name)) + { + yield return op; + } + } + } + } + } + + /// + /// Create `SaveableObject`s from an operation. + /// + /// + /// + /// + public static IEnumerable saveable_objects_for_op(MySaveableObject obj, string name) + { + yield return obj; + } + public static Dictionary op_list_to_dict(IVariableV1[] op_list, bool convert_variable_to_tensor = true) { op_list = op_list.OrderBy(x => x.Name).ToArray(); @@ -127,16 +215,55 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { - // TODO: complete the implementation. - return obj.gather_saveables_for_checkpoint(); + // skip the process of type `PythonState` + + if (trackable_has_serialize_to_tensor(obj)) + { + var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME; + // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. + var tensor_dict = obj.serialize_to_tensors(); + + List specs = new(); + List local_names = new(); + string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; + foreach(var pair in tensor_dict) + { + var tensor_name = pair.Key; + var maybe_tensor = pair.Value; + local_names.Add(tensor_name); + string spec_name = name + TrackableUtils.escape_local_name(tensor_name); + + IDictionary internal_dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + internal_dict= new Dictionary(); + internal_dict[""] = maybe_tensor.GetValueA(); + } + else + { + internal_dict = maybe_tensor.GetValueB(); + } + + foreach(var item in internal_dict) + { + specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); + } + } + Dictionary> res = new(); + res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); + return res; + } + else + { + return obj.gather_saveables_for_checkpoint(); + } } public static bool trackable_has_serialize_to_tensor(Trackable obj) { - // TODO: implement it. - return false; + return obj.GetType().GetMethod("serialize_to_tensors").DeclaringType != typeof(Trackable); } internal static string convert_to_string(string x) @@ -158,27 +285,28 @@ namespace Tensorflow public Trackable Obj => _obj; public IList mySaveables=> _saveables; - public override IDictionary serialize_to_tensors() + public override IDictionary>> serialize_to_tensors() { - return saveable_objects_to_tensor_dict(_saveables); + return saveable_object_to_tensor_dict(_saveables); } /// /// Converts a list of SaveableObjects to a tensor dictionary. /// /// - public static Dictionary saveable_objects_to_tensor_dict(IList saveables) + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) { - Dictionary tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { foreach(var spec in saveable.specs) { + // skip the check that if `spec` is callable. var name = saveable_object_util.convert_to_string(spec.name); var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - throw new NotImplementedException(); + tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; } else { diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 2646fb8d..a677044a 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -16,7 +16,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Checkpoint; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; using Tensorflow.Training; using static Tensorflow.Binding; @@ -39,8 +42,8 @@ namespace Tensorflow.Train protected IList _unconditional_checkpoint_dependencies; - protected IDictionary _self_saveable_object_factories = - new Dictionary(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; private static Trackable _none = new Function(); @@ -94,9 +97,13 @@ namespace Tensorflow.Train // assign again. It will add this variable to our dependencies, and if there // is a non-trivial restoration queued, it will handle that. This also // handles slot variables. - if (!args.Overwrite || new_variable is RefVariable) - return _track_checkpointable(new_variable, name: args.Name, - overwrite: args.Overwrite); + if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) + { + var temp = new_variable as Trackable; + var res = _track_trackable(temp, args.Name, args.Overwrite); + Debug.Assert(res is IVariableV1); + return res as IVariableV1; + } else return new_variable; } @@ -122,13 +129,16 @@ namespace Tensorflow.Train /// public void _maybe_initialize_trackable() { + if(_unconditional_checkpoint_dependencies is not null) + { + return; + } _self_update_uid = -1; _unconditional_checkpoint_dependencies = new List(); _unconditional_dependency_names = new Dictionary(); } - // TODO: cache - public virtual IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public virtual IDictionary _trackable_children(SaveType save_type, IDictionary>? cache) { _maybe_initialize_trackable(); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); @@ -139,8 +149,8 @@ namespace Tensorflow.Train _maybe_initialize_trackable(); if (!_manual_tracking) return trackable; var new_reference = new TrackableReference(name, trackable); - var current_object = _lookupup_dependency(name); - + var current_object = _lookup_dependency(name); + if(current_object is null) { _unconditional_checkpoint_dependencies.Add(new_reference); @@ -170,7 +180,7 @@ namespace Tensorflow.Train // TODO: complete the implementation. } - public virtual Trackable? _lookupup_dependency(string name) + public virtual Trackable? _lookup_dependency(string name) { if (_unconditional_dependency_names.TryGetValue(name, out var dependency)) return dependency; else return null; @@ -199,8 +209,8 @@ namespace Tensorflow.Train return (new Dictionary(), new Dictionary()); } - public virtual List export_to_saved_model_graph(IDictionary? object_map = null, - IDictionary? tensor_map = null, SaveOptions? options = null) + public virtual List export_to_saved_model_graph(IDictionary object_map, + IDictionary tensor_map, SaveOptions? options = null) { var (self_object_map, self_tensor_map) = map_resources(options); foreach (var pair in self_object_map) @@ -215,9 +225,17 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { - return _self_saveable_object_factories; + if (saveable_object_util.trackable_has_serialize_to_tensor(this)) + { + // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). + throw new NotImplementedException(); + } + else + { + return _self_saveable_object_factories; + } } /// @@ -229,7 +247,7 @@ namespace Tensorflow.Train /// /// /// - public virtual IDictionary serialize_to_tensors() + public virtual IDictionary>> serialize_to_tensors() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Training/TrackableUtils.cs b/src/TensorFlowNET.Core/Training/TrackableUtils.cs index 99020702..390d95c7 100644 --- a/src/TensorFlowNET.Core/Training/TrackableUtils.cs +++ b/src/TensorFlowNET.Core/Training/TrackableUtils.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using Tensorflow.Exceptions; using Tensorflow.Train; @@ -22,7 +23,7 @@ public static class TrackableUtils private static string _ESCAPE_CHAR = "."; private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"; private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"; - private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; + internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; public static string object_path_to_string(IEnumerable node_path_arr) { return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name))); @@ -145,4 +146,27 @@ public static class TrackableUtils return $"root.{string.Join(".", paths.Select(x => x.Name))}"; } } + + /// + /// Returns the substring after the "/.ATTIBUTES/" in the checkpoint key. + /// + /// + /// + /// + public static string extract_local_name(string key, string? prefix = null) + { + if(prefix is null) + { + prefix = ""; + } + var search_key = OBJECT_ATTRIBUTES_NAME + "/" + prefix; + try + { + return key.Substring(key.IndexOf(search_key) + search_key.Length); + } + catch(ArgumentOutOfRangeException) + { + return key; + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index d4e9c401..6e3336c9 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -9,6 +9,7 @@ using System.Runtime.InteropServices; using System.Text; using Tensorflow.Functions; using Tensorflow.Keras; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; using Tensorflow.Train; using static Tensorflow.ApiDef.Types; @@ -243,7 +244,7 @@ namespace Tensorflow.Training _last_wrapped_list_snapshot = new List(_storage); } - public override IDictionary _trackable_children(SaveType save_type, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { check_external_modification(); if (_non_append_mutation_value) diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index f217a052..756024db 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -4,6 +4,8 @@ using Tensorflow.Eager; using Tensorflow.Variables; using Tensorflow.Train; using static Tensorflow.Binding; +using System.Collections.Generic; +using Tensorflow.ModelSaving; namespace Tensorflow { @@ -20,6 +22,7 @@ namespace Tensorflow public string UniqueId => _unique_id; protected bool _in_graph_mode; + internal bool InGraphMode => _in_graph_mode; protected bool _trainable; public bool Trainable => _trainable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b31960c7..6093f810 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -17,7 +17,9 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using Tensorflow.Checkpoint; using Tensorflow.NumPy; +using Tensorflow.Train; using static Tensorflow.Binding; namespace Tensorflow @@ -235,5 +237,12 @@ namespace Tensorflow { return _graph_element.eval(session); } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 61a8956a..7c8812ad 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using static Tensorflow.Binding; @@ -351,7 +352,7 @@ namespace Tensorflow.Keras.Engine return output_tensors; } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index 1675fba1..ffb6f71b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; @@ -9,16 +10,16 @@ public abstract partial class Layer { public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); - public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; + public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { IDictionary children; if (save_type == SaveType.SAVEDMODEL) { - // TODO: deal with cache. + Debug.Assert(cache is not null); children = TrackableSavedModelSaver.trackable_children(cache); } else diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index b9b01dae..a2f92ba8 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -88,9 +88,29 @@ namespace Tensorflow.Keras.Engine ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; - public Tensor[] input => inboundNodes[0].input_tensors; + public Tensor[] input + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].input_tensors; + } + return null; + } + } public Dictionary> NodesByDepth { get; set; } - public Shape OutputShape => inboundNodes[0].Outputs.shape; + public Shape OutputShape + { + get + { + if(inboundNodes is not null && inboundNodes.Count > 0) + { + return inboundNodes[0].Outputs.shape; + } + return null; + } + } protected List _self_tracked_trackables; public Layer(LayerArgs args) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59f74cd2..59b205e4 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine bool include_optimizer = true, string save_format = "tf", SaveOptions? options = null, - IDictionary? signatures = null, + ConcreteFunction? signatures = null, bool save_traces = true) { if (save_format != "pb") diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 41f7788e..dfe5b05f 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Train; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Engine } } - public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary? cache = null) + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { if(save_type == SaveType.SAVEDMODEL) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 76453ca0..6a6e418c 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public partial class KerasSavedModelUtils { - public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary? signatures, + public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures, SaveOptions? options, bool save_traces = true) { if (!overwrite && File.Exists(filepath)) @@ -54,12 +54,7 @@ public partial class KerasSavedModelUtils } var metadata = generate_keras_metadata(saved_nodes, node_paths); - using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate, - FileAccess.Write)) - { - var writer = new StreamWriter(f); - writer.Write(metadata.ToString()); - } + File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); if (!include_optimizer) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index 7168e25b..fc7eab3a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -19,7 +19,7 @@ public partial class KerasSavedModelUtils /// /// /// - public static IDictionary wrap_layer_objects(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_objects(Layer layer, IDictionary> serialization_cache) { // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. @@ -55,7 +55,7 @@ public partial class KerasSavedModelUtils /// /// /// - public static IDictionary wrap_layer_functions(Layer layer, IDictionary serialization_cache) + public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) { // TODO: deal with type `RevivedLayer` and `Sequential`. diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index a399eaf1..0235f87b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -18,12 +18,12 @@ public abstract class SavedModelSaver public abstract string TrackingMetadata { get; } public abstract IDictionary objects_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); public abstract IDictionary functions_to_serialize( - IDictionary serialization_cache); + IDictionary> serialization_cache); - public IDictionary trackable_children(IDictionary? serialization_cache) + public IDictionary trackable_children(IDictionary> serialization_cache) { if (!KerasSavedModelUtils.ShouldHaveTraces) { @@ -31,7 +31,6 @@ public abstract class SavedModelSaver } var children = objects_to_serialize(serialization_cache); - return children.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value)) .ToDictionary(x => x.Key, x => x.Value); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 7a0ddd21..b092b595 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -19,12 +19,12 @@ public class LayerSavedModelSaver: SavedModelSaver get => Constants.LAYER_IDENTIFIER; } - public override IDictionary objects_to_serialize(IDictionary serialization_cache) + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).ObjectsToSerialize; } - public override IDictionary functions_to_serialize(IDictionary serialization_cache) + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) { return get_serialized_attributes(serialization_cache).FunctionsToSerialize; } @@ -33,11 +33,21 @@ public class LayerSavedModelSaver: SavedModelSaver /// Generates or retrieves serialized attributes from cache. /// /// - protected SerializedAttributes get_serialized_attributes(IDictionary serialization_cache) + protected ISerializedAttributes get_serialized_attributes(IDictionary> serialization_cache) { // TODO: deal with cache. + IDictionary keras_cache; + if(serialization_cache is not null && serialization_cache.ContainsKey(Constants.KERAS_CACHE_KEY)) + { + keras_cache = serialization_cache[Constants.KERAS_CACHE_KEY]; + } + else + { + serialization_cache![Constants.KERAS_CACHE_KEY] = keras_cache = new Dictionary(); + } + if (keras_cache.ContainsKey(_obj)) return keras_cache[_obj]; - var serialized_attr = SerializedAttributes.Create(_obj); + var serialized_attr = keras_cache[_obj] = SerializedAttributes.Create(_obj); // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`. if (KerasSavedModelUtils.should_skip_serialization(_obj)) @@ -56,7 +66,7 @@ public class LayerSavedModelSaver: SavedModelSaver /// Returns dictionary of serialized attributes. /// /// - private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary serialization_cache) + private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); @@ -75,7 +85,7 @@ public class LayerSavedModelSaver: SavedModelSaver metadata["trainable"] = _obj.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; @@ -92,8 +102,10 @@ public class LayerSavedModelSaver: SavedModelSaver } } - public static LayerConfig get_serialized(Layer obj) + public static IDictionary get_serialized(Layer obj) { - return generic_utils.serialize_keras_object(obj); + // TODO: complete the implmentation (need to revise `get_config`). + return new Dictionary(); + //return generic_utils.serialize_keras_object(obj); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index 804ea1a9..ac194c00 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Saving.SavedModel /// /// Class that tracks and validates all serialization attributes. /// - public abstract class SerializedAttributes + public abstract class SerializedAttributes: ISerializedAttributes { protected IDictionary _object_dict; protected IDictionary _function_dict; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 9d1b3088..0f34ff10 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -50,11 +50,11 @@ public class SaveTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 50000, + ValidationSize = 0, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - model.save("", save_format:"pb"); + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } } \ No newline at end of file