Browse Source

Init the serialization of keras pb model.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
bb8168b5ca
49 changed files with 2347 additions and 13 deletions
  1. +22
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  2. +150
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
  4. +63
    -0
      src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
  5. +229
    -0
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  6. +109
    -0
      src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs
  7. +75
    -0
      src/TensorFlowNET.Core/Checkpoint/TrackableView.cs
  8. +14
    -0
      src/TensorFlowNET.Core/Exceptions/AssertionError.cs
  9. +62
    -1
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  10. +2
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  11. +9
    -2
      src/TensorFlowNET.Core/Functions/Function.cs
  12. +7
    -1
      src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
  13. +6
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  14. +9
    -1
      src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs
  15. +6
    -0
      src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs
  16. +15
    -0
      src/TensorFlowNET.Core/Training/AutoTrackable.cs
  17. +6
    -1
      src/TensorFlowNET.Core/Training/Optimizer.cs
  18. +14
    -0
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  19. +11
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs
  20. +60
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs
  21. +33
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs
  22. +17
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  23. +9
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs
  24. +299
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  25. +10
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs
  26. +22
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs
  27. +256
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  28. +58
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs
  29. +52
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs
  30. +18
    -1
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  31. +78
    -1
      src/TensorFlowNET.Core/Training/Trackable.cs
  32. +148
    -0
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  33. +1
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  34. +18
    -0
      src/TensorFlowNET.Core/ops.cs
  35. +31
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  36. +3
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  37. +13
    -2
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  38. +6
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  39. +12
    -0
      src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs
  40. +7
    -0
      src/TensorFlowNET.Keras/Protobuf/Versions.cs
  41. +41
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs
  42. +11
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs
  43. +115
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  44. +19
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs
  45. +40
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs
  46. +62
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  47. +33
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs
  48. +60
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveTest.cs
  49. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

+ 22
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Text;

namespace Tensorflow namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
@@ -23,6 +25,26 @@ namespace Tensorflow
public class CompatApi public class CompatApi
{ {
public CompatV1Api v1 { get; } = new CompatV1Api(); public CompatV1Api v1 { get; } = new CompatV1Api();

internal string as_text(string bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return bytes_or_text;
}
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return encoding.GetString(bytes_or_text);
}
internal string as_str(string bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
} }


public bool executing_eagerly() public bool executing_eagerly()


+ 150
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -0,0 +1,150 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint;

public static class CheckPointUtils
{
private static string _ESCAPE_CHAR = ".";
public static (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>, Dictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
Dictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
}

public static
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
{
var non_slot_objects = trackable_objects.ToList();
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables = new();
foreach (var trackable in non_slot_objects)
{
if (trackable is not Optimizer)
{
continue;
}

var optim = (Optimizer)trackable;
var slot_names = optim.get_slot_names();
foreach (var slot_name in slot_names)
{
for (int original_variable_node_id = 0;
original_variable_node_id < non_slot_objects.Count;
original_variable_node_id++)
{
var original_variable = non_slot_objects[original_variable_node_id];
IVariableV1 slot_variable;
if (original_variable is not IVariableV1)
{
slot_variable = null;
}
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
if(slot_variable is null) continue;

// There're some problems about the inherits of `Variable` and `Trackable`.
throw new NotImplementedException();
}
}
}

return slot_variables;
}

public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
{
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
{
return trackable;
}
else
{
return possible_res;
}
}

public static string get_full_name(Trackable var)
{
// TODO: This state is not correct, the whole framework need to be updated in the future.
if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var)))
{
return "";
}
// skip the check of attribute `_save_slice_info` .
// TODO: Need to be revised!!!
return ((ResourceVariable)(object)var).Name;
}

public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
{
HashSet<int> checkpointed_trackables = new();
Dictionary<int, HashSet<int>> parents = new();
for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
{
var object_proto = object_graph_proto.Nodes[i];
// skip the process of registered saver.
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
{
checkpointed_trackables.Add(i);
}

foreach (var child_proto in object_proto.Children)
{
var child = child_proto.NodeId;
if (!parents.ContainsKey(child))
{
parents[child] = new HashSet<int>();
}

parents[child].Add(i);
}
}

Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
while (to_visit.Count > 0)
{
var trackable = to_visit.Dequeue();
if (!parents.ContainsKey(trackable)) continue;
var current_parents = parents[trackable];
foreach (var parent in current_parents)
{
checkpointed_trackables.Add(parent);
if (parents.ContainsKey(parent))
{
to_visit.Enqueue(parent);
}
}
parents.Remove(trackable);
}
// TODO: Complete it after supporting checkpoint.
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
// {
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
}

+ 5
- 0
src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs View File

@@ -0,0 +1,5 @@
namespace Tensorflow.Checkpoint;

public record class CheckpointOptions(
string experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);

+ 63
- 0
src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs View File

@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Serilog.Debugging;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;

public class ObjectGraphView: TrackableView, ICloneable
{
protected IEnumerable<TrackableReference>? _attached_dependencies;
// TODO: attached_dependencies
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root)
{
_attached_dependencies = attached_dependencies;
}

public object Clone()
{
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__
return new ObjectGraphView(Root, _attached_dependencies);
}

public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
{
List<TrackableReference> res = base.children(obj, save_type)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value.
if (obj == Root && _attached_dependencies is not null)
{
res.AddRange(_attached_dependencies);
}

return res;
}
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
{
return list_children(obj, save_type).ToDictionary(x => x.Name, x => x.Refer);
}
public IEnumerable<TrackableReference>? AttachedDependencies
{
get => _attached_dependencies;
}

public virtual (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
return base._descendants_with_paths();
}

// TODO: complete the implementation
public void serialize_object_graph(object? saveables_cache = null)
{
throw new NotImplementedException();
}
// TODO: complete the implementation
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null)
{
throw new NotImplementedException();
}
}

+ 229
- 0
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -0,0 +1,229 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;

namespace Tensorflow.Checkpoint;

public static class SaveUtilV1
{
public static (Dictionary<Trackable, IEnumerable<CheckpointFactoryData>>, object?) get_checkpoint_factories_and_keys(IDictionary<Trackable, string> object_names,
IDictionary<Trackable, Trackable>? object_map = null)
{
// According to https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/registration/README.md,
// till now only internal registrations are allowed. So, we won't return a saver in this function.
// The implementation of this function should be updated if tensorflow update it.
Dictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map = new();
foreach (var pair in object_names)
{
var trackable = pair.Key;
var object_name = pair.Value;
var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip the registration process.

List<CheckpointFactoryData> current_list = new();
foreach (var name_and_factory in saveable_object_util.saveable_objects_from_trackable(object_to_save))
{
// treat name as key_suffix.
var name = name_and_factory.Key;
var checkpoint_key = TrackableUtils.checkpoint_key(object_name, name);
current_list.Add(new CheckpointFactoryData(name_and_factory.Value, name, checkpoint_key));
}

checkpoint_factory_map[trackable] = current_list;
}

return (checkpoint_factory_map, null);
}

public static (List<MySaveableObject>, object?) frozen_saveables_and_savers(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null)
{

Graph target_context;
if (to_graph is not null)
{
using (to_graph.as_default())
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
// tensorflow python: `with ops.device("/cpu:0")`
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
}
else
{
using (new ops.NullContextManager())
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
// tensorflow python: `with ops.device("/cpu:0")`
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
}
}

public static (List<MySaveableObject>, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = CheckPointUtils.serialize_slot_variables(trackable_objects, node_ids, object_names);
var object_graph_proto = fill_object_graph_proto(graph_view, trackable_objects, node_ids, slot_variables);
var (named_saveable_objects, feed_additions, registered_savers) = add_attributes_to_object_graph(
trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures,
saveables_cache);
CheckPointUtils.add_checkpoint_values_check(object_graph_proto);
return (named_saveable_objects, object_graph_proto, feed_additions, registered_savers);
}

private static TrackableObjectGraph fill_object_graph_proto(ObjectGraphView graph_view, IList<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables)
{
TrackableObjectGraph object_graph_proto = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
var trackable = trackable_objects[i];
Debug.Assert(node_ids[trackable] == i);
TrackableObjectGraph.Types.TrackableObject object_proto;
if (slot_variables.TryGetValue(trackable, out var slots))
{
object_proto = new TrackableObjectGraph.Types.TrackableObject(slots);
}
else
{
object_proto = new TrackableObjectGraph.Types.TrackableObject();
}
object_graph_proto.Nodes.Add(object_proto);
foreach (var child in graph_view.list_children(trackable))
{
object_proto.Children.Add(new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference()
{ NodeId = node_ids[child.Refer], LocalName = child.Name });
}
}

return object_graph_proto;
}

private static (List<MySaveableObject>, object?, object?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
bool call_with_mapped_captures, object? saveables_cache = null)
{
int cnt = Math.Min(trackable_objects.Count, object_graph_proto.Nodes.Count);
for (int i = 0; i < cnt; i++)
{
Debug.Assert(node_ids[trackable_objects[i]] == i);
}

var (checkpoint_factory_map, unmmaped_registered_savers) =
get_checkpoint_factories_and_keys(object_names, object_map);
// skip the process of registered savers

var (named_saveable_objects, feed_additions) = generate_saveable_objects(checkpoint_factory_map,
object_graph_proto, node_ids, object_map, call_with_mapped_captures, saveables_cache);
return (named_saveable_objects, feed_additions, null);
}

public static (List<MySaveableObject>, object?) generate_saveable_objects(
IDictionary<Trackable, IEnumerable<CheckpointFactoryData>> checkpoint_factory_map,
TrackableObjectGraph? object_graph_proto, IDictionary<Trackable, int>? node_ids,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
List<MySaveableObject> named_saveable_objects = new();
foreach (var pair in checkpoint_factory_map)
{
var trackable = pair.Key;
var factory_data_list = pair.Value;
bool fill_object_proto = object_graph_proto is not null && node_ids is not null;
TrackableObjectGraph.Types.TrackableObject object_proto = null!;
if (fill_object_proto)
{
object_proto = object_graph_proto.Nodes[node_ids[trackable]];
}

var object_to_save = CheckPointUtils.get_mapped_trackable(trackable, object_map);
// skip cache

foreach (var factory_data in factory_data_list)
{
var name = factory_data.name;
var key = factory_data.checkpoint_key;
var saveable_factory = factory_data.factory;
// TODO: oneflow python has a process with callable `saveable_factory`.
var maybe_saveable = saveable_factory;
IEnumerable<MySaveableObject> savesbles;
if (maybe_saveable is MySaveableObject)
{
savesbles = new List<MySaveableObject>() { (MySaveableObject)maybe_saveable };
}
else if (maybe_saveable is Tensor)
{
savesbles = saveable_object_util.saveable_objects_for_op((Tensor)maybe_saveable, key);
}
else
{
throw new TypeError("Unexpected type.");
}

foreach (var saveable in savesbles)
{
if (!saveable.name.Contains(key))
{
throw new AssertionError($"The object {trackable} produced a SaveableObject with name " +
$"'{saveable.name}' for attribute '{name}'. Expected a name" +
$" containing '{key}'.");
}
}
// skip the process of PythonState
named_saveable_objects.AddRange(savesbles);
if(!fill_object_proto) continue;
// skip the process of TrackableSaveable

object_proto!.Attributes.Add(new TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor()
{ Name = name, CheckpointKey = key, FullName = CheckPointUtils.get_full_name(object_to_save) });
}
}

return (named_saveable_objects, null);
}
}

public record class CheckpointFactoryData
(
object factory,
string name,
string checkpoint_key
);

+ 109
- 0
src/TensorFlowNET.Core/Checkpoint/TrackableSaver.cs View File

@@ -0,0 +1,109 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;

namespace Tensorflow.Checkpoint;

public class TrackableSaver
{
private ObjectGraphView _graph_view;
private EagerTensor _cached_save_operation;
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
// TODO: cache when not executing eagerly.
// including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`,
// `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache`
}

private void gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
throw new NotImplementedException();
}

private (EagerTensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
throw new NotImplementedException();
}
// TODO: parameter write_done_callback
public Tensor save(string file_prefix, int? checkpoint_number = null, Session? session = null,
CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}

Dictionary<Tensor, string> feed_dict = new();
bool use_session = (!new Context().executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{
file_prefix = $"{file_prefix}-{checkpoint_number?.ToString()}";
}

Tensor file_prefix_tensor;
Tensor object_graph_tensor;
if (use_session)
{
if (_object_graph_feed_tensor is null)
{
// In python there is `with ops.device("/cpu:0")`.
_object_graph_feed_tensor = constant_op.constant("", dtypes.variant);
_file_prefix_feed_tensor = constant_op.constant("", dtypes.variant);
}

object_graph_tensor = _object_graph_feed_tensor;
file_prefix_tensor = _file_prefix_feed_tensor;
feed_dict[file_prefix_tensor] = file_prefix;
}
else
{
// In python there is `with ops.device("/cpu:0")`.
file_prefix_tensor = ops.convert_to_tensor(file_prefix, dtypes.variant);
object_graph_tensor = null;
}

var (save_path, new_feed_additions) =
save_cached_when_graph_building(file_prefix_tensor, object_graph_tensor, options);

if (new_feed_additions is not null)
{
foreach (var pair in new_feed_additions)
{
feed_dict.Add(pair.Key, pair.Value);
}
}
if(!use_session)
{
session = null;
}
else if (session is null)
{
session = new Session(); // In python it uses `get_session`.
}

if (session is not null)
{
var s = feed_dict.Select(x => new FeedItem(x.Key, x.Value)).ToArray();
return session.run((Tensor)save_path, s);
}
else if (use_session)
{
throw new RuntimeError($"Unable to save checkpoint to \"{file_prefix}\" " +
"in graph mode without a default session. Please use " +
"`with tf.Session():` to create a session.");
}
else
{
return save_path;
}
}
}

+ 75
- 0
src/TensorFlowNET.Core/Checkpoint/TrackableView.cs View File

@@ -0,0 +1,75 @@
using System;
using Tensorflow.Train;
using System.Collections.Generic;
using System.IO;

namespace Tensorflow.Checkpoint;

public class TrackableView
{
protected WeakReference<Trackable> _root_ref;
public TrackableView(Trackable obj)
{
_root_ref = new WeakReference<Trackable>(obj);
}

public TrackableView(WeakReference<Trackable> obj)
{
_root_ref = obj;
}
public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
{
obj._maybe_initialize_trackable();
// Note: in python the return type of `Trackable._trackable_children` is not fixed.
// Therefore it uses `convert_to_trackable` to have an extra process.
return obj._trackable_children(save_type);
}
public Trackable Root
{
get
{
if (_root_ref.TryGetTarget(out Trackable res))
{
return res;
}
else
{
throw new InvalidDataException(
"Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor.");
}
}
}
/// <summary>
/// Returns a list of all nodes and its paths from self.root using a breadth first traversal.
/// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths
/// </summary>
protected (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
{
List<Trackable> bfs_sorted = new();
Queue<Trackable> to_visit = new();
Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new();
node_paths[this.Root] = new List<TrackableReference>();
while (!to_visit.empty())
{
var current_trackable = to_visit.Dequeue();
bfs_sorted.Add(current_trackable);
var children_dict = this.children(current_trackable);
foreach (var name in children_dict.Keys)
{
var dependency = children_dict[name];
if (!node_paths.ContainsKey(dependency))
{
var list = new List<TrackableReference>(node_paths[current_trackable]);
list.Add(new TrackableReference(name, dependency));
node_paths[dependency] = list;
to_visit.Enqueue(dependency);
}
}
}

return (bfs_sorted, node_paths);
}
}

+ 14
- 0
src/TensorFlowNET.Core/Exceptions/AssertionError.cs View File

@@ -0,0 +1,14 @@
namespace Tensorflow.Exceptions;

public class AssertionError : TensorflowException
{
public AssertionError() : base()
{

}

public AssertionError(string message) : base(message)
{

}
}

+ 62
- 1
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -304,7 +304,7 @@ namespace Tensorflow
} }
} }


private static OpList stripped_op_list_for_graph(GraphDef graph_def)
public static OpList stripped_op_list_for_graph(GraphDef graph_def)
{ {
var used_ops = ops_used_by_graph_def(graph_def); var used_ops = ops_used_by_graph_def(graph_def);


@@ -345,5 +345,66 @@ namespace Tensorflow


return used_ops.ToArray(); return used_ops.ToArray();
} }

private static bool is_default_attr_value(OpDef op_def, string attr_name, AttrValue attr_value)
{
foreach (var attr_def in op_def.Attr)
{
if (attr_def.Name == attr_name)
{
if (attr_def.DefaultValue is null) return false;
// TODO: add new c_api `EqualAttrValueWrapper` and complete the check.
return true;
}
}

return false;
}

public static void strip_graph_default_valued_attrs(MetaGraphDef meta_graph_def)
{
Dictionary<string, FunctionDef> op_name_to_function = new();
foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
op_name_to_function[function_def.Signature.Name] = function_def;
}

Action<NodeDef> _strip_node_default_valued_attrs = (node_def) =>
{
if (op_name_to_function.ContainsKey(node_def.Op)) return;

var op_def = op_def_registry.GetOpDef(node_def.Op);
if(op_def is null) return;

HashSet<string> attrs_to_strip = new();
foreach (var attr in node_def.Attr)
{
if (is_default_attr_value(op_def, attr.Key, attr.Value))
{
attrs_to_strip.Add(attr.Key);
}
}

foreach (var attr in attrs_to_strip)
{
node_def.Attr.Remove(attr);
}
};

foreach (var node_def in meta_graph_def.GraphDef.Node)
{
_strip_node_default_valued_attrs(node_def);
}

foreach (var function_def in meta_graph_def.GraphDef.Library.Function)
{
foreach (var function_node_def in function_def.NodeDef)
{
_strip_node_default_valued_attrs(function_node_def);
}
}

meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Functions namespace Tensorflow.Functions
@@ -10,7 +11,7 @@ namespace Tensorflow.Functions
/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>
public class ConcreteFunction
public class ConcreteFunction: Trackable
{ {
FuncGraph func_graph; FuncGraph func_graph;
ForwardBackwardCall forward_backward; ForwardBackwardCall forward_backward;


+ 9
- 2
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -1,16 +1,23 @@
using System; using System;
using Tensorflow.Train;


namespace Tensorflow namespace Tensorflow
{ {
public class Function
public class Function: Trackable
{ {
#pragma warning disable CS0169 // The field 'Function._handle' is never used #pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle; private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used #pragma warning restore CS0169 // The field 'Function._handle' is never used

public string Name { get; set; }
public Function() public Function()
{ {


} }
public Function(string name)
{
Name = name;
}
} }
} }

+ 7
- 1
src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs View File

@@ -9,7 +9,13 @@ namespace Tensorflow.ModelSaving
/// </summary> /// </summary>
public class SaveOptions public class SaveOptions
{ {
bool save_debug_info;
public bool save_debug_info = false;
public IList<string>? namespace_white_list { get; set; } = null;
public IDictionary<string, object>? function_aliases { get; set; } = null;
public string? experimental_io_device { get; set; } = null;
// TODO: experimental
public Object? experimental_variable_polict { get; set; } = null;
public bool experimental_custom_gradients { get; set; } = true;
public SaveOptions(bool save_debug_info = false) public SaveOptions(bool save_debug_info = false)
{ {
this.save_debug_info = save_debug_info; this.save_debug_info = save_debug_info;


+ 6
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.Train;
using static Tensorflow.CppShapeInferenceResult.Types; using static Tensorflow.CppShapeInferenceResult.Types;


namespace Tensorflow namespace Tensorflow
@@ -38,6 +39,11 @@ namespace Tensorflow
{ {
return var is ResourceVariable; return var is ResourceVariable;
} }
public static bool is_resource_variable(Trackable var)
{
return var is BaseResourceVariable;
}


/// <summary> /// <summary>
/// Creates a variable handle with information to do shape inference. /// Creates a variable handle with information to do shape inference.


+ 9
- 1
src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs View File

@@ -156,7 +156,7 @@ namespace Tensorflow {
/// Nodes[0] is considered the root node. /// Nodes[0] is considered the root node.
/// </summary> /// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
public pbc::RepeatedField<global::Tensorflow.SavedObject> Nodes {
get { return nodes_; } get { return nodes_; }
} }


@@ -286,6 +286,7 @@ namespace Tensorflow {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(SavedObject other) : this() { public SavedObject(SavedObject other) : this() {
children_ = other.children_.Clone(); children_ = other.children_.Clone();
dependencies_ = other.dependencies_.Clone();
slotVariables_ = other.slotVariables_.Clone(); slotVariables_ = other.slotVariables_.Clone();
saveableObjects_ = other.saveableObjects_.Clone(); saveableObjects_ = other.saveableObjects_.Clone();
switch (other.KindCase) { switch (other.KindCase) {
@@ -328,6 +329,7 @@ namespace Tensorflow {
private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec private static readonly pb::FieldCodec<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> _repeated_children_codec
= pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser);
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>(); private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> children_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
private readonly pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> dependencies_ = new pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference>();
/// <summary> /// <summary>
/// Objects which this object depends on: named edges in the dependency /// Objects which this object depends on: named edges in the dependency
/// graph. /// graph.
@@ -338,6 +340,11 @@ namespace Tensorflow {
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children { public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Children {
get { return children_; } get { return children_; }
} }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference> Dependencies {
get { return dependencies_; }
}


/// <summary>Field number for the "slot_variables" field.</summary> /// <summary>Field number for the "slot_variables" field.</summary>
public const int SlotVariablesFieldNumber = 3; public const int SlotVariablesFieldNumber = 3;
@@ -617,6 +624,7 @@ namespace Tensorflow {
return; return;
} }
children_.Add(other.children_); children_.Add(other.children_);
dependencies_.Add(other.dependencies_);
slotVariables_.Add(other.slotVariables_); slotVariables_.Add(other.slotVariables_);
saveableObjects_.Add(other.saveableObjects_); saveableObjects_.Add(other.saveableObjects_);
switch (other.KindCase) { switch (other.KindCase) {


+ 6
- 0
src/TensorFlowNET.Core/Protobuf/TrackableObjectGraph.cs View File

@@ -198,6 +198,12 @@ namespace Tensorflow {
public TrackableObject() { public TrackableObject() {
OnConstruction(); OnConstruction();
} }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public TrackableObject(pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference> slot) {
OnConstruction();
slotVariables_ = slot;
}


partial void OnConstruction(); partial void OnConstruction();




+ 15
- 0
src/TensorFlowNET.Core/Training/AutoTrackable.cs View File

@@ -2,5 +2,20 @@
{ {
public abstract class AutoTrackable : Trackable public abstract class AutoTrackable : Trackable
{ {
public void _delete_tracking(string name)
{
_maybe_initialize_trackable();
if (_unconditional_dependency_names.ContainsKey(name))
{
_unconditional_dependency_names.Remove(name);
for (int i = _unconditional_checkpoint_dependencies.Count - 1; i >= 0; i--)
{
if (_unconditional_checkpoint_dependencies[i].Name == name)
{
_unconditional_checkpoint_dependencies.RemoveAt(i);
}
}
}
}
} }
} }

+ 6
- 1
src/TensorFlowNET.Core/Training/Optimizer.cs View File

@@ -351,7 +351,7 @@ namespace Tensorflow
/// <param name="var"></param> /// <param name="var"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
protected IVariableV1 get_slot(IVariableV1 var, string name)
internal IVariableV1 get_slot(IVariableV1 var, string name)
{ {
var named_slots = _slots.ContainsKey(name) ? _slots[name] : null; var named_slots = _slots.ContainsKey(name) ? _slots[name] : null;
if (named_slots == null) if (named_slots == null)
@@ -360,6 +360,11 @@ namespace Tensorflow
return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null; return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null;
} }


internal IEnumerable<string> get_slot_names()
{
return _slots.Keys;
}

private string _var_key(IVariableV1 var) private string _var_key(IVariableV1 var)
{ {
return $"{var.Op.graph.graph_key}.{var.Op.name}"; return $"{var.Op.graph.graph_key}.{var.Op.name}";


+ 14
- 0
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -48,4 +48,18 @@ namespace Tensorflow
validate_shape: restored_shapes == null && op.shape.IsFullyDefined); validate_shape: restored_shapes == null && op.shape.IsFullyDefined);
} }
} }

public class NoRestoreSaveable: MySaveableObject
{
public NoRestoreSaveable(Tensor tensor, string name, TF_DataType dtype = TF_DataType.DtInvalid, string? device = null) : base(tensor,
new SaveSpec[] { new SaveSpec(tensor, "", name, dtype) }, name)
{
}

public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
return control_flow_ops.no_op();
}
}
} }

+ 11
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/AssetInfo.cs View File

@@ -0,0 +1,11 @@
using System.Collections.Generic;

namespace Tensorflow;

public record class AssetInfo
(
List<AssetFileDef> asset_defs,
Dictionary<object, object> asset_initializers_by_resource,
Dictionary<AssetInfo, string> asset_filename_map,
Dictionary<object, object> asset_index
);

+ 60
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/AugmentedGraphView.cs View File

@@ -0,0 +1,60 @@
using System;
using Tensorflow.Checkpoint;
using Tensorflow.Train;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;

namespace Tensorflow;

public class AugmentedGraphView: ObjectGraphView
{
// private object _children_cache;
// private object _serialization_cache;
private List<string> _untraces_functions;
public AugmentedGraphView(Trackable root): base(root)
{
_untraces_functions = new();
}

public void set_signature(object signature_map, object wrapped_functions)
{
// TODO: cache
list_children(Root);
}
public override List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT)
{
Dictionary<string, Trackable> children = new();
foreach (var pair in base.list_children(obj, save_type))
{
var name = pair.Name;
var child = pair.Refer;
children[name] = child;
}

if (obj is Function && children.Count == 0)
{
_untraces_functions.Add(((Function)obj).Name);
}

return children.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
}

public override (List<Trackable>, Dictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
// TODO: implement it if needed.
return base.breadth_first_traversal();
}

public List<(string, Trackable)> list_dependencies(Trackable obj)
{
// TODO: deal with cache.
return obj.deserialization_dependencies(null).Select(x => (x.Key, x.Value)).ToList();
}

public Trackable get_child(Trackable obj, string name)
{
throw new NotImplementedException();
}
}

+ 33
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/Constants.cs View File

@@ -0,0 +1,33 @@
namespace Tensorflow;

public static class Constants
{
public static readonly string ASSETS_DIRECTORY = "assets";
public static readonly string ASSETS_KEY = "saved_model_assets";

public static readonly string DEBUG_DIRECTORY = "debug";

public static readonly string DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb";

public static readonly string EXTRA_ASSETS_DIRECTORY = "assets.extra";

public static readonly string FINGERPRINT_FILENAME = "fingerprint.pb";

public static readonly string INIT_OP_SIGNATURE_KEY = "__saved_model_init_op";

public static readonly string LEGACY_INIT_OP_KEY = "legacy_init_op";

public static readonly string MAIN_OP_KEY = "saved_model_main_op";

public static readonly string SAVED_MODEL_FILENAME_PB = "saved_model.pb";
public static readonly string SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt";

public static readonly int SAVED_MODEL_SCHEMA_VERSION = 1;

public static readonly string TRAIN_OP_KEY = "saved_model_train_op";

public static readonly string TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op";

public static readonly string VARIABLES_DIRECTORY = "variables";
public static readonly string VARIABLES_FILENAME = "variables";
}

+ 17
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -0,0 +1,17 @@
using Tensorflow.Train;

namespace Tensorflow;

public class RevivedTypes
{
/// <summary>
/// Create a SavedUserObject from a trackable object.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public static SavedUserObject? serialize(Trackable obj)
{
// TODO: complete the implementation.
return null;
}
}

+ 9
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveType.cs View File

@@ -0,0 +1,9 @@
using System;

namespace Tensorflow;

public enum SaveType
{
SAVEDMODEL,
CHECKPOINT
}

+ 299
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -0,0 +1,299 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;

namespace Tensorflow;

public class SaveableView
{
private AugmentedGraphView _augmented_graph_view;
private SaveOptions _options;
private List<Trackable> _trackable_objects;
private List<Trackable> _nodes;
private Dictionary<Trackable, IEnumerable<TrackableReference>> _node_paths;
private Dictionary<Trackable, int> _node_ids;
private IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
_slot_variables;
private Dictionary<Trackable, string> _object_names;
private List<object> _gradient_functions; // to be completed
private List<RegisteredGradient> _gradient_defs; // to be completed
private List<ConcreteFunction> _concrete_functions;
private Dictionary<Tensor, int> _captured_tensor_node_ids;
private Dictionary<Trackable, IDictionary<string, ConcreteFunction>> _saveable_objects_map;
private Dictionary<Trackable, string> _obj_to_registered_saver;

public AugmentedGraphView AugmentedGraphView
{
get => _augmented_graph_view;
}
public Trackable Root
{
get => _nodes[0];
}
public List<Trackable> Nodes
{
get => _nodes;
}
public Dictionary<Trackable, int> NodeIds
{
get => _node_ids;
}
public List<RegisteredGradient> GradientDefs
{
get => _gradient_defs;
}
public Dictionary<Trackable, IEnumerable<TrackableReference>> NodePaths
{
get => _node_paths;
}
public SaveableView(AugmentedGraphView augmented_graph_view, SaveOptions options)
{
_augmented_graph_view = augmented_graph_view;
_options = options;

(_trackable_objects, _node_paths, _node_ids, _slot_variables, _object_names) =
CheckPointUtils.objects_ids_and_slot_variables_and_paths(_augmented_graph_view);
// TODO: deal with untraced functions.
initialize_save_and_restore_functions();
initialize_nodes_and_concrete_functions();

_captured_tensor_node_ids = new();
}

private void initialize_save_and_restore_functions()
{
// TODO: deal with the return value of `get_checkpoint_factories_and_keys`.
SaveUtilV1.get_checkpoint_factories_and_keys(_object_names);
// skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver.
_obj_to_registered_saver = new();
_saveable_objects_map = new();
}

private void initialize_nodes_and_concrete_functions()
{
_nodes = _trackable_objects.ConvertAll(x => x); // deep copy
_gradient_functions = new();
_gradient_defs = new();

// TODO: deal with the condition that obj in `_saveable_objects_map`.
// foreach (var obj in _nodes)
// {
//
// }

foreach (var obj in _nodes)
{
if (obj is ConcreteFunction)
{
_concrete_functions.Add((ConcreteFunction)obj);
}
}
}

public List<ConcreteFunction> get_concrete_resource_initializers()
{
// TODO: complete the implementation.
return new List<ConcreteFunction>();
}
public (Dictionary<Trackable, Trackable>, Dictionary<Tensor, Tensor>, AssetInfo) map_resources()
{
Debug.Assert(!tf.Context.executing_eagerly());

Dictionary<Trackable, Trackable> object_map = new();
Dictionary<Tensor, Tensor> tensor_map = new();

AssetInfo assetInfo = new(new List<AssetFileDef>(), new Dictionary<object, object>(),
new Dictionary<AssetInfo, string>(), new Dictionary<object, object>());

foreach (var node_id in dependency_sorted_node_ids())
{
var obj = _nodes[node_id];
var tensors = obj.export_to_saved_model_graph(object_map, tensor_map, _options);
// TODO: deal with Asset (if obj is Asset)
foreach (var tensor in tensors)
{
_captured_tensor_node_ids[tensor] = node_id;
}
}

return (object_map, tensor_map, assetInfo);
}

/// <summary>
/// Returns topologically sorted nodes, sorted by dependencies.
/// </summary>
public List<int> dependency_sorted_node_ids()
{
Dictionary<int, IEnumerable<int>> dependency_map = new();
foreach (var node in _nodes)
{
var node_id = _node_ids[node];
List<int> deps = new();
// TODO: deal with captured tensor.

string node_path;
foreach (var (_, dep) in _augmented_graph_view.list_dependencies(node))
{
if (!_node_ids.ContainsKey(dep))
{
node_path = TrackableUtils.pretty_print_node_path(_node_paths[node]);
throw new ValueError(
$"Found an untracked dependency. Object {node_path} depends on {dep}, " +
$"but this dependency isn't listed as a child. Please track this child by " +
$"overriding `_trackable_children` or use `._track_trackable`.");
}
deps.Add(_node_ids[dep]);
}
}

try
{
return TrackableUtils.order_by_dependency(dependency_map);
}
catch (TrackableUtils.CyclicDependencyError err)
{
List<string> pretty_printed_nodes = new();
List<string> pretty_printed_dependencies = new();

foreach (var pair in err.LeftOverDependencyMap)
{
var x = pair.Key;
var deps = pair.Value;
var node_path = TrackableUtils.pretty_print_node_path(_node_paths[_nodes[x]]);
pretty_printed_nodes.Add($"\tNode {x.ToString()} = {node_path} (type {_nodes[x]})");
pretty_printed_dependencies.Add(
$"\tNode {x.ToString()} depends on nodes [{string.Join(", ", deps.Select(x => x.ToString()))}]");
}

throw new ValueError($"There is one or more dependency cycle in the saved Trackable object. " +
$"Saving cannot continue until this cycle is resolved." +
$"\n>> Unresolved nodes:\n{string.Join("\n", pretty_printed_nodes)}" +
$"\n>> Unresolved cyclic dependencies:\n{string.Join("\n", pretty_printed_dependencies)}");
}
}

/// <summary>
/// Corresponding to tensorflow/python/saved_model/save.py/_serialize_object_graph
/// </summary>
/// <param name="asset_index"></param>
/// <returns></returns>
public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index, SaveOptions options)
{
SavedObjectGraph proto = new();
fill_object_graph_proto(proto);
// TODO: complete the process of concrete functions.

int cnt = Math.Min(_nodes.Count, proto.Nodes.Count);
for (int i = 0; i < cnt; i++)
{
var obj = _nodes[i];
var obj_proto = proto.Nodes[i];
write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x),
options);
}

return proto;
}

private static void write_object_proto(Trackable obj, SavedObject proto,
IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn, SaveOptions options)
{
// skip the process of type Asset
if (resource_variable_ops.is_resource_variable(obj))
{
// TODO: complete it.
throw new NotImplementedException();
}
else if (obj is Function)
{
// TODO: complete it.
throw new NotImplementedException();
}
else if (obj is ConcreteFunction)
{
// TODO: complete it.
throw new NotImplementedException();
}
// skip the process of type `_CapturedTensor` and `CapturableResource`.
else
{
var registered_type_proto = RevivedTypes.serialize(obj);
if (registered_type_proto is null)
{
registered_type_proto = new SavedUserObject()
{
Identifier = obj.ObjectIdentifier,
Version = new VersionDef()
{
Producer = 1,
MinConsumer = 1,
BadConsumers = { }
}
};
}

proto.UserObject = new SavedUserObject(registered_type_proto);
}
// TODO: try get the registered_name from `registration`.
}

public void fill_object_graph_proto(SavedObjectGraph proto)
{
for (int node_id = 0; node_id < _nodes.Count; node_id++)
{
var node = _nodes[node_id];
Debug.Assert(_node_ids[node] == node_id);
SavedObject object_proto = new();
if (_slot_variables.TryGetValue(node, out var value))
{
object_proto.SlotVariables.AddRange(value);
}
// skip the check of type `_CapturedTensor`
foreach (var child in _augmented_graph_view.list_children(node))
{
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference();
child_proto.NodeId = _node_ids[child.Refer];
child_proto.LocalName = child.Name;
object_proto.Children.Add(child_proto);
}

foreach (var pair in _augmented_graph_view.list_dependencies(node))
{
var child_proto = new TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference();
child_proto.NodeId = _node_ids[pair.Item2];
child_proto.LocalName = pair.Item1;
object_proto.Dependencies.Add(child_proto);
}

if (_saveable_objects_map.ContainsKey(node))
{
// TODO: complete it.
throw new NotImplementedException();
}
else if(_obj_to_registered_saver.ContainsKey(node))
{
// TODO: complete it.
// We now skip it for the lack of `SavedObject.registered_saver` API.
throw new NotImplementedException();
}

proto.Nodes.Add(object_proto);
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/TagConstants.cs View File

@@ -0,0 +1,10 @@
namespace Tensorflow;

public static class TagConstants
{
public static readonly string SERVING = "serve";
public static readonly string TRAINING = "train";
public static readonly string EVAL = "eval";
public static readonly string GPU = "gpu";
public static readonly string TPU = "tpu";
}

+ 22
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/builder.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using static Tensorflow.Binding;

namespace Tensorflow;

public class BuilderUtils
{
public static void copy_assets_to_destination_dir(IDictionary<AssetInfo, string> asset_filename_map,
string destination_dir, HashSet<string>? saved_files = null)
{
if (saved_files is null) saved_files = new HashSet<string>();

var asset_destination_dir = SavedModelUtils.get_or_create_assets_dir(destination_dir);

// TODO: complete the implementation of this function.
if (asset_filename_map is not null && asset_filename_map.Count > 0)
{
throw new NotImplementedException();
}
}
}

+ 256
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -0,0 +1,256 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.Protobuf;
using Tensorflow.Checkpoint;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;

namespace Tensorflow;

public static partial class SavedModelUtils
{
private static readonly IEnumerable<int> byte_swappable = new List<TF_DataType>()
{
dtypes.float16, dtypes.float32, dtypes.float64, TF_DataType.TF_BFLOAT16,
dtypes.complex64, dtypes.complex128, TF_DataType.TF_UINT16, dtypes.uint32,
dtypes.uint64, TF_DataType.TF_INT16, dtypes.int32, dtypes.int64, TF_DataType.TF_QINT16,
TF_DataType.TF_QUINT16, TF_DataType.TF_QINT32
}.Select(x => (int)x);
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) save_and_return_nodes(Trackable obj,
string export_dir, IDictionary<string, ConcreteFunction>? signatures, SaveOptions? options = null, bool experimental_skip_checkpoint = false)
{
if (options is null)
{
options = new SaveOptions();
}

var saved_model = new Tensorflow.SavedModel();
var meta_graph_def = new MetaGraphDef();
saved_model.MetaGraphs.Add(meta_graph_def);

var (_, exported_graph, object_saver, asset_info, saved_nodes, node_paths) =
_build_meta_graph(obj, signatures, options, meta_graph_def);
saved_model.SavedModelSchemaVersion = Tensorflow.Constants.SAVED_MODEL_SCHEMA_VERSION;

if (!experimental_skip_checkpoint)
{
Tensorflow.SavedModelUtils.get_or_create_variables_dir(export_dir);
CheckpointOptions ckpt_options = new(options.experimental_io_device);
object_saver.save(Tensorflow.SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
}
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir);

if (tf.Context.executing_eagerly())
{
// tensorflow python has a check of `context.async_wait()` here.
}
// TODO: deal with `pywrap_saved_model.Save(export_dir)`.

var saved_model_serialized = saved_model.ToString();

// This is a state depending on some py-c APIs. Here we temporarily set it as `true`.
if (true)
{
var fingerprint_path = Path.Combine(tf.compat.as_str(export_dir),
tf.compat.as_str(Constants.FINGERPRINT_FILENAME));
// TODO: add c api and complete the fingerprint def.
var fingerprint_proto = "";
File.WriteAllText(fingerprint_path, fingerprint_proto);
}

var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB));
File.WriteAllText(path, saved_model.ToString());

if (options.save_debug_info)
{
throw new NotImplementedException();
}
ops.dismantle_graph(exported_graph);

return (saved_nodes, node_paths);
}

private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, List<Trackable>,
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
IDictionary<string, ConcreteFunction>? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{
if (ops.inside_function())
{
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " +
"Move the call to the outer eagerly-executed context.");
}

if (meta_graph_def is null)
{
meta_graph_def = new MetaGraphDef();
}

AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is not null)
{
throw new NotImplementedException();
}
// TODO: process of aignatures and wrapped_functions

SaveableView saveable_view = new SaveableView(augmented_graph_view, options);
TrackableSaver object_saver = new TrackableSaver(augmented_graph_view);
var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures,
options.namespace_white_list, options.experimental_custom_gradients);
if (options.function_aliases is not null)
{
var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases;
foreach (var pair in options.function_aliases)
{
var alias = pair.Key;
var func = pair.Value;
// TODO: complete it.
throw new NotImplementedException();
}
}

var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options);
meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto);

return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths);
}

private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view,
IDictionary<string, ConcreteFunction> signatures, IEnumerable<string> namespace_whitelist,
bool save_custom_gradients)
{
var resource_initializers = saveable_view.get_concrete_resource_initializers();
var exported_graph = new Graph();

Dictionary<Trackable, Trackable> object_map;
Dictionary<Tensor, Tensor> tensor_map;
AssetInfo asset_info;
using (var g = exported_graph.as_default())
{
(object_map, tensor_map, asset_info) = saveable_view.map_resources();
// TODO: deal with signatures.
if (save_custom_gradients)
{
// TODO: trace gradient functions.
}

foreach (var resource_initializer_function in resource_initializers)
{
// List<Trackable> asset_dependencies = new();
// TODO: deal with initializers
}
// using(ops.control_dependencies(...))
var init_op = control_flow_ops.no_op();
if (meta_graph_def.CollectionDef.ContainsKey(Tensorflow.Constants.MAIN_OP_KEY))
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY].NodeList.Value.Append(init_op.name);
}
else
{
meta_graph_def.CollectionDef[Tensorflow.Constants.MAIN_OP_KEY] = new CollectionDef();
}
// Lack `CopyFrom` API
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY]
}
foreach (var obj in object_map.Values)
{
obj._maybe_initialize_trackable();
}

var (named_saveable_objects, registered_savers) =
SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false);
// TODO: complete the save of checkpoints with `MultiDeviceSaver`.

saveable_view.dependency_sorted_node_ids();

var graph_def = exported_graph.as_graph_def(true);
graph_def.Library.RegisteredGradients.AddRange(saveable_view.GradientDefs);
verify_ops(graph_def, namespace_whitelist);

meta_graph_def.GraphDef = new GraphDef(graph_def);
meta_graph_def.MetaInfoDef.Tags.Add(TagConstants.SERVING);
meta_graph_def.MetaInfoDef.TensorflowVersion = tf.VERSION;
// TODO: add git version.
meta_graph_def.MetaInfoDef.TensorflowGitVersion = "";
meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
meta_graph_def.MetaInfoDef.StrippedOpList.MergeFrom(meta_graph.stripped_op_list_for_graph(meta_graph_def.GraphDef));
meta_graph_def.AssetFileDef.AddRange(asset_info.asset_defs);
// TODO: deal with signatures here.
meta_graph.strip_graph_default_valued_attrs(meta_graph_def);

if (!BitConverter.IsLittleEndian)
{
swap_function_tensor_content(meta_graph_def);
}

return (asset_info, exported_graph);
}

private static void verify_ops(GraphDef graph_def, IEnumerable<string>? namespace_whitelist)
{
return;
// if (namespace_whitelist is null || !namespace_whitelist.Any())
// {
// return;
// }
// skip the check for the lack of `meta_graph.ops_used_by_graph_def`.
}

public static void swap_function_tensor_content(MetaGraphDef meta_graph_def)
{
var functions = meta_graph_def.GraphDef.Library.Function;
foreach (var function in functions)
{
var node_def = function.NodeDef;
foreach (var node in node_def)
{
if (node.Op == "Const")
{
var tensor = node.Attr["value"].Tensor;
byte_swap_tensor_content(tensor);
}
}
}
}

public static void byte_swap_tensor_content(TensorProto tensor)
{
if (byte_swappable.Contains((int)tensor.Dtype))
{
var tshape = tensor.TensorShape.Dim;
var tensor_bytes = tensor.TensorContent;
if (tensor_bytes is not null && !tensor_bytes.IsEmpty)
{
long tensor_size = 1;
foreach (var sz in tshape)
{
tensor_size *= sz.Size;
}

var chunksize = tensor_bytes.Length / tensor_size;
List<byte> reversed_bytes = new();
for (int i = 0; i < tensor_bytes.Length; i += (int)chunksize)
{
var current = tensor_bytes.Skip(i).Take((int)chunksize).Reverse();
reversed_bytes.AddRange(current);
}
tensor.TensorContent = ByteString.CopyFrom(reversed_bytes.ToArray());
}
}
}
}

+ 58
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/signature_serialization.cs View File

@@ -0,0 +1,58 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Train;

namespace Tensorflow;

public class SignatureMap: Trackable
{
private Dictionary<string, Function> _signatures;
private Dictionary<string, ConcreteFunction> _concrete_signatures;

public SignatureMap()
{
_signatures = new();
}

public void _add_signature(string name, ConcreteFunction concrete_function)
{
_concrete_signatures[name] = concrete_function;
}
public void _add_signature(string name, Function concrete_function)
{
_signatures[name] = concrete_function;
}

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
{
if (save_type != SaveType.SAVEDMODEL)
{
return new Dictionary<string, Trackable>();
}

Dictionary<string, Trackable> res = _signatures.ToDictionary(x => x.Key, x => (Trackable)x.Value);
foreach (var pair in _concrete_signatures)
{
res[pair.Key] = pair.Value;
}

return res;
}

public static SignatureMap create_signature_map(IDictionary<string, ConcreteFunction> signatures)
{
var signature_map = new SignatureMap();
foreach (var pair in signatures)
{
var name = pair.Key;
var func = pair.Value;
// TODO: assert the arg_keywords
signature_map._add_signature(name, func);
}

return signature_map;
}
}

+ 52
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs View File

@@ -0,0 +1,52 @@
using System.IO;
using System.Security.Cryptography.X509Certificates;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow;

public static partial class SavedModelUtils
{
/// <summary>
/// Return variables sub-directory, or create one if it doesn't exist.
/// </summary>
/// <returns></returns>
public static string get_or_create_variables_dir(string export_dir)
{
var variables_dir = get_variables_dir(export_dir);
Directory.CreateDirectory(variables_dir);
return variables_dir;
}

/// <summary>
/// Return variables sub-directory in the SavedModel.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_variables_dir(string export_dir)
{
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY));
}

/// <summary>
/// Return assets sub-directory, or create one if it doesn't exist.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_or_create_assets_dir(string export_dir)
{
var assets_destination_dir = get_assets_dir(export_dir);
Directory.CreateDirectory(assets_destination_dir);
return assets_destination_dir;
}

/// <summary>
/// Return path to asset directory in the SavedModel.
/// </summary>
/// <param name="export_dir"></param>
/// <returns></returns>
public static string get_assets_dir(string export_dir)
{
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.ASSETS_DIRECTORY));
}
}

+ 18
- 1
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -17,12 +17,17 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public class saveable_object_util
public static class saveable_object_util
{ {
public class TrackableSaveable: MySaveableObject
{
}
/// <summary> /// <summary>
/// Returns the variables and names that will be used for a Saver. /// Returns the variables and names that will be used for a Saver.
/// </summary> /// </summary>
@@ -121,5 +126,17 @@ namespace Tensorflow


return names_to_saveables; return names_to_saveables;
} }

public static IDictionary<string, ResourceVariable> saveable_objects_from_trackable(Trackable obj)
{
// TODO: complete the implementation.
return obj.gather_saveables_for_checkpoint();
}

public static bool trackable_has_serialize_to_tensor(Trackable obj)
{
// TODO: implement it.
return false;
}
} }
} }

+ 78
- 1
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -14,14 +14,38 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.ModelSaving;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Train namespace Tensorflow.Train
{ {
public abstract class Trackable public abstract class Trackable
{ {
/// <summary>
/// Corresponding to tensorflow/python/trackable/constants.py
/// </summary>
public static class Constants
{
public static readonly string OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH";
public static readonly string VARIABLE_VALUE_KEY = "VARIABLE_VALUE";
public static readonly string OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON";
}
protected int _self_update_uid; protected int _self_update_uid;
protected IDictionary<string, Trackable> _unconditional_dependency_names =
new Dictionary<string, Trackable>();

protected IList<TrackableReference> _unconditional_checkpoint_dependencies = new List<TrackableReference>();


protected IDictionary<string, ResourceVariable> _self_saveable_object_factories =
new Dictionary<string, ResourceVariable>();
public virtual string ObjectIdentifier
{
get => "_generic_user_object";
}
/// <summary> /// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`. /// Restore-on-create for a variable be saved with this `Checkpointable`.
/// </summary> /// </summary>
@@ -73,10 +97,63 @@ namespace Tensorflow.Train
/// <summary> /// <summary>
/// Initialize dependency management. /// Initialize dependency management.
/// </summary> /// </summary>
protected void _maybe_initialize_trackable()
public void _maybe_initialize_trackable()
{ {
// _self_unconditional_checkpoint_dependencies = [] // _self_unconditional_checkpoint_dependencies = []
_self_update_uid = -1; _self_update_uid = -1;
} }

// TODO: cache
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, object>? cache = null)
{
_maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
}

public static Trackable convert_to_trackable(object obj, object? parent = null)
{
if (obj is Trackable)
{
return (Trackable)obj;
}
else
{
throw new NotImplementedException();
}
}

public virtual IDictionary<string, Trackable> deserialization_dependencies(IDictionary<string, Trackable> children)
{
return new Dictionary<string, Trackable>();
}

public virtual (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(
SaveOptions? save_options)
{
return (new Dictionary<Trackable, Trackable>(), new Dictionary<Tensor, Tensor>());
}

public virtual List<Tensor> export_to_saved_model_graph(IDictionary<Trackable, Trackable>? object_map = null,
IDictionary<Tensor, Tensor>? tensor_map = null, SaveOptions? options = null)
{
var (self_object_map, self_tensor_map) = map_resources(options);
foreach (var pair in self_object_map)
{
object_map.Add(pair);
}
foreach (var pair in self_tensor_map)
{
tensor_map.Add(pair);
}

return self_tensor_map.Keys.ToList();
}

public virtual IDictionary<string, ResourceVariable> gather_saveables_for_checkpoint()
{
return _self_saveable_object_factories;
}
} }

public record class TrackableReference(string Name, Trackable Refer);
} }

+ 148
- 0
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -0,0 +1,148 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Exceptions;
using Tensorflow.Train;

namespace Tensorflow.Training;

public static class TrackableUtils
{
public class CyclicDependencyError: System.Exception
{
public IDictionary<int, IEnumerable<int>> LeftOverDependencyMap { get; }
public CyclicDependencyError(IDictionary<int, IEnumerable<int>> leftover_dependency_map): base()
{
LeftOverDependencyMap = leftover_dependency_map;
}
public CyclicDependencyError(IDictionary<int, List<int>> leftover_dependency_map): base()
{
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable());
}
}
private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
private static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{
return string.Join("/", node_path_arr.Select(x => escape_local_name(x.Name)));
}

public static string escape_local_name(string name)
{
return name.Replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR).Replace("/", _ESCAPE_CHAR + "S");
}
public static string checkpoint_key(string object_path, string local_name)
{
var key_suffix = escape_local_name(local_name);
if (local_name == SERIALIZE_TO_TENSORS_NAME)
{
key_suffix = "";
}

return $"{object_path}/{OBJECT_ATTRIBUTES_NAME}/{key_suffix}";
}

/// <summary>
/// Topologically sorts the keys of a map so that dependencies appear first.
/// Uses Kahn's algorithm: https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
/// </summary>
/// <param name="dependency_map"></param>
/// <exception cref="ValueError"></exception>
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map)
{
Dictionary<int, HashSet<int>> reverse_dependency_map = new();
foreach (var pair in dependency_map)
{
foreach (var dep in pair.Value)
{
if (reverse_dependency_map.ContainsKey(dep))
{
reverse_dependency_map[dep].Add(pair.Key);
}
else
{
reverse_dependency_map[dep] = new HashSet<int>();
reverse_dependency_map[dep].Add(pair.Key);
}
}
}
// Validate that all values in the dependency map are also keys.
var unknown_keys = reverse_dependency_map.Keys.Except(dependency_map.Keys);
if (unknown_keys.Count() > 0)
{
throw new ValueError(
$"Found values in the dependency map which are not keys: {string.Join(", ", unknown_keys.Select(x => x.ToString()))}");
}
// Generate the list sorted by objects without dependencies -> dependencies.
// The returned list will reverse this.
List<int> reversed_dependency_arr = new();

Queue<int> to_visit = new();
foreach (var x in dependency_map.Keys)
{
if (!reverse_dependency_map.ContainsKey(x))
{
to_visit.Enqueue(x);
}
}

while (to_visit.Count > 0)
{
var x = to_visit.Dequeue();
reversed_dependency_arr.Add(x);
foreach (var dep in dependency_map[x].Distinct())
{
var edges = reverse_dependency_map[dep];
edges.Remove(x);
if (edges.Count == 0)
{
to_visit.Enqueue(dep);
if (!reverse_dependency_map.Remove(dep))
{
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map");
}
}
}
}

if (reverse_dependency_map.Count > 0)
{
Dictionary<int, List<int>> leftover_dependency_map = new();
foreach (var pair in reverse_dependency_map)
{
foreach (var x in pair.Value)
{
if (leftover_dependency_map.ContainsKey(x))
{
leftover_dependency_map[x].Add(pair.Key);
}
else
{
leftover_dependency_map[x] = new List<int>() { pair.Key };
}
}
}

throw new CyclicDependencyError(leftover_dependency_map);
}

reversed_dependency_arr.Reverse();
return reversed_dependency_arr;
}

public static string pretty_print_node_path(IEnumerable<TrackableReference> paths)
{
if (paths.Count() == 0)
{
return "root object";
}
else
{
return $"root.{string.Join(".", paths.Select(x => x.Name))}";
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -2,6 +2,7 @@
using System; using System;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Variables; using Tensorflow.Variables;
using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow


+ 18
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -566,5 +566,23 @@ namespace Tensorflow
else else
throw new NotImplementedException(""); throw new NotImplementedException("");
} }

public static bool inside_function()
{
return get_default_graph().building_function;
}

public static void dismantle_graph(Graph graph)
{
}

public class NullContextManager: IDisposable
{
public void Dispose()
{
}
}
} }
} }

+ 31
- 0
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -0,0 +1,31 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Keras.Engine;

public abstract partial class Layer
{
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);

public string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;

public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, object>? cache = null)
{
IDictionary<string, Trackable> children;
if (save_type == SaveType.SAVEDMODEL)
{
// TODO: deal with cache.
children = TrackableSavedModelSaver.trackable_children(cache);
}
else
{
children = new Dictionary<string, Trackable>();
}

return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value);
}
}

+ 3
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -49,6 +49,8 @@ namespace Tensorflow.Keras.Engine
public bool Built => built; public bool Built => built;
public bool Trainable => args.Trainable; public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType; public TF_DataType DType => args.DType;
public bool AutoCast => args.Autocast;
public IRegularizer ActivityRegularizer => args.ActivityRegularizer;


/// <summary> /// <summary>
/// A stateful layer is a layer whose updates are run during inference too, /// A stateful layer is a layer whose updates are run during inference too,
@@ -162,7 +164,7 @@ namespace Tensorflow.Keras.Engine
/// </summary> /// </summary>
/// <param name="inputs"></param> /// <param name="inputs"></param>
/// <param name="state"></param> /// <param name="state"></param>
/// <param name="is_training"></param>
/// <param name="training"></param>
/// <returns></returns> /// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {


+ 13
- 2
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -1,5 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving; using Tensorflow.ModelSaving;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
@@ -18,9 +20,18 @@ namespace Tensorflow.Keras.Engine
bool overwrite = true, bool overwrite = true,
bool include_optimizer = true, bool include_optimizer = true,
string save_format = "tf", string save_format = "tf",
SaveOptions options = null)
SaveOptions? options = null,
IDictionary<string, ConcreteFunction>? signatures = null,
bool save_traces = true)
{ {
saver.save(this, filepath);
if (save_format != "pb")
{
saver.save(this, filepath);
}
else
{
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
}
} }
} }
} }

+ 6
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -35,6 +35,12 @@ namespace Tensorflow.Keras.Engine
bool _base_model_initialized; bool _base_model_initialized;
bool stop_training; bool stop_training;
DataHandler data_handler; DataHandler data_handler;
public OptimizerV2 Optimizer
{
get => optimizer;
set => optimizer = value;
}


public Model(ModelArgs args) public Model(ModelArgs args)
: base(args) : base(args)


+ 12
- 0
src/TensorFlowNET.Keras/Protobuf/SavedMetadata.cs View File

@@ -194,6 +194,18 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf {
OnConstruction(); OnConstruction();
} }


[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public SavedObject(int nodeId, string nodePath,
global::ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef version, string identifier, string metadata)
{
OnConstruction();
nodeId_ = nodeId;
nodePath_ = nodePath;
identifier_ = identifier;
metadata_ = metadata;
version_ = version;
}

partial void OnConstruction(); partial void OnConstruction();


[global::System.Diagnostics.DebuggerNonUserCodeAttribute] [global::System.Diagnostics.DebuggerNonUserCodeAttribute]


+ 7
- 0
src/TensorFlowNET.Keras/Protobuf/Versions.cs View File

@@ -74,6 +74,13 @@ namespace ThirdParty.Tensorflow.Python.Keras.Protobuf {
public VersionDef() { public VersionDef() {
OnConstruction(); OnConstruction();
} }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public VersionDef(int producer, int minConsumer) {
OnConstruction();
producer_ = producer;
minConsumer_ = minConsumer;
}


partial void OnConstruction(); partial void OnConstruction();




+ 41
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/Constants.cs View File

@@ -0,0 +1,41 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.Saving.SavedModel;

public static class Constants
{
/// <summary>
/// Namespace used to store all attributes added during serialization.
/// e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an
/// object loaded from `tf.saved_model.load()`.
/// </summary>
public static readonly string KERAS_ATTR = "keras_api";
/// <summary>
/// Keys for the serialization cache.
/// Maps to the keras serialization dict {Layer --> SerializedAttributes object}
/// </summary>
public static readonly string KERAS_CACHE_KEY = "keras_serialized_attributes";
/// <summary>
/// Name of Keras metadata file stored in the SavedModel.
/// </summary>
public static readonly string SAVED_METADATA_PATH = "keras_metadata.pb";
public static readonly string INPUT_LAYER_IDENTIFIER = "_tf_keras_input_layer";
public static readonly string LAYER_IDENTIFIER = "_tf_keras_layer";
public static readonly string METRIC_IDENTIFIER = "_tf_keras_metric";
public static readonly string MODEL_IDENTIFIER = "_tf_keras_model";
public static readonly string NETWORK_IDENTIFIER = "_tf_keras_network";
public static readonly string RNN_LAYER_IDENTIFIER = "_tf_keras_rnn_layer";
public static readonly string SEQUENTIAL_IDENTIFIER = "_tf_keras_sequential";

public static readonly IList<string> KERAS_OBJECT_IDENTIFIERS = new List<string>()
{
INPUT_LAYER_IDENTIFIER,
LAYER_IDENTIFIER,
METRIC_IDENTIFIER,
MODEL_IDENTIFIER,
NETWORK_IDENTIFIER,
RNN_LAYER_IDENTIFIER,
SEQUENTIAL_IDENTIFIER
};
}

+ 11
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/KerasObjectWrapper.cs View File

@@ -0,0 +1,11 @@
namespace Tensorflow.Keras.Saving.SavedModel;

public class KerasObjectWrapper
{
}

public class KerasObjectWrapper<T>
{
public T Item { get; set; }
}

+ 115
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -0,0 +1,115 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Google.Protobuf;
using ICSharpCode.SharpZipLib.Zip;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using Tensorflow.IO;
using Tensorflow.Keras.Optimizers;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, IDictionary<string, ConcreteFunction>? signatures,
SaveOptions? options, bool save_traces = true)
{
if (!overwrite && File.Exists(filepath))
{
throw new Exception("The file already exists but is not allowed to overwrite it.");
}

if (save_traces)
{
if(should_skip_serialization(model))
{
throw new NotImplementedException();
}
}

OptimizerV2? orig_optimizer = null;
if (!include_optimizer)
{
orig_optimizer = model.Optimizer;
model.Optimizer = null;
model._delete_tracking("optimizer");
}

IList<Trackable> saved_nodes;
IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths;
// skip two scopes of python
using (KerasSavedModelUtils.keras_option_scope(save_traces))
{
(saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options);
}

var metadata = generate_keras_metadata(saved_nodes, node_paths);
using (var f = new FileStream(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), FileMode.OpenOrCreate,
FileAccess.Write))
{
var writer = new StreamWriter(f);
writer.Write(metadata.ToString());
}

if (!include_optimizer)
{
model.Optimizer = orig_optimizer!;
}
}

public static SavedMetadata generate_keras_metadata(IList<Trackable> saved_nodes,
IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths)
{
var metadata = new SavedMetadata();
for (int i = 0; i < saved_nodes.Count; i++)
{
var node = saved_nodes[i];
if (node is not Layer)
{
continue;
}

Layer layer = (Layer)node;

var path = node_paths[node];
string node_path;
if (path is null)
{
node_path = "root";
}
else
{
node_path = $"root.{string.Join(".", path.Select(x => x.Name))}";
}
ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new()
{
NodeId = i,
NodePath = node_path,
Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef()
{
Producer = 2,
MinConsumer = 1,
BadConsumers = { }
},
Identifier = layer.ObjectIdentifier,
Metadata = layer.TrackingMetadata
};
}

return metadata;
}

}

+ 19
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs View File

@@ -0,0 +1,19 @@
using System.Collections.Generic;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static bool should_skip_serialization(object layer)
{
return false;
}

public static IDictionary<string, KerasObjectWrapper> wrap_layer_objects(Layer layer, object serialization_cache)
{
// TODO: process the loss

return null;
}
}

+ 40
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs View File

@@ -0,0 +1,40 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Engine;
using Newtonsoft.Json;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel;

public abstract class SavedModelSaver
{
private Trackable _obj;
public SavedModelSaver(Trackable obj)
{
_obj = obj;
}

public abstract string ObjectIdentifier { get; }
public abstract string TrackingMetadata { get; }

public abstract IDictionary<string, CheckpointableBase> objects_to_serialize(
IDictionary<string, object> serialization_cache);

public abstract IDictionary<string, Function> functions_to_serialize(
IDictionary<string, object> serialization_cache);

public IDictionary<string, Trackable> trackable_children(IDictionary<string, object>? serialization_cache)
{
if (!KerasSavedModelUtils.ShouldHaveTraces)
{
return new Dictionary<string, Trackable>();
}

var children = objects_to_serialize(serialization_cache);

return children.ToDictionary(x => x.Key, x => (Trackable)x.Value)
.Concat(functions_to_serialize(serialization_cache).ToDictionary(x => x.Key, x => (Trackable)x.Value))
.ToDictionary(x => x.Key, x => x.Value);
}
}

+ 62
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -0,0 +1,62 @@
using System.Collections.Generic;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;

namespace Tensorflow.Keras.Saving.SavedModel;

public class LayerSavedModelSaver: SavedModelSaver
{
private Layer _obj;
public LayerSavedModelSaver(Layer obj): base(obj)
{
_obj = obj;
}
public override string ObjectIdentifier
{
get => Constants.LAYER_IDENTIFIER;
}

public override IDictionary<string, CheckpointableBase> objects_to_serialize(IDictionary<string, object> serialization_cache)
{
throw new System.NotImplementedException();
}

public override IDictionary<string, Function> functions_to_serialize(IDictionary<string, object> serialization_cache)
{
throw new System.NotImplementedException();
}

public override string TrackingMetadata
{
get
{
JObject metadata = new JObject();
metadata["name"] = _obj.Name;
metadata["trainable"] = _obj.Trainable;
// metadata["expects_training_arg"] = _obj._expects_training_arg;
// metadata["dtype"] = policy.serialize(_obj._dtype_policy)
metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape);
// metadata["stateful"] = _obj.stateful;
// metadata["must_restore_from_config"] = _obj.must_restore_from_config;
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
metadata["autocast"] = _obj.AutoCast;
metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings
{
// Handle conflicts by using values from obj2
MergeArrayHandling = MergeArrayHandling.Merge
});
// skip the check of `input_spec` and `build_input_shape` for the lack of members.
// skip the check of `activity_regularizer` for the type problem.
return metadata.ToString();
}
}

public static LayerConfig get_serialized(Layer obj)
{
return generic_utils.serialize_keras_object(obj);
}
}

+ 33
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/utils.cs View File

@@ -0,0 +1,33 @@
using System;

namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static bool ShouldHaveTraces { get; internal set; }

public static SaveOptionsContext keras_option_scope(bool save_traces)
{
var res = new SaveOptionsContext(ShouldHaveTraces);
ShouldHaveTraces = save_traces;
return res;
}
}

/// <summary>
/// Implementation of this class is different with that of python.
/// But it could be used with `using` the same as `with` of python.
/// </summary>
public class SaveOptionsContext: IDisposable
{
public bool _old_value;
public SaveOptionsContext(bool old_value)
{
_old_value = true;
}

public void Dispose()
{
KerasSavedModelUtils.ShouldHaveTraces = _old_value;
}
}

+ 60
- 0
test/TensorFlowNET.Keras.UnitTest/SaveTest.cs View File

@@ -0,0 +1,60 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;

namespace TensorFlowNET.Keras.UnitTest;

// class MNISTLoader
// {
// public MNISTLoader()
// {
// var mnist = new MnistModelLoader()
//
// }
// }

[TestClass]
public class SaveTest
{
[TestMethod]
public void Test()
{
var inputs = new KerasInterface().Input((28, 28, 1));
var x = new Flatten(new FlattenArgs()).Apply(inputs);
x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
x = new LayersApi().Dense(units: 10).Apply(x);
var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
var model = new KerasInterface().Model(inputs, outputs);
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"});

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
}).Result;
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.save("", save_format:"pb");
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj View File

@@ -47,7 +47,7 @@


<ItemGroup> <ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" /> <PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.8" /> <PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.8" /> <PackageReference Include="MSTest.TestFramework" Version="2.2.8" />


Loading…
Cancel
Save