using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;
namespace Tensorflow.Checkpoint;
///
/// Saves and restores a `Trackable` object and its dependencies.
///
public class TrackableSaver
{
private ObjectGraphView _graph_view;
private Tensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
private Dictionary? _object_map = null;
private object? _cache = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}
private (IDictionary>, IDictionary, IDictionary>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
// TODO: cache.
if(object_graph_tensor is null)
{
// tensorflow python: `with ops.device("/cpu:0"):`
object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
}
else
{
feed_additions[object_graph_tensor] = graph_proto.ToString();
}
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
}
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}
private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);
Func<(Tensor, IDictionary)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);
// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] { save_op }))
{
_cached_save_operation = array_ops.identity(file_prefix);
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};
if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}
return run_save();
}
private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);
Func<(Tensor, IDictionary)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
var saver = new MultiDeviceSaver(serialized_tensors, registered_savers);
var save_op = saver.save(file_prefix, options);
// tensorflow python: `with ops.device("/cpu:0"):`
using (ops.control_dependencies(new object[] {save_op} ))
{
_cached_save_operation = array_ops.identity(tf.constant(file_prefix));
}
_last_save_object_graph = graph_proto;
}
return (_cached_save_operation, feed_additions);
};
if (options.experimental_enable_async_checkpoint)
{
throw new NotImplementedException();
}
return run_save();
}
// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}
Dictionary feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}
Tensor file_prefix_tensor;
Tensor object_graph_tensor;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
_file_prefix_feed_tensor = constant_op.constant("", TF_DataType.TF_STRING);
}
object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, TF_DataType.TF_STRING);
object_graph_tensor = null;
}
var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);
if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}
if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}