Browse Source

Add essential components of SavedModel format loading.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
5df6e5ddb5
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
36 changed files with 1945 additions and 63 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +217
    -0
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  3. +80
    -0
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  4. +12
    -0
      src/TensorFlowNET.Core/IO/gfile.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  6. +1
    -0
      src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  8. +23
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs
  9. +8
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs
  10. +6
    -6
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs
  11. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  12. +601
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  13. +122
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs
  14. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  15. +0
    -1
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs
  16. +15
    -0
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  17. +14
    -1
      src/TensorFlowNET.Core/Training/Trackable.cs
  18. +1
    -0
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  19. +6
    -2
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  20. +18
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  21. +1
    -1
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  22. +5
    -0
      src/TensorFlowNET.Keras/Engine/Functional.cs
  23. +19
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  25. +8
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  26. +10
    -5
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  27. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  28. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  29. +1
    -2
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  30. +521
    -33
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  31. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  32. +96
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
  33. +69
    -0
      src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs
  34. +29
    -0
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  35. +45
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  36. +2
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs

+ 9
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -149,4 +149,13 @@ public static class CheckPointUtils
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}

/// <summary>
/// Traverse the object graph and list all accessible objects.
/// </summary>
/// <param name="object_graph_view"></param>
public static IList<Trackable> list_objects(ObjectGraphView graph_view)
{
return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
}
}

+ 217
- 0
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -6,8 +6,10 @@ using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding;
using Tensorflow.Operations;

namespace Tensorflow.Checkpoint;

@@ -21,8 +23,20 @@ public class TrackableSaver
private TrackableObjectGraph _last_save_object_graph;
private Tensor? _object_graph_feed_tensor = null;
private Tensor? _file_prefix_feed_tensor = null;
private Tensor? _file_prefix_placeholder = null;
private Dictionary<Trackable, Trackable>? _object_map = null;
private object? _cache = null;
public Tensor? FilePrefixPlaceHolder
{
get
{
return _file_prefix_placeholder;
}
set
{
_file_prefix_placeholder = value;
}
}
public TrackableSaver(ObjectGraphView graph_view)
{
_graph_view = graph_view;
@@ -192,4 +206,207 @@ public class TrackableSaver
return save_path;
}
}

public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
{
if (options is null)
{
options = new CheckpointOptions();
}
if(save_path is null)
{
return new InitializationOnlyStatus(_graph_view, ops.uid());
}

CheckpointReader reader = new CheckpointReader(save_path);
bool graph_building = tf.Context.executing_eagerly();
Dictionary<string, TF_DataType> dtype_map = null;
if (!graph_building)
{
dtype_map = reader.VariableToDataTypeMap;
}
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY);

Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor;
if (graph_building)
{
if(_file_prefix_placeholder is null)
{
tf.device("/cpu:0");
_file_prefix_placeholder = constant_op.constant("model");
}
file_prefix_tensor = _file_prefix_placeholder;
file_prefix_feed_dict = new();
file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
}
else
{
tf.device("/cpu:0");
file_prefix_tensor = constant_op.constant(save_path);
file_prefix_feed_dict = null;
}
TrackableObjectGraph object_graph_proto = new();
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
object_graph_proto: object_graph_proto,
save_path: save_path,
save_path_tensor: file_prefix_tensor,
reader: reader,
restore_op_cache: null,
graph_view: _graph_view,
options: options,
saveables_cache: null
);

throw new NotImplementedException();
}
}

internal class CheckpointRestoreCoordinator
{
private CheckpointOptions _options;
private TrackableObjectGraph _object_graph_proto;
private int _restore_uid;
private HashSet<int> _matched_proto_ids;
private Tensor _save_path_tensor;
private string _save_path_string;
private CheckpointReader _reader;
private Dictionary<string, TF_DataType> _dtype_map;
private Dictionary<string, Shape> _shape_map;
private ObjectGraphView _graph_view;
private Dictionary<int, IList<SlotVariableRestoration>> _slot_restorations;
private bool _expect_partial_attr;
private List<Operation> _restore_ops;
private List<Trackable> _all_trackables;
private Dictionary<int, Trackable> _object_by_proto_id;

public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor,
CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache)
{
// TODO(Rinne): cache.
_options = options;
_object_graph_proto = object_graph_proto;
_restore_uid = ops.uid();
_save_path_tensor = save_path_tensor;
_save_path_string = save_path;
_reader = reader;
if(_reader is null)
{
_reader = new CheckpointReader(save_path);
}
_dtype_map = _reader.VariableToDataTypeMap;
_shape_map = _reader.VariableToShapeMap;
_graph_view = graph_view;
_restore_ops = new List<Operation>();
_all_trackables = new List<Trackable>();
_matched_proto_ids = new HashSet<int>();
_object_by_proto_id = new Dictionary<int, Trackable>();
_slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>();

_expect_partial_attr = false;
for(int i = 0; i < _object_graph_proto.Nodes.Count; i++)
{
var node = _object_graph_proto.Nodes[i];
foreach(var slot_reference in node.SlotVariables)
{
_slot_restorations.SetDefault(slot_reference.OriginalVariableNodeId, new List<SlotVariableRestoration>())
.Add(new SlotVariableRestoration(i, slot_reference.SlotVariableNodeId, slot_reference.SlotName));
}
}

// skip the deleter and cache.
}

public bool ExpectPartial
{
get
{
return _expect_partial_attr;
}
set
{
_expect_partial_attr = value;
}
}

public List<Trackable> AllTrackables => _all_trackables;
public HashSet<int> MatchedProtoIds => _matched_proto_ids;
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id;
public int RestoreUid => _restore_uid;

public void new_restore_ops(IEnumerable<Operation> new_ops)
{
_restore_ops.AddRange(new_ops);
// skip the callback.
}

public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null)
{
throw new NotImplementedException();
}
}

public abstract class LoadStatus
{
public abstract void assert_consumed();
public abstract void assert_existing_objects_matched();
public abstract void assert_nontrivial_match();
public abstract void run_restore_ops(Session? session = null);
public abstract void initialize_or_restore(Session? session = null);
public virtual LoadStatus expect_partial()
{
return this;
}
}

public class InitializationOnlyStatus: LoadStatus
{
private int _restore_uid;
private ObjectGraphView _object_graph_view;
private Trackable _root;
public InitializationOnlyStatus(ObjectGraphView object_graph_view, int restore_uid)
{
_restore_uid = restore_uid;
_object_graph_view = object_graph_view;
_root = object_graph_view.Root;
}
public override void assert_consumed()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override void assert_existing_objects_matched()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override void assert_nontrivial_match()
{
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
}
public override void run_restore_ops(Session? session = null)
{
throw new AssertionError("No checkpoint specified, so no restore ops are available "
+ "(save_path=None to Saver.restore).");
}
public override void initialize_or_restore(Session? session = null)
{
if (tf.Context.executing_eagerly())
{
return;
}
if(session is null)
{
session = new Session();
}
var trackable_objects = CheckPointUtils.list_objects(_object_graph_view);
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}

public class CheckpointLoadStatus
{
public CheckpointLoadStatus()
{

}
}

+ 80
- 0
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -0,0 +1,80 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;

internal class CheckpointPosition
{
private CheckpointRestoreCoordinator _checkpoint;
private int _proto_id;
private bool _skip_restore;
public CheckpointPosition(CheckpointRestoreCoordinator checkpoint, int proto_id)
{
_checkpoint = checkpoint;
_proto_id = proto_id;
_skip_restore = false;
}

public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id];

public void restore(Trackable trackable)
{
using (ops.init_scope())
{
if (bind_project(trackable))
{

}
}
}

/// <summary>
/// Set a checkpoint<->object correspondence.
/// </summary>
/// <param name="trackable"></param>
/// <returns></returns>
public bool bind_project(Trackable trackable)
{
_checkpoint.AllTrackables.Add(trackable);
_checkpoint.MatchedProtoIds.Add(_proto_id);
if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment))
{
// skip the `logging.warning`.
return false;
}
else
{
_checkpoint.ObjectByProtoId[_proto_id] = trackable;
return true;
}
}

public void gather_ops_or_named_saveables()
{
// skip the registered_saver


}

/// <summary>
/// Restore the bound Trackable and dependencies (may be deferred).
/// </summary>
private void _restore_descendants()
{
Queue<(CheckpointPosition, Trackable)> visit_queue = new();
visit_queue.Enqueue((this, this.Trackable));

}

private void _single_restore()
{
var trackable = this.Trackable;
trackable._maybe_initialize_trackable();
if(_checkpoint.RestoreUid > trackable.UpdateUid)
{

}
}
}

+ 12
- 0
src/TensorFlowNET.Core/IO/gfile.cs View File

@@ -16,8 +16,10 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.IO
{
@@ -63,5 +65,15 @@ namespace Tensorflow.IO
dirs.AddRange(Directory.GetFiles(dir));
return dirs.ToArray();
}

public string join(params string[] paths)
{
Debug.Assert(paths.Length >= 1);
if (paths[0].Substring(1).Contains("://"))
{
throw new NotImplementedException("The combination of urls has not been implemented.");
}
return Path.Combine(paths);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Common

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var axis = serializer.Deserialize(reader, typeof(long[]));
var axis = serializer.Deserialize(reader, typeof(int[]));
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");


+ 1
- 0
src/TensorFlowNET.Core/ModelSaving/ModelSaver.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow.ModelSaving
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,8 +17,8 @@
using System;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;



+ 23
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/LoadOptions.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public record class LoadOptions
{
public bool allow_partial_checkpoint;
public string experimental_io_device;
public bool experimental_skip_checkpoint;
public VariablePolicy experimental_variable_policy;

public LoadOptions(bool allow_partial_checkpoint = false, string experimental_io_device = null,
bool experimental_skip_checkpoint = false, string experimental_variable_policy = null)
{
this.allow_partial_checkpoint = allow_partial_checkpoint;
this.experimental_io_device = experimental_io_device;
this.experimental_skip_checkpoint = experimental_skip_checkpoint;
this.experimental_variable_policy = VariablePolicy.from_obj(experimental_variable_policy);
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Train;
using System;
using Tensorflow.Train;

namespace Tensorflow;

@@ -14,4 +15,10 @@ public class RevivedTypes
// TODO: complete the implementation.
return null;
}

public static Tuple<Trackable, Action<object, object, object>> deserialize(object proto)
{
// TODO: complete the implementation.
return null;
}
}

src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs → src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveOptions.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.ModelSaving
namespace Tensorflow
{
/// <summary>
/// Options for saving to SavedModel.
@@ -35,7 +35,7 @@ namespace Tensorflow.ModelSaving

public bool save_variable_devices()
{
return this != VariablePolicy.None;
return this != None;
}

/// <summary>
@@ -45,14 +45,14 @@ namespace Tensorflow.ModelSaving
/// <returns></returns>
public static VariablePolicy from_obj(object obj)
{
if (obj is null) return VariablePolicy.None;
if (obj is null) return None;
if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower();
return key switch
{
null => VariablePolicy.None,
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
null => None,
"save_variable_devices" => SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
};
}

+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -5,7 +5,6 @@ using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;


+ 601
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -0,0 +1,601 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using Tensorflow.Checkpoint;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using System.Runtime.CompilerServices;
using Tensorflow.Variables;

namespace Tensorflow
{
/// <summary>
/// Helper class to load an object-based SavedModel.
/// </summary>
public partial class Loader
{
private pbc::RepeatedField<global::Tensorflow.AssetFileDef> _asset_file_def;
private Dictionary<string, pbc::MapField<string, AttrValue>> _operation_attributes;
private SavedObjectGraph _proto;
private string _export_dir;
private CheckpointOptions _checkpoint_options;
private LoadOptions _save_options;
private IDictionary<string, (Trackable, Action<object, object, object>)> _node_filters;
private Dictionary<string, int>? _node_path_to_id;
private List<int>? _filtered_nodes;
private List<int> _ordered_node_ids;
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes;
private List<Trackable> _nodes;
private Dictionary<int, Action<object, object, object>> _node_setters;
public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto, string export_dir,
CheckpointOptions ckpt_options, LoadOptions save_options, IDictionary<string, (Trackable, Action<object, object, object>)> filters)
{
var meta_graph = saved_model_proto.MetaGraphs[0];
_asset_file_def = meta_graph.AssetFileDef;
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
_proto = object_graph_proto;
_export_dir = export_dir;
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
_checkpoint_options = ckpt_options;
_save_options = save_options;

// TODO: `this._pretty_printer`

_node_filters = filters;
_node_path_to_id = _convert_node_paths_to_ints();
_loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
foreach(var filter in filters)
{
_loaded_nodes[_node_path_to_id[filter.Key]] = filter.Value;
}

_filtered_nodes = _retrieve_all_filtered_nodes();

_ordered_node_ids = _generate_ordered_node_ids();

_load_all();


if (!save_options.experimental_skip_checkpoint)
{
// TODO: implement it.
}
foreach(var node in _nodes)
{
// skip the process of `CapturableResource`.
}
}

/// <summary>
/// Maps all string node paths in node_filters to the int node ids.
/// </summary>
/// <returns></returns>
private Dictionary<string, int>? _convert_node_paths_to_ints()
{
if( _node_filters is null)
{
return null;
}
Dictionary<string, int> path_to_int = new();
foreach(var node_id in _node_filters.Keys)
{
int int_node_id;
var node_path = node_id.Split('.');
if (node_path[0] != "root")
{
throw new ValueError($"When passing string identifiers to node_filters, the first name" +
$" must be root. Received {node_path[0]}.");
}
int_node_id = 0;
for(int i = 0; i < node_path.Length - 1; i++)
{
var name = node_path[i + 1];
int_node_id = _find_node_child(int_node_id, name, String.Join(".", node_path.Take(i + 1)));
}
path_to_int[node_id] = int_node_id;
}
return path_to_int;
}

private int _find_node_child(int node_id, string child_name, string path)
{
foreach(var refer in _proto.Nodes[node_id].Children)
{
if(refer.LocalName == child_name)
{
return refer.NodeId;
}
}
throw new ValueError($"Unable to find node {path}.");
}

private List<int>? _retrieve_all_filtered_nodes()
{
if(_node_filters is null)
{
return null;
}

HashSet<int> all_filtered_nodes = new();
Queue<string> nodes_to_visit = new Queue<string>(_node_filters.Keys);

while(nodes_to_visit.Count > 0)
{
var node_path = nodes_to_visit.Dequeue();
var node_id = _node_path_to_id[node_path];
if (all_filtered_nodes.Contains(node_id))
{
continue;
}
all_filtered_nodes.Add(node_id);
Trackable node = null;
Action<object, object, object> setter = null;
if(_loaded_nodes.TryGetValue(node_id, out var res))
{
(node, setter) = res;
}
if(node is not null)
{
node._maybe_initialize_trackable();
}

foreach(var refer in _proto.Nodes[node_id].Children)
{
Trackable children_object = null;
if(_loaded_nodes.TryGetValue(refer.NodeId, out var result))
{
children_object = result.Item1;
}
// See if node already tracks the child reference, in which case add the child to the loaded_nodes dict.
if(children_object is null && node is not null)
{
children_object = node._lookup_dependency(refer.LocalName);
if(children_object is TrackableDataStructure)
{
// TODO: set setter as lambda.

_loaded_nodes[refer.NodeId] = (children_object, setter);
}
}
string child_path = $"{node_path}.{refer.LocalName}";
_node_path_to_id[child_path] = refer.NodeId;
nodes_to_visit.Enqueue(child_path);
}
}

if (all_filtered_nodes.Contains(0))
{
return null;
}
return all_filtered_nodes.ToList();
}

/// <summary>
/// Orders the node ids so that dependencies appear first.
/// </summary>
/// <returns></returns>
private List<int> _generate_ordered_node_ids()
{
List<int> unordered_ids;
if(_filtered_nodes is null)
{
unordered_ids = Enumerable.Range(0, _proto.Nodes.Count).ToList();
}
else
{
unordered_ids = new List<int>(_filtered_nodes);
}

Dictionary<int, List<int>> dependency_map = new();
foreach(var node_id in unordered_ids)
{
var deps = dependency_map.SetDefault(node_id, new List<int>());
if (_loaded_nodes.ContainsKey(node_id))
{
continue;
}
var proto = _proto.Nodes[node_id];
foreach(var dep in _get_node_dependencies(proto).Values.Distinct())
{
deps.Add(dep);
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep))
{
// TODO: add info with `_pretty_printer`.
throw new ValueError($"Unable to partially load SavedModel since the specified filter " +
$"does not include all required objects for loading (e.g. " +
$"variables used in functions or deserialization dependencies). " +
$"Please include this path in the filter: {dep}");
}
}
int? prev_slot = null;
foreach(var slot_variable_proto in proto.SlotVariables)
{
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId;
// The optimizer and original variable must be created before the slot
// variable, since the slot variable is generated using the Optimizer's
// add_slot API.
var slot_deps = dependency_map[slot_variable_node_id];
slot_deps.Add(node_id);
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId);

if(prev_slot is not null)
{
slot_deps.Add(prev_slot.Value);
}
prev_slot = slot_variable_node_id;
}
}
try
{
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>));
}
catch (TrackableUtils.CyclicDependencyError ex)
{
throw new ValueError("Encountered a cycle in the deserialization dependencies" +
"in the SavedModel. This is extremely unexpected, please" +
"file a bug and make sure you are not manually modifying the SavedModel.");
}
}

/// <summary>
/// Returns a dictionary of all dependencies of an object.
/// </summary>
/// <param name="proto"></param>
/// <returns></returns>
private Dictionary<Maybe<string, int>, int> _get_node_dependencies(SavedObject proto)
{
Dictionary<Maybe<string, int>, int> dependencies = new();
foreach(var refer in proto.Dependencies)
{
dependencies[refer.LocalName] = refer.NodeId;
}
if(proto.KindCase == SavedObject.KindOneofCase.Function)
{
var concreete_functions = proto.Function.ConcreteFunctions;
foreach(var fn_name in concreete_functions)
{
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs)
{
dependencies[bound_input] = bound_input;
}
}
}
else if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction)
{
var fn_name = proto.BareConcreteFunction.ConcreteFunctionName;
foreach(var bound_input in _proto.ConcreteFunctions[fn_name].BoundInputs)
{
dependencies[bound_input] = bound_input;
}
}
else if(proto.KindCase == SavedObject.KindOneofCase.Resource)
{
foreach(var child in proto.Children)
{
if(child.LocalName == "_create_resource")
{
dependencies["_create_resource"] = child.NodeId;
}
}
}
return dependencies;
}

/// <summary>
/// Loads all nodes and functions from the SavedModel and their edges.
/// </summary>
private void _load_all()
{
_load_nodes();
_load_edges();

_setup_remaining_functions();
_load_checkpoint_save_and_restore_functions();
}

/// <summary>
/// Restores the checkpoint-related save/restore functions to all nodes.
/// </summary>
private void _load_checkpoint_save_and_restore_functions()
{
foreach(var (node_id, proto) in _iter_all_nodes())
{
var node = get(node_id);
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
// Restore Trackable serialize- and restore-from-tensor functions.
Debug.Assert(proto.SaveableObjects.Count == 1);
var saveable_object_proto = proto.SaveableObjects.Values.First();
var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction;

throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
else
{
// Restore legacy SaveableObject functions.
Dictionary<string, (Trackable, Trackable)> saveable_fn_by_name = new();
foreach(var item in proto.SaveableObjects)
{
var name = item.Key;
var saveable_object_proto = item.Value;
var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction;
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id));
}
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
}
}
}

/// <summary>
/// Load all saved objects.
/// </summary>
private void _load_nodes()
{
// `nodes` maps from node ids to recreated objects
// `node_setters` maps from node ids to setter functions
// (same signature as setattr) for setting children.
var (nodes, node_setters) = _initialize_loaded_nodes();

Dictionary<int, (int, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference)>
slot_variable_node_ids = new();

foreach(var (node_id, proto) in _iter_all_nodes())
{
foreach(var slot_variable_proto in proto.SlotVariables)
{
var slot_variable_node_id = slot_variable_proto.SlotVariableNodeId;
slot_variable_node_ids[slot_variable_node_id] = (node_id, slot_variable_proto);
}
}

// Re-create everything.
foreach (var (node_id, proto) in _iter_all_nodes())
{
if (nodes.ContainsKey(node_id))
{
continue;
}
else if (slot_variable_node_ids.ContainsKey(node_id))
{
// Use the public Optimizer interface when creating slot variables.
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id];
var optimizer_object = nodes[optimizer_node_id];
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];

// TODO: implement it.
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}
else
{
var (node, setter) = _recreate(proto, node_id, nodes);
nodes[node_id] = node;
node_setters[node_id] = setter;
}
}

if (!nodes.ContainsKey(0))
{
nodes[0] = _recreate_base_user_object().Item1;
}
_nodes = new List<Trackable>();
for(int i = 0; i < _proto.Nodes.Count; i++)
{
_nodes.Add(nodes[i]);
}
_node_setters = node_setters;
}

/// <summary>
/// Load state from checkpoint into the deserialized objects.
/// </summary>
private void _restore_checkpoint()
{
var variables_path = SavedModelUtils.get_variables_dir(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0)));
tf.device("CPU");
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
if (_save_options.allow_partial_checkpoint)
{

}
}

/// <summary>
/// Adds edges from objects to other objects and functions.
/// </summary>
private void _load_edges()
{
foreach(var (node_id, object_proto) in _iter_all_nodes())
{
_add_object_graph_edges(object_proto, node_id);
}

if(_filtered_nodes is not null && _filtered_nodes.Contains(0))
{
var root = get(0);
foreach(var node_path in _node_filters.Keys)
{
var loaded_node = _nodes[_node_path_to_id[node_path]];

var path = node_path.Split('.');
var current_node = root;
foreach(var name in path.Skip(1).Take(path.Length - 2))
{
// `hasattr` and `setattr` is used here
throw new NotImplementedException();
}
// `hasattr` and `setattr` is used here
throw new NotImplementedException();
}
}
}

private void _setup_remaining_functions()
{
// TODO: implement it with concrete functions.
}

public Trackable get(int node_id)
{
return _nodes[node_id];
}

public Trackable get(string node_id)
{
return get(_node_path_to_id[node_id]);
}

/// <summary>
/// Adds edges from an object to its children.
/// </summary>
/// <param name="proto"></param>
/// <param name="node_id"></param>
private void _add_object_graph_edges(SavedObject proto, int node_id)
{
var obj = _nodes[node_id];
var setter = _node_setters[node_id];

foreach(var refer in proto.Children)
{
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// skip the process of "__call__"
}
}

private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
{
Dictionary<int, Trackable> nodes = new();
Dictionary<int, Action<object, object, object>> node_setters = new();
foreach(var item in _loaded_nodes)
{
var node_id = item.Key;
var (node, setter) = item.Value;
nodes[node_id] = node;
node_setters[node_id] = setter;
}
return (nodes, node_setters);
}

private IEnumerable<(int, SavedObject)> _iter_all_nodes()
{
foreach(var node_id in _ordered_node_ids)
{
yield return (node_id, _proto.Nodes[node_id]);
}
}

private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
{
// skip the registered classes.

Dictionary<Maybe<string, int>, Trackable> dependencies = new();
foreach(var item in _get_node_dependencies(proto))
{
dependencies[item.Key] = nodes[item.Value];
}

return _recreate_default(proto, node_id, dependencies);
}

/// <summary>
/// Creates a Python object from a SavedObject protocol buffer.
/// </summary>
/// <param name="proto"></param>
/// <param name="node_id"></param>
/// <param name="dependencies"></param>
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<Maybe<string, int>, Trackable> dependencies)
{
return proto.KindCase switch
{
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
SavedObject.KindOneofCase.Function => throw new NotImplementedException(),
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException()
};
}

private (Trackable, Action<object, object, object>) _recreate_user_object(SavedUserObject? proto, int node_id)
{
// skip the check of proto identifier because of lack of property.

var looked_up = RevivedTypes.deserialize(proto);
if(looked_up is null)
{
return _recreate_base_user_object(proto, node_id);
}
return (looked_up.Item1, looked_up.Item2);
}

private (Trackable, Action<object, object, object>) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null)
{
return (new _UserObject(), setattr);
}

private (BaseResourceVariable, Action<object, object, object>) _recreate_variable(SavedVariable proto)
{
string name = proto.Name;
string dbg_name = !string.IsNullOrEmpty(name) ? name : "<variable loaded from saved model>";

// TODO(Rinne): `validate_synchronization_aggregation_trainable`

var (synchronization, aggregation, trainable) = ResourceVariable.validate_synchronization_aggregation_trainable(
proto.Synchronization, proto.Aggregation, proto.Trainable, dbg_name);

var saved_device = proto.Device;
var load_with_device = _save_options.experimental_variable_policy.save_variable_devices() && !string.IsNullOrEmpty(saved_device);

if (load_with_device)
{
tf.device(saved_device);
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
}
else
{
return (new UninitializedVariable(
shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()),
dtype: (TF_DataType)proto.Dtype,
name: name,
trainable: trainable,
aggregation: aggregation
), setattr);
}
}

// TODO: remove this to a common class.
public static Action<object, object, object> setattr = (x, y, z) =>
{
Debug.Assert(y is string);
var properties = x.GetType().GetProperties();
foreach(var p in properties)
{
if((string)y == p.Name)
{
p.SetValue(x, z);
return;
}
}
// TODO(Rinne): check if the property has been set successfully.
//throw new ValueError($"Cannot find the property {y} of {x}.");
};

public class _UserObject: AutoTrackable
{

}
}
}

+ 122
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.static.cs View File

@@ -0,0 +1,122 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow.Checkpoint;
using Tensorflow.Operations;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow
{
public partial class Loader
{
public static SavedModel parse_saved_model(string export_dir)
{
var path_to_pbtxt = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PBTXT);
var path_to_pb = tf.io.gfile.join(export_dir, Constants.SAVED_MODEL_FILENAME_PB);

SavedModel saved_model = new SavedModel();
if (File.Exists(path_to_pb))
{
byte[] file_content;
using(var f = new FileStream(path_to_pb, FileMode.Open, FileAccess.Read))
{
file_content = new byte[f.Length];
Debug.Assert(f.Length <= int.MaxValue);
f.Read(file_content, 0, (int)f.Length);
}
// TODO: change to stream mode.
saved_model.MergeFrom(file_content);
return saved_model;
}
else if (File.Exists(path_to_pbtxt))
{
throw new NotImplementedException();
}
else
{
throw new IOException($"SavedModel file does not exist at: {export_dir}{Path.PathSeparator}" +
$"{{{Constants.SAVED_MODEL_FILENAME_PBTXT}|{Constants.SAVED_MODEL_FILENAME_PB}}}");
}
}

// TODO: revise the type of `tags`
public static Trackable load(string export_dir, object? tags = null, LoadOptions? options = null)
{
return load_partial(export_dir, null, tags, options)["root"];
}

public static IDictionary<string, Trackable> load_partial(string export_dir, IDictionary<string, (Trackable, Action<object, object, object>)>? filters, object? tags = null, LoadOptions? options = null)
{
if (options is null)
{
options = new LoadOptions();
}
if (tags is not null)
{
throw new NotImplementedException();
}
var (saved_model_proto, debug_info) = Loader.parse_saved_model_with_debug_info(export_dir);

Trackable root = null;
Loader loader = null;
if (saved_model_proto.MetaGraphs.Count == 1 && saved_model_proto.MetaGraphs[0].ObjectGraphDef is not null)
{
// skip python code: `metrics.IncrementReadApi(_LOAD_V2_LABEL)`
var meta_graph_def = saved_model_proto.MetaGraphs[0];
if (!BitConverter.IsLittleEndian)
{
SavedModelUtils.swap_function_tensor_content(meta_graph_def);
}

var object_graph_proto = meta_graph_def.ObjectGraphDef;
var ckpt_options = new CheckpointOptions(options.experimental_io_device);
tf_with(ops.init_scope(), x =>
{
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters);
root = loader.get(0);
// skip the assignment of `graph_debug_info`.
});
// skip the assignment of `tensorflow_version`
// skip the assignment of `tensorflow_git_version`
// skip the process of `metrics`.
}
else
{
if(filters is not null && filters.Count > 0)
{
throw new ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any"
+ " version) cannot be loaded with node filters.");
}
tf_with(ops.init_scope(), x =>
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
});
}
if(filters != null && filters.Count > 0)
{
return filters.Keys.ToDictionary(x => x, x => loader.get(x));
}
else
{
var res = new Dictionary<string, Trackable>();
res["root"] = root;
return res;
}
}

public static (SavedModel, object?) parse_saved_model_with_debug_info(string export_dir)
{
var saved_model = parse_saved_model(export_dir);

// TODO: implement debug info.

return (saved_model, null);
}

}
}

+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -6,7 +6,6 @@ using System.Text;
using Google.Protobuf;
using Tensorflow.Checkpoint;
using Tensorflow.Functions;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;


+ 0
- 1
src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs View File

@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.ModelSaving;

namespace Tensorflow.Training.Saving.SavedModel
{


+ 15
- 0
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -333,6 +333,21 @@ namespace Tensorflow
return restored_ops;
};
}

/// <summary>
/// Returns a dict of SaveableObject factories generated from loaded fns.
/// </summary>
/// <param name="saveable_fn_by_name"></param>
/// <param name="temp_session"></param>
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects(
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session)
{
if (saveable_fn_by_name.Count > 0)
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
}
}

public class SaveableCompatibilityConverter: Trackable


+ 14
- 1
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -20,8 +20,8 @@ using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;
using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using static Tensorflow.Binding;

namespace Tensorflow.Train
@@ -71,6 +71,17 @@ namespace Tensorflow.Train
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }
public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories
{
get
{
return _self_saveable_object_factories;
}
set
{
_self_saveable_object_factories = value;
}
}

/// <summary>
/// Restore-on-create for a variable be saved with this `Checkpointable`.
@@ -259,4 +270,6 @@ namespace Tensorflow.Train
}

public record class TrackableReference(string Name, Trackable Refer);

public record class SlotVariableRestoration(int OptimizerId, int SlotVariableId, string SlotName);
}

+ 1
- 0
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Exceptions;
using Tensorflow.Train;



+ 6
- 2
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -5,9 +5,9 @@ using Tensorflow.Variables;
using Tensorflow.Train;
using static Tensorflow.Binding;
using System.Collections.Generic;
using Tensorflow.ModelSaving;
using System.Diagnostics;
using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow
{
@@ -19,7 +19,11 @@ namespace Tensorflow
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;
protected string _handle_name;
protected string handle_name => _handle_name;
public string handle_name
{
get { return _handle_name; }
set { _handle_name = value; }
}

protected string _unique_id;
public string UniqueId => _unique_id;


+ 18
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -238,5 +238,23 @@ namespace Tensorflow
{
return _graph_element.eval(session);
}

public static (VariableSynchronization, VariableAggregation, bool) validate_synchronization_aggregation_trainable(
VariableSynchronization? synchronization, VariableAggregation? aggregation, bool? trainable, string name)
{
if(aggregation is null)
{
aggregation = VariableAggregation.None;
}
if(synchronization is null)
{
synchronization = VariableSynchronization.Auto;
}
if (trainable is null)
{
trainable = synchronization != VariableSynchronization.OnRead;
}
return (synchronization.Value, aggregation.Value, trainable.Value);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
{
// Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary<string, ILayer>();


+ 5
- 0
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Engine
Inputs = inputs,
Outputs = outputs
})
{
Initialize(inputs, outputs, name);
}

internal void Initialize(Tensors inputs, Tensors outputs, string name = null)
{
_input_layers = new List<ILayer>();
_output_layers = new List<ILayer>();


+ 19
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -75,7 +76,17 @@ namespace Tensorflow.Keras.Engine
public int Id => id;
protected string name;
protected string base_name;
public string Name => name;
public string Name
{
get
{
return name;
}
set
{
name = value;
}
}

protected bool computePreviousMask;
protected List<Operation> updates;
@@ -89,6 +100,8 @@ namespace Tensorflow.Keras.Engine
List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes;

public JObject SerializedAttributes { get; set; }

ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
public CallContext CallContext => callContext.Value;
public Tensor[] input
@@ -117,6 +130,11 @@ namespace Tensorflow.Keras.Engine
protected List<ILayer> _self_tracked_trackables;

public Layer(LayerArgs args)
{
Initialize(args);
}

internal virtual void Initialize(LayerArgs args)
{
this.args = args;
// A stateful layer is a layer whose updates are run during inference too,


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow.Keras.Engine
{
using (SharedObjectSavingScope.Enter())
{
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
KerasSavedModelUtils.save_model(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
}
}
}


+ 8
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -36,6 +36,8 @@ namespace Tensorflow.Keras.Engine
IVariableV1 _predict_counter;
bool _base_model_initialized;
bool stop_training;

public bool IsGraphNetwork => _is_graph_network;
public OptimizerV2 Optimizer
{
@@ -49,6 +51,12 @@ namespace Tensorflow.Keras.Engine
_init_batch_counters();
}

internal override void Initialize(LayerArgs args)
{
_init_batch_counters();
base.Initialize(args);
}

void _configure_steps_per_execution(int steps_per_execution)
{
_steps_per_execution = tf.Variable(steps_per_execution,


+ 10
- 5
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -44,8 +44,6 @@ namespace Tensorflow.Keras.Engine
: base(args.Inputs, args.Outputs, name: args.Name)
{
this.args = args;
if (args.Layers == null)
args.Layers = new List<ILayer>();
// SupportsMasking = true;
_compute_output_and_mask_jointly = true;
_auto_track_sub_layers = false;
@@ -54,10 +52,17 @@ namespace Tensorflow.Keras.Engine
_created_nodes = new List<INode>();

// Add to the model any layers passed to the constructor.
if (args.Layers != null)
if (args.Layers is not null)
{
foreach (var layer in args.Layers)
add(layer);
InitLayers(args.Layers);
}
}

public void InitLayers(IEnumerable<ILayer> layers)
{
foreach(var layer in layers)
{
add(layer);
}
}



+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -25,8 +25,7 @@ namespace Tensorflow.Keras.Layers {
{
throw new ValueError("Alpha must be a number greater than 0.");
}
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -14,8 +14,7 @@ namespace Tensorflow.Keras.Layers {
}
public override void build(Shape input_shape)
{
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{


+ 1
- 2
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -19,8 +19,7 @@ namespace Tensorflow.Keras.Layers {
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
_buildInputShape = input_shape;
built = true;
base.build(input_shape);
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs;


+ 521
- 33
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -1,12 +1,24 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;
using Tensorflow.Training;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -14,17 +26,29 @@ namespace Tensorflow.Keras.Saving
{
public class KerasObjectLoader
{
SavedMetadata _metadata;
SavedObjectGraph _proto;
Dictionary<int, string> _node_paths = new Dictionary<int, string>();
Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>();
List<int> _traversed_nodes_from_config = new List<int>();
private static readonly IDictionary<string, Trackable> PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects;
private SavedMetadata _metadata;
private SavedObjectGraph _proto;
private Dictionary<int, string> _node_paths = new Dictionary<int, string>();
private Dictionary<int, (Model, int[])> model_layer_ids_dependencies = new Dictionary<int, (Model, int[])>();
private Dictionary<int, (Model, Layer[])> model_layer_dependencies = new Dictionary<int, (Model, Layer[])>();
private List<int> _traversed_nodes_from_config = new List<int>();
private Dictionary<int, (Trackable, Action<object, object, object>)> loaded_nodes;
private List<int> _models_to_reconstruct;
public Dictionary<int, (Trackable, Action<object, object, object>)> LoadedNodes => loaded_nodes;

static KerasObjectLoader()
{
PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null;
}

public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def)
{
_metadata = metadata;
_proto = object_graph_def;
_metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath);
_models_to_reconstruct = new List<int>();
loaded_nodes = new Dictionary<int, (Trackable, Action<object, object, object>)>();
}

/// <summary>
@@ -42,15 +66,251 @@ namespace Tensorflow.Keras.Saving
continue;
}

_load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
}
foreach(var node_metadata in metric_list)
{
try
{
loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier,
node_metadata.Metadata);
}
catch(ValueError e)
{
if (compile)
{
throw e;
}
// TODO: add logging.warning.
}
}
}

public string get_path(int node_id)
{
return _node_paths[node_id];
}

/// <summary>
/// Finish setting up Keras objects.
///
/// This function is executed after all objects and functions have been created.
/// Call functions and losses are attached to each layer, and once all layers
/// have been fully set up, graph networks are initialized.
///
/// Subclassed models that are revived from the SavedModel are treated like
/// layers, and have their call/loss functions attached here.
/// </summary>
public void finalize_objects()
{
List<Layer> layers_revived_from_config = new();
List<Layer> layers_revived_from_saved_model = new();
foreach(var item in loaded_nodes)
{
var node_id = item.Key;
var node = item.Value.Item1;
if(node is not Layer || model_layer_ids_dependencies.ContainsKey(node_id))
{
continue;
}

_unblock_model_reconstruction(node_id, node as Layer);

if(node is InputLayer or Metric)
{
continue;
}

// TODO: deal with `RevivedLayer` and `RevivedInputLayer`.
layers_revived_from_config.Add(node as Layer);
}

_finalize_saved_model_layers(layers_revived_from_saved_model);
_finalize_config_layers(layers_revived_from_config);

_reconstruct_all_models();
}

private void _reconstruct_all_models()
{
HashSet<int> all_initialized_models = new();
for(int i = _models_to_reconstruct.Count - 1; i >= 0; i--)
{
int model_id = _models_to_reconstruct[i];
all_initialized_models.Add(model_id);
var (model, layers) = model_layer_dependencies[model_id];
_reconstruct_model(model_id, model, layers.ToList());
_finalize_config_layers(new List<Layer>() { model });
}

Debug.Assert(all_initialized_models.SequenceEqual(model_layer_dependencies.Keys));
}

private void _reconstruct_model(int model_id, Model model, List<Layer> layers)
{
var config = JsonConvert.DeserializeObject<JObject>(_metadata.Nodes[model_id].Metadata)["config"];

if(model.input is not null && model.input.Length > 0)
{

}
else if(model is Sequential s)
{
if(layers is null || layers.Count == 0 || layers[0] is not InputLayer)
{
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer")
{
layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
}
else if (config["layers"][0]["config"]["batch_input_shape"] is not null)
{
// TODO: implement it
}
}
// `model.__init__(layers, config["name"])`
s.InitLayers(layers);
s.Name = config["name"].ToObject<string>();
if(s.input is null || s.input.Length == 0)
{
var first_layer = _get_child_layer_node_ids(model_id)[0];
var input_specs = _infer_inputs(first_layer);
var input_shapes = _infer_inputs(first_layer, true);
// `model._set_inputs(input_specs)`

// skip the check of input_specs is Dictionary
if (!s.Built)
{
s.build(input_shapes);
}
}
}
else
{
// skip the parameter `created_layers`.
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>());
// skip the `model.__init__`
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>());
(model as Functional).connect_ancillary_layers(created_layers);
}

_set_network_attributes_from_metadata(model);
_unblock_model_reconstruction(model_id, model);
}

void _load_layer(int node_id, string identifier, string metadata_json)
private void _set_network_attributes_from_metadata(Model revived_object)
{
// TODO: implement it.
}

/// <summary>
/// Runs the final steps of loading Keras Layers from config.
/// </summary>
/// <param name="layers"></param>
private void _finalize_config_layers(List<Layer> layers)
{
foreach(var layer in layers)
{
if (_is_graph_network(layer))
{
_restore_layer_unconditional_losses(layer);
}
_restore_layer_activation_loss(layer);
_restore_layer_metrics(layer);

// TODO(Rinne): deal with RNN.
}
}

/// <summary>
/// Runs the final steps of loading Keras Layers from SavedModel.
/// </summary>
/// <param name="layers"></param>
private void _finalize_saved_model_layers(List<Layer> layers)
{
foreach(var layer in layers)
{
// TODO(Rinne): deal with `RevivedNetwork`.
_restore_layer_unconditional_losses(layer);
_restore_layer_activation_loss(layer);
_restore_layer_metrics(layer);
}
}

private void _restore_layer_unconditional_losses(Layer layer)
{
// TODO(Rinne): implement it.
}

private void _restore_layer_activation_loss(Layer layer)
{
// TODO(Rinne): implement it.
}

private void _restore_layer_metrics(Layer layer)
{
// TODO(Rinne): implement it.
}

/// <summary>
/// Removes layer from blocking model reconstruction.
/// </summary>
/// <param name="layer_id"></param>
/// <param name="layer"></param>
private void _unblock_model_reconstruction(int layer_id, Layer layer)
{
foreach(var depencency in model_layer_ids_dependencies)
{
var layer_ids = depencency.Value.Item2;
var layers = model_layer_dependencies.SetDefault(depencency.Key,
(depencency.Value.Item1, new Layer[depencency.Value.Item2.Length])).Item2;
if (!layer_ids.Contains(layer_id))
{
continue;
}
layers[Array.IndexOf(layer_ids, layer_id)] = layer;
if (layers.All(x => x is not null))
{
_models_to_reconstruct.Add(depencency.Key);
}
}
}

private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_revive_from_config(identifier, metadata, node_id);

if (loaded_nodes.ContainsKey(node_id))
{
var (node, setter) = loaded_nodes[node_id];

_maybe_add_serialized_attributes(node as Layer, metadata);
var config = metadata.Config;
if(_is_graph_network(node as Layer) && generic_utils.validate_config(config))
{
Debug.Assert(node is Model);
var child_nodes = _get_child_layer_node_ids(node_id);
model_layer_ids_dependencies[node_id] = (node as Model, child_nodes);
if(child_nodes is null || child_nodes.Length == 0)
{
_models_to_reconstruct.Add(node_id);
}
}
return (node, setter);
}
else
{
var (obj, setter) = _revive_from_config(identifier, metadata, node_id);
if (obj is null)
{
(obj, setter) = _revive_custom_object(identifier, metadata);
}
Debug.Assert(obj is Layer);
_maybe_add_serialized_attributes(obj as Layer, metadata);
return (obj, setter);
}
}

/// <summary>
@@ -59,11 +319,32 @@ namespace Tensorflow.Keras.Saving
/// <param name="identifier"></param>
/// <param name="metadata"></param>
/// <param name="node_id"></param>
void _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
private (Trackable, Action<object, object, object>) _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
{
var obj = _revive_graph_network(identifier, metadata, node_id);
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
Trackable obj;
if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER)
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}
else
{
obj = _revive_graph_network(identifier, metadata, node_id);
obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
}

if(obj is null)
{
return (null, null);
}
var setter = _config_node_setter(_revive_setter);
_add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id);
return (obj, setter);
}

private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata)
{
// TODO: implement it.
throw new NotImplementedException();
}

Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id)
@@ -71,6 +352,12 @@ namespace Tensorflow.Keras.Saving
var config = metadata.Config;
var class_name = metadata.ClassName;
Model model = null;

if(!metadata.IsGraphNetwork && class_name != "Sequential" && class_name != "Functional")
{
return null;
}

if (class_name == "Sequential")
{
model = new Sequential(new SequentialArgs
@@ -78,34 +365,83 @@ namespace Tensorflow.Keras.Saving
Name = config.GetValue("name").ToString()
});
}
else if (class_name == "Functional")
else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER)
{
throw new NotImplementedException("");
model = model = new Sequential(new SequentialArgs
{
Name = class_name
});
}
else
{
// TODO: implement it.
throw new NotImplementedException("Not implemented");
}

if (!metadata.IsGraphNetwork)
return null;

// Record this model and its layers. This will later be used to reconstruct
// the model.
var layers = _get_child_layer_node_ids(node_id);
model_layer_dependencies[node_id] = (model, layers);
model_layer_ids_dependencies[node_id] = (model, layers);
if(layers is null || layers.Length == 0)
{
_models_to_reconstruct.Add(node_id);
}
return model;
}

Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
Layer _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
{
var config = metadata.Config;
var class_name = metadata.ClassName;
var shared_object_id = metadata.SharedObjectId;
var must_restore_from_config = metadata.MustRestoreFromConfig;
var obj = class_name switch
{
"Resizing" => Resizing.from_config(config),
_ => throw new NotImplementedException("")
};

var obj = generic_utils.deserialize_keras_object(class_name, config);

obj.Name = metadata.Name;
// TODO(Rinne): add `trainable`, `dtype`, `stateful` and `save_spec`

var built = _try_build_layer(obj, node_id, metadata.BuildInputShape);
return null;
if (!built)
{
return null;
}
return obj;
}

private void _revive_setter(object layer, object name, object value)
{
Debug.Assert(name is string);
Debug.Assert(layer is Layer);
if(PUBLIC_ATTRIBUTES.ContainsKey(name as string))
{
if(value is Trackable)
{
(layer as Layer)._track_trackable(value as Trackable, name as string);
}
if((layer as Layer).SerializedAttributes is null)
{
(layer as Layer).SerializedAttributes = new JObject();
}
(layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value);
}
else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success)
{
(layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true);
}
else
{
var properties = layer.GetType().GetProperties();
foreach(var p in properties)
{
if(p.Name == name as string && p.GetValue(layer) is not null)
{
return;
}
}
Loader.setattr(layer, name, value);
}
}

/// <summary>
@@ -143,34 +479,186 @@ namespace Tensorflow.Keras.Saving
/// <param name="obj"></param>
/// <param name="proto"></param>
/// <param name="node_id"></param>
void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id)
void _add_children_recreated_from_config(Trackable obj, SavedObject proto, int node_id)
{
if (_traversed_nodes_from_config.Contains(node_id))
return;
var parent_path = _node_paths[node_id];
_traversed_nodes_from_config.Add(node_id);
if (!obj.Built)
obj._maybe_initialize_trackable();

if(obj is Layer layer && !layer.Built)
{
var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
_try_build_layer(obj, node_id, metadata.BuildInputShape);
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(_metadata.Nodes[node_id].Metadata);
_try_build_layer(layer, node_id, metadata.BuildInputShape);
}


List<(Trackable, int, string)> children = new();
foreach(var refer in proto.Children)
{
var obj_child = obj._lookup_dependency(refer.LocalName);
children.Add((obj_child, refer.NodeId, refer.LocalName));
}

var metric_list_node_id = _search_for_child_node(node_id, new string[] {
Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics"
});
if(metric_list_node_id is not null && obj is Model model && model.metrics is not null)
{
var obj_metrics = model.metrics.ToDictionary(x => x.Name, x => x);
foreach(var refer in _proto.Nodes[metric_list_node_id.Value].Children)
{
if (obj_metrics.TryGetValue(refer.LocalName, out var metric))
{
var metric_path = $"{Keras.Saving.SavedModel.Constants.KERAS_ATTR}.layer_metrics.{refer.LocalName}";
children.Add((metric, refer.NodeId, metric_path));
}
}
}

foreach(var (obj_child, child_id, child_name) in children)
{
if(obj_child is null)
{
continue;
}
var child_proto = _proto.Nodes[child_id];

// skip the check for registered identifier

Action<object, object, object> setter;
if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier))
{
setter = _revive_setter;
}
else
{
setter = Loader.setattr;
}

if (loaded_nodes.ContainsKey(child_id))
{
// skip the logging.warning
continue;
}

if(child_proto.KindCase == SavedObject.KindOneofCase.Variable && !string.IsNullOrEmpty(child_proto.Variable.Name))
{
(obj_child as BaseResourceVariable).handle_name = child_proto.Variable.Name + ":0";
}

if(obj_child is TrackableDataStructure)
{
setter = (x, y, z) => { };
}

var child_path = $"{parent_path}.{child_name}";
_node_paths[child_id] = child_path;
_add_children_recreated_from_config(obj_child, child_proto, child_id);
loaded_nodes[child_id] = (obj_child, setter);
}
}

bool _try_build_layer(Model obj, int node_id, Shape build_input_shape)
private bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape)
{
if (obj.Built)
return true;

if(build_input_shape is null)
{
build_input_shape = _infer_inputs(node_id, convert_to_shapes: true);
}

if(build_input_shape is not null)
{
obj.build(build_input_shape);
// In tf python here is a `base_layer.Layer.build(obj, build_input_shape)`.
// On the one hand, C# does not support call a method from specified parent class.
// On the other hand, currently All class derived from Layer call `Layer.Build` or
// move the implementation of `Layer.build` to its own `build` method.
// Therefore we do not call it here.
// However, it's still quite risky once in the future a certain class derived from
// `Layer` does not call `Layer.build`.

return true;
}

return false;
}

bool _try_build_layer(Layer obj, int node_id, Shape build_input_shape)
/// <summary>
/// Infers input shape of layer from SavedModel functions.
/// </summary>
/// <param name="layer_node_id"></param>
/// <param name="convert_to_shapes"></param>
/// <returns></returns>
private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false)
{
if (obj.Built)
return true;
var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" });
if(call_fn_id is null)
{
return null;
}

var concrete_functions = _proto.Nodes[call_fn_id.Value].Function.ConcreteFunctions;
if(concrete_functions is null)
{
return null;
}
var call_fn_name = concrete_functions[0];
var call_fn_proto = _proto.ConcreteFunctions[call_fn_name];
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}

private int? _search_for_child_node(int parent_id, IEnumerable<string> path_to_child)
{
if(path_to_child is null || path_to_child.Count() == 0)
{
return parent_id;
}

foreach(var child in _proto.Nodes[parent_id].Children)
{
if(child.LocalName == path_to_child.First())
{
return _search_for_child_node(child.NodeId, path_to_child.Skip(1));
}
}
return null;
}

private bool _is_graph_network(Layer layer)
{
// TODO: deal with `RevivedLayer`
if(layer is Functional)
{
return (layer as Functional).IsGraphNetwork || layer is Sequential;
}
return false;
}

private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata)
{
// TODO: deal with `RevivedLayer`
}

/// <summary>
/// Creates edges for nodes that are recreated from config.
/// </summary>
/// <returns></returns>
private Action<object, object, object> _config_node_setter(Action<object, object, object> setter)
{
void setattr_wrapper(object obj, object name, object value)
{
Debug.Assert(obj is Trackable);
Debug.Assert(name is string);
if((obj as Trackable)._lookup_dependency(name as string) is null)
{
setter(obj, name, value);
}
}
return setattr_wrapper;
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;

public partial class KerasSavedModelUtils
{
public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
public static void save_model(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
SaveOptions? options, bool save_traces = true)
{
if (!overwrite && File.Exists(filepath))


+ 96
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -0,0 +1,96 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class KerasLoadModelUtils
{
/// <summary>
/// Corresponding to keras/saving/save.py/load_model
/// </summary>
/// <param name="filepath"></param>
/// <param name="custom_objects"></param>
/// <param name="compile"></param>
/// <param name="options"></param>
/// <returns></returns>
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
bool compile = true, LoadOptions? options = null)
{
using (SharedObjectSavingScope.Enter())
{
using (LoadContext.load_context(options))
{
if (!File.Exists(filepath) && !Directory.Exists(filepath))
{
throw new IOException($"No file or directory found at {filepath}.");
}
if (Directory.Exists(filepath))
{
return load(filepath, compile, options);
}
else
{
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
}
}
}
}

public static Trackable load(string path, bool compile = true, LoadOptions? options = null)
{
SavedMetadata metadata = new SavedMetadata();
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
if (File.Exists(path_to_metadata_pb))
{
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
}
else
{
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
}

if (metadata.Nodes is null || metadata.Nodes.Count == 0)
{
return Loader.load(path, options: options) as Model;
}

var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);

Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
nodes_to_load["root"] = (null, null);
foreach(var item in keras_loader.LoadedNodes)
{
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
}
var loaded = Loader.load_partial(path, nodes_to_load, options);

keras_loader.finalize_objects();
// keras_loader.del_tracking();

var model = loaded["root"];

if(model is Model && compile)
{
// TODO: implement it.
}

if (!tf.Context.executing_eagerly())
{
// TODO: implement it.
}

return model;
}
}
}

+ 69
- 0
src/TensorFlowNET.Keras/Saving/SavedModel/load_context.cs View File

@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow.Keras.Saving.SavedModel
{
// TODO: remove this class to common project.
public class ContextHandler: IDisposable
{
public Action<bool> DisposeCallBack { get; set; }
public void Dispose()
{
DisposeCallBack.Invoke(true);
}
}
public class LoadContext
{
private bool _entered_load_context;
private LoadOptions? _load_options;
private static ThreadLocal<LoadContext> _load_context = new();
private LoadContext()
{
_entered_load_context = false;
_load_options = null;
}

public void set_load_options(LoadOptions load_options)
{
_load_options = load_options;
_entered_load_context = true;
}

private void clear_load_options()
{
_load_options = null;
_entered_load_context = false;
}

private LoadOptions? load_options()
{
return _load_options;
}

public static ContextHandler load_context(LoadOptions? load_options)
{
if(_load_context.Value is null)
{
_load_context.Value = new LoadContext();
}
_load_context.Value.set_load_options(load_options);
return new ContextHandler()
{
DisposeCallBack = _ => _load_context.Value.clear_load_options()
};
}

public static LoadOptions? get_load_option()
{
return _load_context.Value.load_options();
}

public static bool in_load_context()
{
return _load_context.Value._entered_load_context;
}
}
}

+ 29
- 0
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -22,12 +22,16 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Train;

namespace Tensorflow.Keras.Utils
{
public class generic_utils
{
private static readonly string _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config";
/// <summary>
/// This method does not have corresponding method in python. It's close to `serialize_keras_object`.
/// </summary>
@@ -51,6 +55,21 @@ namespace Tensorflow.Keras.Utils
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance);
}

public static Layer deserialize_keras_object(string class_name, JObject config)
{
return class_name switch
{
"Sequential" => new Sequential(config.ToObject<SequentialArgs>()),
"InputLayer" => new InputLayer(config.ToObject<InputLayerArgs>()),
"Flatten" => new Flatten(config.ToObject<FlattenArgs>()),
"ELU" => new ELU(config.ToObject<ELUArgs>()),
"Dense" => new Dense(config.ToObject<DenseArgs>()),
"Softmax" => new Softmax(config.ToObject<SoftmaxArgs>()),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
}

public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>
@@ -60,5 +79,15 @@ namespace Tensorflow.Keras.Utils
x.ToString();
})).ToLower();
}

/// <summary>
/// Determines whether config appears to be a valid layer config.
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
public static bool validate_config(JObject config)
{
return !config.ContainsKey(_LAYER_UNDEFINED_CONFIG_KEY);
}
}
}

+ 45
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -0,0 +1,45 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

[TestClass]
public class SequentialModelLoad
{
[TestMethod]
public void SimpleModelFromSequential()
{
var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential");
Debug.Assert(model is Model);
var m = model as Model;

m.summary();

m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 50000,
}).Result;

m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}
}

test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs → test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -63,6 +63,8 @@ public class SequentialModelTest
keras.layers.Softmax(1)
});

model.summary();

model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();

Loading…
Cancel
Save