Browse Source

Add checkpoint reading for SavedModel format loading.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
8f7a594145
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
17 changed files with 628 additions and 66 deletions
  1. +9
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs
  5. +168
    -14
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  7. +257
    -6
      src/TensorFlowNET.Core/Checkpoint/restore.cs
  8. +5
    -1
      src/TensorFlowNET.Core/Eager/execute.cs
  9. +42
    -1
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  10. +1
    -0
      src/TensorFlowNET.Core/Operations/io_ops.cs
  11. +18
    -0
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  12. +16
    -2
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  13. +66
    -17
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  14. +30
    -10
      src/TensorFlowNET.Core/Training/Trackable.cs
  15. +3
    -3
      src/TensorFlowNET.Core/Training/TrackableUtils.cs
  16. +3
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  17. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 9
- 0
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -158,4 +158,13 @@ public static class CheckPointUtils
{ {
return objects_ids_and_slot_variables_and_paths(graph_view).Item1; return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
} }

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;
});
}
} }

+ 4
- 3
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -1,12 +1,13 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
internal class CheckpointReader : IDisposable
public class CheckpointReader : IDisposable
{ {
private IntPtr _reader; private IntPtr _reader;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
@@ -61,14 +62,14 @@ namespace Tensorflow.Checkpoint
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
} }


public Tensor GetTensor(string name)
public unsafe Tensor GetTensor(string name)
{ {
Status status = new Status(); Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
status.Check(true); status.Check(true);
var shape = GetVariableShape(name); var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name); var dtype = GetVariableDataType(name);
return new Tensor(tensor, shape, dtype);
return new Tensor(c_api.TF_TensorData(tensor), shape, dtype);
} }


private void ReadAllShapeAndType() private void ReadAllShapeAndType()


+ 3
- 3
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -175,9 +175,9 @@ public static class SaveUtilV1
{ {
var name = factory_data.name; var name = factory_data.name;
var key = factory_data.checkpoint_key; var key = factory_data.checkpoint_key;
var maybe_saveable = factory_data.factory;
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);


// TODO: oneflow python has a process with callable `saveable_factory`.
// TODO: tensorflow python has a process with callable `saveable_factory`.
List<MySaveableObject> saveables = new(); List<MySaveableObject> saveables = new();
if (maybe_saveable.TryGet<MySaveableObject>(out var s)) if (maybe_saveable.TryGet<MySaveableObject>(out var s))
{ {
@@ -217,7 +217,7 @@ public static class SaveUtilV1


public record class CheckpointFactoryData public record class CheckpointFactoryData
( (
Maybe<BaseResourceVariable, MySaveableObject> factory,
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
string name, string name,
string checkpoint_key string checkpoint_key
); );

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs View File

@@ -24,6 +24,6 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
} }
} }

+ 168
- 14
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -10,6 +10,8 @@ using Tensorflow.Exceptions;
using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types; using static Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.Operations; using Tensorflow.Operations;
using Newtonsoft.Json;
using Tensorflow.Training;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


@@ -259,11 +261,48 @@ public class TrackableSaver
saveables_cache: null saveables_cache: null
); );


throw new NotImplementedException();
new CheckpointPosition(checkpoint, 0).restore(_graph_view.Root);

if(_graph_view.AttachedDependencies is not null)
{
foreach(var refer in _graph_view.AttachedDependencies)
{
if(refer.Name == "root")
{
continue;
}
int? proto_id = null;
// Find proto ID of attached dependency (if it is in the proto).
foreach (var proto_refer in object_graph_proto.Nodes[0].Children)
{
if(proto_refer.LocalName == refer.Name)
{
proto_id = proto_refer.NodeId;
break;
}
}

if (proto_id is null)
{
continue;
}

// Object has already been restored. This can happen when there's an
// indirect connection from the attached object to the root.
if (checkpoint.ObjectByProtoId.ContainsKey(proto_id.Value))
{
continue;
}

new CheckpointPosition(checkpoint, proto_id.Value).restore(refer.Refer);
}
}

return new CheckpointLoadStatus(checkpoint, file_prefix_feed_dict, _graph_view);
} }
} }


internal class CheckpointRestoreCoordinator
public class CheckpointRestoreCoordinator
{ {
private CheckpointOptions _options; private CheckpointOptions _options;
private TrackableObjectGraph _object_graph_proto; private TrackableObjectGraph _object_graph_proto;
@@ -280,6 +319,9 @@ internal class CheckpointRestoreCoordinator
private List<Operation> _restore_ops; private List<Operation> _restore_ops;
private List<Trackable> _all_trackables; private List<Trackable> _all_trackables;
private Dictionary<int, Trackable> _object_by_proto_id; private Dictionary<int, Trackable> _object_by_proto_id;
private Dictionary<string, Operation> _restore_ops_by_name;
private Dictionary<int, IList<DeferredSlotVariableRestoration>> _deferred_slot_restorations;
private Dictionary<int, IList<string>> _unused_attributes;


public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor, public CheckpointRestoreCoordinator(TrackableObjectGraph object_graph_proto, string save_path, Tensor save_path_tensor,
CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache) CheckpointReader reader, object? restore_op_cache, ObjectGraphView graph_view, CheckpointOptions options, object? saveables_cache)
@@ -299,10 +341,12 @@ internal class CheckpointRestoreCoordinator
_shape_map = _reader.VariableToShapeMap; _shape_map = _reader.VariableToShapeMap;
_graph_view = graph_view; _graph_view = graph_view;
_restore_ops = new List<Operation>(); _restore_ops = new List<Operation>();
_restore_ops_by_name = new Dictionary<string, Operation>();
_all_trackables = new List<Trackable>(); _all_trackables = new List<Trackable>();
_matched_proto_ids = new HashSet<int>(); _matched_proto_ids = new HashSet<int>();
_object_by_proto_id = new Dictionary<int, Trackable>(); _object_by_proto_id = new Dictionary<int, Trackable>();
_slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>(); _slot_restorations = new Dictionary<int, IList<SlotVariableRestoration>>();
_deferred_slot_restorations = new Dictionary<int, IList<DeferredSlotVariableRestoration>>();


_expect_partial_attr = false; _expect_partial_attr = false;
for(int i = 0; i < _object_graph_proto.Nodes.Count; i++) for(int i = 0; i < _object_graph_proto.Nodes.Count; i++)
@@ -330,10 +374,18 @@ internal class CheckpointRestoreCoordinator
} }
} }


/// <summary>
/// Corresponding to `all_python_objects` of tensorflow python
/// </summary>
public List<Trackable> AllTrackables => _all_trackables; public List<Trackable> AllTrackables => _all_trackables;
public HashSet<int> MatchedProtoIds => _matched_proto_ids; public HashSet<int> MatchedProtoIds => _matched_proto_ids;
public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id; public Dictionary<int, Trackable> ObjectByProtoId => _object_by_proto_id;
public int RestoreUid => _restore_uid; public int RestoreUid => _restore_uid;
public TrackableObjectGraph ObjectGraphProto => _object_graph_proto;
public Dictionary<int, IList<SlotVariableRestoration>> SlotRestorations => _slot_restorations;
public Dictionary<int, IList<DeferredSlotVariableRestoration>> DeferredSlotRestorations => _deferred_slot_restorations;
public Dictionary<string, Operation> RestoreOpsByName => _restore_ops_by_name;
public Dictionary<int, IList<string>> UnusedAttributes => _unused_attributes;


public void new_restore_ops(IEnumerable<Operation> new_ops) public void new_restore_ops(IEnumerable<Operation> new_ops)
{ {
@@ -341,18 +393,52 @@ internal class CheckpointRestoreCoordinator
// skip the callback. // skip the callback.
} }


public List<Operation> restore_saveables(MySaveableObject tensor_saveables, object? python_positions = null, object? registered_savers = null)
public List<Operation> restore_saveables(Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables, List<CheckpointPosition> positions, object? registered_savers = null)
{ {
throw new NotImplementedException();
List<Operation> restore_ops = new();
foreach(var position in positions)
{
var key = position.ObjectProto.Attributes[0].CheckpointKey;
throw new NotImplementedException();
}

Dictionary<string, BaseResourceVariable> variable_dict = new();
foreach(var item in tensor_saveables)
{
if(item.Value.TryGet<BaseResourceVariable>(out var variable))
{
variable_dict[item.Key] = variable;
}
else
{
throw new TypeError();
}
}

if (tensor_saveables is not null && tensor_saveables.Count > 0)
{
var flat_saveables = saveable_object_util.validate_and_slice_inputs(variable_dict);
var new_restore_ops = MultiDeviceSaver.from_saveables(flat_saveables).restore(_save_path_tensor, _options);
if (!tf.Context.executing_eagerly())
{
foreach(var item in new_restore_ops)
{
restore_ops.Add(item.Value);
Debug.Assert(!_restore_ops_by_name.ContainsKey(item.Key));
_restore_ops_by_name[item.Key] = item.Value;
}
}
}
return restore_ops;
} }
} }


public abstract class LoadStatus public abstract class LoadStatus
{ {
public abstract void assert_consumed();
public abstract void assert_existing_objects_matched();
public abstract void assert_nontrivial_match();
public abstract void run_restore_ops(Session? session = null);
public abstract LoadStatus assert_consumed();
public abstract LoadStatus assert_existing_objects_matched();
public abstract LoadStatus assert_nontrivial_match();
public abstract LoadStatus run_restore_ops(Session? session = null);
public abstract void initialize_or_restore(Session? session = null); public abstract void initialize_or_restore(Session? session = null);
public virtual LoadStatus expect_partial() public virtual LoadStatus expect_partial()
{ {
@@ -371,19 +457,19 @@ public class InitializationOnlyStatus: LoadStatus
_object_graph_view = object_graph_view; _object_graph_view = object_graph_view;
_root = object_graph_view.Root; _root = object_graph_view.Root;
} }
public override void assert_consumed()
public override LoadStatus assert_consumed()
{ {
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
} }
public override void assert_existing_objects_matched()
public override LoadStatus assert_existing_objects_matched()
{ {
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
} }
public override void assert_nontrivial_match()
public override LoadStatus assert_nontrivial_match()
{ {
throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored."); throw new AssertionError("No checkpoint specified (save_path=None); nothing is being restored.");
} }
public override void run_restore_ops(Session? session = null)
public override LoadStatus run_restore_ops(Session? session = null)
{ {
throw new AssertionError("No checkpoint specified, so no restore ops are available " throw new AssertionError("No checkpoint specified, so no restore ops are available "
+ "(save_path=None to Saver.restore)."); + "(save_path=None to Saver.restore).");
@@ -403,10 +489,78 @@ public class InitializationOnlyStatus: LoadStatus
} }
} }


public class CheckpointLoadStatus
internal class CheckpointLoadStatus: LoadStatus
{ {
public CheckpointLoadStatus()
private CheckpointRestoreCoordinator _checkpoint;
private Dictionary<Tensor, string> _feed_dict;
private ObjectGraphView _object_graph_view;
private Trackable _root;
public CheckpointLoadStatus(CheckpointRestoreCoordinator checkpoint, Dictionary<Tensor, string> feed_dict, ObjectGraphView graph_view):base()
{
_checkpoint = checkpoint;
_feed_dict = feed_dict;
_object_graph_view = graph_view;
_root = graph_view.Root;
}

public CheckpointRestoreCoordinator Checkpoint => _checkpoint;

public override LoadStatus assert_consumed()
{
throw new NotImplementedException();
}

public override LoadStatus assert_existing_objects_matched()
{
for(int i = 0; i < _checkpoint.ObjectGraphProto.Nodes.Count; i++)
{
var node = _checkpoint.ObjectGraphProto.Nodes[i];
if(_checkpoint.ObjectByProtoId.TryGetValue(i, out var trackable) &&
trackable.UpdateUid < _checkpoint.RestoreUid)
{
throw new AssertionError($"Object {node} not assigned a value from checkpoint.");
}
}
foreach(var trackable_object in CheckPointUtils.list_objects(_object_graph_view))
{
if(trackable_object is TrackableDataStructure && trackable_object._trackable_children().Count == 0)
{
continue;
}
_checkpoint.AllTrackables.Add(trackable_object);
}
var unused_trackables = CheckPointUtils._objects_with_attributes(_checkpoint.AllTrackables)
.Except(_checkpoint.ObjectByProtoId.Values);
if (unused_trackables.Any())
{
var num_unused_trackables = unused_trackables.Count();
var num_variables_to_show = Math.Min(10, num_unused_trackables);
throw new AssertionError($"Found {num_unused_trackables} Python objects that were " +
$"not bound to checkpointed values, likely due to changes in the " +
$"Python program. Showing {num_variables_to_show} of " +
$"{num_unused_trackables} unmatched objects: " +
$"{{list(unused_python_objects)[:num_variables_to_show]}}");
}
return this;
}

public override LoadStatus assert_nontrivial_match()
{
throw new NotImplementedException();
}

public override LoadStatus expect_partial()
{ {
throw new NotImplementedException();
}


public override void initialize_or_restore(Session? session = null)
{
throw new NotImplementedException();
}

public override LoadStatus run_restore_ops(Session? session = null)
{
throw new NotImplementedException();
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -213,7 +213,7 @@ namespace Tensorflow.Checkpoint


// tf python has code `with ops.device(restore_device):` here. // tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky. tf.device(restore_device); // may be risky.
var restored_tensors = tf.io.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());


Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0; int idx = 0;


+ 257
- 6
src/TensorFlowNET.Core/Checkpoint/restore.cs View File

@@ -1,11 +1,15 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding;


namespace Tensorflow.Checkpoint; namespace Tensorflow.Checkpoint;


internal class CheckpointPosition
public class CheckpointPosition
{ {
private CheckpointRestoreCoordinator _checkpoint; private CheckpointRestoreCoordinator _checkpoint;
private int _proto_id; private int _proto_id;
@@ -18,6 +22,8 @@ internal class CheckpointPosition
} }


public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id]; public Trackable Trackable => _checkpoint.ObjectByProtoId[_proto_id];
public CheckpointRestoreCoordinator Checkpoint => _checkpoint;
public TrackableObjectGraph.Types.TrackableObject ObjectProto => _checkpoint.ObjectGraphProto.Nodes[_proto_id];


public void restore(Trackable trackable) public void restore(Trackable trackable)
{ {
@@ -25,7 +31,11 @@ internal class CheckpointPosition
{ {
if (bind_project(trackable)) if (bind_project(trackable))
{ {

var restore_ops = _restore_descendants();
if(restore_ops is not null && restore_ops.Count > 0)
{
_checkpoint.new_restore_ops(restore_ops);
}
} }
} }
} }
@@ -51,30 +61,271 @@ internal class CheckpointPosition
} }
} }


public void gather_ops_or_named_saveables()
public (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) gather_ops_or_named_saveables()
{ {
// skip the registered_saver // skip the registered_saver


if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
}

var saveable_factories = saveable_object_util.saveable_objects_from_trackable(this.Trackable);


List<Operation> existing_restore_ops;
List<CheckpointPosition> positions = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables;
if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
(existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories);
}
else if(saveable_factories.Count > 0)
{
(existing_restore_ops, named_saveables) = _create_saveables_by_attribute_name(saveable_factories);
}
else
{
throw new NotImplementedException();
}
return (existing_restore_ops, named_saveables, positions, null);
}

public CheckpointPosition create_child_position(int node_id)
{
return new CheckpointPosition(_checkpoint, node_id);
}

public (CheckpointPosition, BaseResourceVariable) create_slot_variable_position(Optimizer optimizer_object, BaseResourceVariable variable,
int slot_variable_id, string slot_name)
{
//CheckpointPosition slot_variable_position = new(Checkpoint, slot_variable_id);

// TODO(Rinne): implement it.
return (null, null);
}

/// <summary>
/// Creates a saveable using the _serialize_to_tensor method.
/// </summary>
/// <param name="saveable_factories"></param>
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_serialize_to_tensor_saveable(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
string suffix = SaveableCompat.get_saveable_name(this.Trackable);
suffix = suffix ?? "";
var saveable_name = _extract_saveable_name(ObjectProto.Attributes[0].CheckpointKey) + suffix;

if (!tf.Context.executing_eagerly())
{
throw new NotImplementedException("The restore under graph mode has not been implemented. " +
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}

var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name);
// skip the cache.
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> dict = new();
dict[saveable_name] = saveable;
return (new List<Operation>(), dict);
}

private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>) _create_saveables_by_attribute_name(
IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories)
{
// TODO(Rinne): implement it.
if(ObjectProto.Attributes is null)
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>());
}

List<Operation> existing_restore_ops = new();
HashSet<string> created_compat_names = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> named_saveables = new();
foreach (var serialized_tensor in ObjectProto.Attributes)
{
Operation existing_op;
if (tf.Context.executing_eagerly() || !_checkpoint.RestoreOpsByName.ContainsKey(serialized_tensor.CheckpointKey))
{
existing_op = null;
}
else
{
existing_op = _checkpoint.RestoreOpsByName[serialized_tensor.CheckpointKey];
}

if(existing_op is not null)
{
existing_restore_ops.Add(existing_op);
continue;
}

if(created_compat_names.Any(x => serialized_tensor.Name.StartsWith(x)))
{
continue;
}

// TODO(Rinne): deal with cache.

var saveable = _get_saveable_from_factory(saveable_factories, serialized_tensor, created_compat_names);
if(saveable is null)
{
_checkpoint.UnusedAttributes.SetDefault(_proto_id, new List<string>()).Add(serialized_tensor.Name);
continue;
}
named_saveables[serialized_tensor.CheckpointKey] = saveable;
}
return (existing_restore_ops, named_saveables);
}

private Maybe<BaseResourceVariable, MySaveableObject> _get_saveable_from_factory(IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_factories,
TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet<string> created_compat_names)
{
var expected_factory_name = serialized_tensor.Name;
var factory_input_name = serialized_tensor.CheckpointKey;

if (!saveable_factories.TryGetValue(expected_factory_name, out var matched_factory))
{
foreach(var item in saveable_factories)
{
var factory_name = item.Key;
var factory = item.Value;
if (expected_factory_name.StartsWith(factory_name))
{
if(matched_factory is not null)
{
throw new ValueError($"Forward compatibility load error: Unable to load " +
"checkpoint saved in future version of TensorFlow. " +
"Please update your version of TensorFlow to the " +
"version in which the checkpoint was saved.");
}
}
matched_factory = factory;
factory_input_name = _extract_saveable_name(serialized_tensor.CheckpointKey) + factory_name;
created_compat_names.Add(factory_name);
}
}
return matched_factory(factory_input_name);
}

private string _extract_saveable_name(string checkpoint_key)
{
var search_key = TrackableUtils.OBJECT_ATTRIBUTES_NAME + "/";
return checkpoint_key.Substring(0, checkpoint_key.IndexOf(search_key) + search_key.Length);
} }


/// <summary> /// <summary>
/// Restore the bound Trackable and dependencies (may be deferred). /// Restore the bound Trackable and dependencies (may be deferred).
/// </summary> /// </summary>
private void _restore_descendants()
private List<Operation> _restore_descendants()
{ {
Queue<(CheckpointPosition, Trackable)> visit_queue = new(); Queue<(CheckpointPosition, Trackable)> visit_queue = new();
visit_queue.Enqueue((this, this.Trackable)); visit_queue.Enqueue((this, this.Trackable));
List<Operation> restore_ops = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> tensor_saveables = new();
List<CheckpointPosition> positions = new();

CheckpointPosition current_position = null;
while (visit_queue.Count > 0)
{
current_position = visit_queue.Dequeue().Item1;
var (new_restore_ops, new_tensor_saveables, new_positions, new_registered_savers) = current_position._single_restore();
restore_ops.AddRange(new_restore_ops);
foreach(var item in new_tensor_saveables)
{
tensor_saveables.Add(item.Key, item.Value);
}
positions.AddRange(new_positions);
_queue_children_for_restoration(current_position, visit_queue);
_queue_slot_variables(current_position, visit_queue);
}
restore_ops.AddRange(current_position.Checkpoint.restore_saveables(tensor_saveables, positions, null));
return restore_ops;
}

private void _queue_children_for_restoration(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
{
var trackable = checkpoint_position.Trackable;
foreach(var child in checkpoint_position.ObjectProto.Children)
{
var child_position = checkpoint_position.create_child_position(child.NodeId);
var local_object = trackable._lookup_dependency(child.LocalName);
var child_proto = child_position.ObjectProto;
if(local_object is null)
{
if(child_proto.Children.Any() || child_proto.Attributes.Any() || child_proto.SlotVariables.Any())
{
trackable.DeferredDependencies.SetDefault(child.LocalName, new List<CheckpointPosition>()).Add(child_position);
}
}
else
{
if (child_position.bind_project(local_object))
{
visit_queue.Enqueue((child_position, local_object));
}
}
}
}


private void _queue_slot_variables(CheckpointPosition checkpoint_position, Queue<(CheckpointPosition, Trackable)> visit_queue)
{
var trackable = checkpoint_position.Trackable;
var checkpoint = checkpoint_position.Checkpoint;
if(checkpoint.DeferredSlotRestorations.TryGetValue(checkpoint_position._proto_id, out var positions))
{
checkpoint.DeferredSlotRestorations.Remove(checkpoint_position._proto_id);
foreach (var deferred_slot_restoration in positions)
{
var (slot_variable_position, slot_variable) = checkpoint_position.create_slot_variable_position(
trackable as Optimizer, deferred_slot_restoration.OriginalVariable, deferred_slot_restoration.SlotVariableId,
deferred_slot_restoration.SlotName
);
if(slot_variable_position is not null)
{
visit_queue.Enqueue((slot_variable_position, slot_variable));
}
}
}
if (checkpoint.SlotRestorations.TryGetValue(checkpoint_position._proto_id, out var restorations))
{
checkpoint.SlotRestorations.Remove(checkpoint_position._proto_id);
foreach (var slot_restoration in restorations)
{
if(Checkpoint.ObjectByProtoId.TryGetValue(slot_restoration.OptimizerId, out var optimizer_object))
{
throw new NotImplementedException();
// TODO(Rinne); implement it.
}
else
{
Debug.Assert(trackable is BaseResourceVariable);
Checkpoint.DeferredSlotRestorations.SetDefault(slot_restoration.OptimizerId, new List<DeferredSlotVariableRestoration>())
.Add(new DeferredSlotVariableRestoration(trackable as BaseResourceVariable, slot_restoration.SlotVariableId, slot_restoration.SlotName));
}
}
}
} }


private void _single_restore()
private (List<Operation>, Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>, List<CheckpointPosition>, object?) _single_restore()
{ {
var trackable = this.Trackable; var trackable = this.Trackable;
trackable._maybe_initialize_trackable(); trackable._maybe_initialize_trackable();
if(_checkpoint.RestoreUid > trackable.UpdateUid) if(_checkpoint.RestoreUid > trackable.UpdateUid)
{ {

var (restore_ops, tensor_saveables, positions, registered_savers) = gather_ops_or_named_saveables();
trackable.UpdateUid = _checkpoint.RestoreUid;
return (restore_ops, tensor_saveables, positions, registered_savers);
}
else
{
return (new List<Operation>(), new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(),
new List<CheckpointPosition>(), null);
} }
} }
} }

public record class DeferredSlotVariableRestoration(
BaseResourceVariable OriginalVariable,
int SlotVariableId,
string SlotName
);

+ 5
- 1
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -10,7 +10,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
internal class execute
internal static class execute
{ {
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{ {
@@ -27,5 +27,9 @@ namespace Tensorflow.Eager


return tensors; return tensors;
} }
public static bool must_record_gradient()
{
return false;
}
} }
} }

+ 42
- 1
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -27189,8 +27189,33 @@ namespace Tensorflow.Operations
/// ///
/// Callers must ensure all the named tensors are indeed stored in the checkpoint. /// Callers must ensure all the named tensors are indeed stored in the checkpoint.
/// </remarks> /// </remarks>
public static Tensor[] restore_v2(Tensor prefix, Tensor tensor_names, Tensor shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2")
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "RestoreV2")
{ {
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
Dictionary<string, object> attrs = new();
attrs["dtypes"] = dtypes;
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(
"RestoreV2", name, prefix, tensor_names, shape_and_slices
)
{ attrs = attrs });
return result;
}
catch (Exception)
{
try
{
return restore_v2_eager_fallback(prefix, tensor_names, shape_and_slices, dtypes, name, ctx);
}
catch (Exception)
{

}
}
}
var dict = new Dictionary<string, object>(); var dict = new Dictionary<string, object>();
dict["prefix"] = prefix; dict["prefix"] = prefix;
dict["tensor_names"] = tensor_names; dict["tensor_names"] = tensor_names;
@@ -27202,6 +27227,22 @@ namespace Tensorflow.Operations
return (tensors); return (tensors);
} }


public static Tensor[] restore_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name, Context ctx)
{
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING);
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING);
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING);
object[] attrs = new object[] { "dtypes", dtypes };
Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor };
var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name);

if (execute.must_record_gradient())
{
// TODO(Rinne); record the gradient
}
return result;
}

/// <summary> /// <summary>
/// Reverses specific dimensions of a tensor. /// Reverses specific dimensions of a tensor.
/// </summary> /// </summary>


+ 1
- 0
src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -62,6 +62,7 @@ namespace Tensorflow


public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{ {
// Note: this implementation is not correct in many cases, please consider using `gen_ops.restore_v2`.
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


return _op.outputs; return _op.outputs;


+ 18
- 0
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -39,6 +39,24 @@ namespace Tensorflow
_op = value; _op = value;
} }
} }
public BaseResourceVariable variable
{
get
{
if (_op.TryGet<BaseResourceVariable>(out var v))
{
return v;
}
else
{
throw new TypeError("The _op is not a variable.");
}
}
set
{
_op = value;
}
}
public SaveSpec[] specs; public SaveSpec[] specs;
public string name; public string name;
public string device; public string device;


+ 16
- 2
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow


if (!save_options.experimental_skip_checkpoint) if (!save_options.experimental_skip_checkpoint)
{ {
// TODO: implement it.
_restore_checkpoint();
} }
foreach(var node in _nodes) foreach(var node in _nodes)
{ {
@@ -398,13 +398,27 @@ namespace Tensorflow
/// </summary> /// </summary>
private void _restore_checkpoint() private void _restore_checkpoint()
{ {
var variables_path = SavedModelUtils.get_variables_dir(_export_dir);
var variables_path = SavedModelUtils.get_variables_path(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0))); var saver = new TrackableSaver(new ObjectGraphView(get(0)));
tf.device("CPU"); tf.device("CPU");
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
LoadStatus load_status;
if (_save_options.allow_partial_checkpoint) if (_save_options.allow_partial_checkpoint)
{ {
load_status = saver.restore(variables_path, _checkpoint_options).expect_partial();
load_status.assert_nontrivial_match();
}
else
{
load_status = saver.restore(variables_path, _checkpoint_options);
load_status.assert_existing_objects_matched();
}
var ckpt = (load_status as CheckpointLoadStatus).Checkpoint;


if (!tf.Context.executing_eagerly())
{
throw new NotImplementedException("The checkpoint restore has not supported graph mode. " +
"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
} }
} }




+ 66
- 17
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -68,6 +68,34 @@ namespace Tensorflow
return saveables.ToArray(); return saveables.ToArray();
} }


public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, Tensor> names_to_saveables)
{
var saveables = new List<MySaveableObject>();
var seen_ops = new List<Tensor>();

foreach (var (name, op) in enumerate(names_to_saveables))
{
foreach (var converted_saveable_object in saveable_objects_for_op(op, name))
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
return saveables.ToArray();
}

public static MySaveableObject[] validate_and_slice_inputs(Dictionary<string, BaseResourceVariable> names_to_saveables)
{
var saveables = new List<MySaveableObject>();
var seen_ops = new List<BaseResourceVariable>();

foreach(var item in names_to_saveables.OrderBy(x => x.Key))
{
foreach(var converted_saveable_object in saveable_objects_for_op(item.Value, item.Key))
{
_add_saveable(saveables, seen_ops, converted_saveable_object);
}
}
return saveables.ToArray();
}

private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject private static void _add_saveable<T>(List<T> saveables, List<Tensor> seen_ops, T saveable) where T : MySaveableObject
{ {
if (seen_ops.Contains(saveable.op)) if (seen_ops.Contains(saveable.op))
@@ -77,6 +105,15 @@ namespace Tensorflow
seen_ops.Add(saveable.op); seen_ops.Add(saveable.op);
} }


private static void _add_saveable(List<MySaveableObject> saveables, List<BaseResourceVariable> seen_ops, MySaveableObject saveable)
{
if (seen_ops.Contains(saveable.variable))
throw new ValueError($"The same saveable will be restored with two names: {saveable.op.OriginalVar.Name}");

saveables.Add(saveable);
seen_ops.Add(saveable.variable);
}

/// <summary> /// <summary>
/// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`. /// Create `SaveableObject`s from an operation. Note that the `op` should not be implicitly converted from `Variable`.
/// </summary> /// </summary>
@@ -136,19 +173,20 @@ namespace Tensorflow
{ {
full_name = name + "_" + attr; full_name = name + "_" + attr;
} }
if(factory.TryGet<BaseResourceVariable>(out var variable))
var op = factory(full_name);
if(op.TryGet<BaseResourceVariable>(out var variable))
{ {
foreach (var op in saveable_objects_for_op(variable as Trackable, variable.Name))
foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name))
{ {
yield return op;
yield return v;
} }
} }
else else
{ {
var saveable = factory.GetValue<MySaveableObject>();
foreach (var op in saveable_objects_for_op(saveable, saveable.name))
var saveable = op.GetValue<MySaveableObject>();
foreach (var v in saveable_objects_for_op(saveable, saveable.name))
{ {
yield return op;
yield return v;
} }
} }
} }
@@ -214,20 +252,19 @@ namespace Tensorflow
return names_to_saveables; return names_to_saveables;
} }


public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> saveable_objects_from_trackable(Trackable obj)
{ {
// skip the process of type `PythonState` // skip the process of type `PythonState`


if (trackable_has_serialize_to_tensor(obj))
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{ {
var name = TrackableUtils.SERIALIZE_TO_TENSORS_NAME;
// skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`.
var tensor_dict = obj.serialize_to_tensors(); var tensor_dict = obj.serialize_to_tensors();


List<SaveSpec> specs = new(); List<SaveSpec> specs = new();
List<string> local_names = new(); List<string> local_names = new();
string prefix = SaveableCompat.get_saveable_name(obj) ?? ""; string prefix = SaveableCompat.get_saveable_name(obj) ?? "";
foreach(var pair in tensor_dict)
foreach (var pair in tensor_dict)
{ {
var tensor_name = pair.Key; var tensor_name = pair.Key;
var maybe_tensor = pair.Value; var maybe_tensor = pair.Value;
@@ -235,9 +272,9 @@ namespace Tensorflow
string spec_name = name + TrackableUtils.escape_local_name(tensor_name); string spec_name = name + TrackableUtils.escape_local_name(tensor_name);


IDictionary<string, Tensor> internal_dict; IDictionary<string, Tensor> internal_dict;
if(maybe_tensor.TryGet<Tensor>(out var tensor))
if (maybe_tensor.TryGet<Tensor>(out var tensor))
{ {
internal_dict= new Dictionary<string, Tensor>();
internal_dict = new Dictionary<string, Tensor>();
internal_dict[""] = tensor; internal_dict[""] = tensor;
} }
else else
@@ -245,13 +282,18 @@ namespace Tensorflow
internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>(); internal_dict = maybe_tensor.GetValue<IDictionary<string, Tensor>>();
} }


foreach(var item in internal_dict)
foreach (var item in internal_dict)
{ {
specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
} }
} }
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return new TrackableSaveable(obj, specs, name, local_names, prefix);
}

if (trackable_has_serialize_to_tensor(obj))
{
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable;
return res; return res;
} }
else else
@@ -339,14 +381,21 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="saveable_fn_by_name"></param> /// <param name="saveable_fn_by_name"></param>
/// <param name="temp_session"></param> /// <param name="temp_session"></param>
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> recreate_saveable_objects(
public static IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> recreate_saveable_objects(
IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session) IDictionary<string, (Trackable, Trackable)> saveable_fn_by_name, IEnumerable<object>? temp_session)
{ {
if (saveable_fn_by_name.Count > 0) if (saveable_fn_by_name.Count > 0)
{ {
throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
} }
return new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
return res;
}

public static Maybe<BaseResourceVariable, MySaveableObject> create_saveable_object(string name, string key, Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
bool call_with_mapped_captures = false)
{
return factory(key);
} }
} }




+ 30
- 10
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -41,9 +41,10 @@ namespace Tensorflow.Train
protected IDictionary<string, Trackable> _unconditional_dependency_names; protected IDictionary<string, Trackable> _unconditional_dependency_names;


protected IList<TrackableReference> _unconditional_checkpoint_dependencies; protected IList<TrackableReference> _unconditional_checkpoint_dependencies;
protected Dictionary<string, IList<CheckpointPosition>> _unconditional_deferred_dependencies;


protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
protected IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> _self_saveable_object_factories =
new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
private bool _manual_tracking = true; private bool _manual_tracking = true;


private static Trackable _none = new AutoTrackable(); private static Trackable _none = new AutoTrackable();
@@ -71,7 +72,8 @@ namespace Tensorflow.Train
public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; } public IList<TrackableReference> UnconditionalCheckpointDependencies { get => _unconditional_checkpoint_dependencies; }
public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IDictionary<string, Trackable> UnconditionalDependencyNames { get => _unconditional_dependency_names; }
public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; } public IList<TrackableReference> CheckpointDependencies { get => UnconditionalCheckpointDependencies; }
public IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> SelfSaveableObjectFactories
public Dictionary<string, IList<CheckpointPosition>> DeferredDependencies => _unconditional_deferred_dependencies;
public IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> SelfSaveableObjectFactories
{ {
get get
{ {
@@ -147,9 +149,11 @@ namespace Tensorflow.Train
_self_update_uid = -1; _self_update_uid = -1;
_unconditional_checkpoint_dependencies = new List<TrackableReference>(); _unconditional_checkpoint_dependencies = new List<TrackableReference>();
_unconditional_dependency_names = new Dictionary<string, Trackable>(); _unconditional_dependency_names = new Dictionary<string, Trackable>();
_unconditional_deferred_dependencies = new Dictionary<string, IList<CheckpointPosition>>();
} }


public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache)
public virtual IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT,
IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {
_maybe_initialize_trackable(); _maybe_initialize_trackable();
return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer); return _unconditional_checkpoint_dependencies.ToDictionary(x => x.Name, x => x.Refer);
@@ -185,10 +189,19 @@ namespace Tensorflow.Train
/// <param name="trackable"></param> /// <param name="trackable"></param>
public virtual void _handle_deferred_dependencies(string name, Trackable trackable) public virtual void _handle_deferred_dependencies(string name, Trackable trackable)
{ {
//_maybe_initialize_trackable();
//trackable._maybe_initialize_trackable();
// TODO: complete the implementation.
_maybe_initialize_trackable();
trackable._maybe_initialize_trackable();

if(_unconditional_deferred_dependencies.TryGetValue(name, out var dependencies))
{
_unconditional_deferred_dependencies.Remove(name);
foreach(var checkpoint_position in dependencies.OrderByDescending(x => x.Checkpoint.RestoreUid))
{
checkpoint_position.restore(trackable);
}
}

// TODO(Rinne): deal with `_self_name_based_restores`
} }


public virtual Trackable? _lookup_dependency(string name) public virtual Trackable? _lookup_dependency(string name)
@@ -236,12 +249,19 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList(); return self_tensor_map.Keys.ToList();
} }


public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
Maybe<BaseResourceVariable, MySaveableObject> create_saveable(string name = "")
{
throw new NotImplementedException();
//return new TrackableSaveable(this, null, name, null, null);
}
if (saveable_object_util.trackable_has_serialize_to_tensor(this)) if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{ {
// TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`).
throw new NotImplementedException();
Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> res = new();
res[""] = create_saveable;
return res;
} }
else else
{ {


+ 3
- 3
src/TensorFlowNET.Core/Training/TrackableUtils.cs View File

@@ -21,9 +21,9 @@ public static class TrackableUtils
LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable()); LeftOverDependencyMap = leftover_dependency_map.ToDictionary(x => x.Key, x => x.Value.AsEnumerable());
} }
} }
private static string _ESCAPE_CHAR = ".";
private static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
private static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
internal static string _ESCAPE_CHAR = ".";
internal static string _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT";
internal static string OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES";
internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS"; internal static string SERIALIZE_TO_TENSORS_NAME = _ESCAPE_CHAR + "TENSORS";
public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr) public static string object_path_to_string(IEnumerable<TrackableReference> node_path_arr)
{ {


+ 3
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -293,10 +293,10 @@ namespace Tensorflow
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
} }


public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
public override IDictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>> gather_saveables_for_checkpoint()
{ {
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
var res = new Dictionary<string, Func<string, Maybe<BaseResourceVariable, MySaveableObject>>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this;
return res; return res;
} }




+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -21,7 +21,7 @@ public class SequentialModelLoad
[TestMethod] [TestMethod]
public void SimpleModelFromSequential() public void SimpleModelFromSequential()
{ {
var model = KerasLoadModelUtils.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential");
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential");
Debug.Assert(model is Model); Debug.Assert(model is Model);
var m = model as Model; var m = model as Model;




Loading…
Cancel
Save