Browse Source

Add more implementations to the pb model save.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
a479e53f3a
30 changed files with 775 additions and 160 deletions
  1. +6
    -4
      src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
  3. +9
    -9
      src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
  4. +12
    -11
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  5. +39
    -21
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  6. +31
    -0
      src/TensorFlowNET.Core/Eager/execute.cs
  7. +23
    -0
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  8. +37
    -1
      src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs
  9. +55
    -4
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  10. +32
    -0
      src/TensorFlowNET.Core/Operations/io_ops.cs
  11. +54
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  12. +28
    -0
      src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs
  13. +22
    -2
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  14. +7
    -7
      src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs
  15. +48
    -35
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs
  16. +53
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs
  17. +5
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs
  18. +83
    -26
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  19. +9
    -4
      src/TensorFlowNET.Core/Training/Trackable.cs
  20. +64
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  21. +1
    -7
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  22. +70
    -0
      src/TensorFlowNET.Core/Variables/UninitializedVariable.cs
  23. +1
    -1
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  25. +3
    -0
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  26. +3
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  27. +12
    -12
      src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs
  28. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs
  29. +54
    -9
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  30. +8
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs

+ 6
- 4
src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Train;
@@ -85,17 +86,18 @@ public static class CheckPointUtils
}
}

public static string get_full_name(Trackable var)
public static string get_full_name(Trackable variable)
{
// TODO: This state is not correct, the whole framework need to be updated in the future.
if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var)))
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
{
return "";
}
// skip the check of attribute `_save_slice_info` .
// TODO: Need to be revised!!!
return ((ResourceVariable)(object)var).Name;
Debug.Assert(variable is BaseResourceVariable);
return ((BaseResourceVariable)variable).Name;
}

public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)


+ 4
- 3
src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -39,7 +39,7 @@ namespace Tensorflow.Checkpoint
var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto);
var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto);

Dictionary<Tensor, string> feed_additions;
Dictionary<Tensor, object> feed_additions;
if(cache is null)
{
feed_additions = null;
@@ -125,7 +125,7 @@ namespace Tensorflow.Checkpoint
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
var trackable = td.object_to_save;
Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
@@ -134,6 +134,7 @@ namespace Tensorflow.Checkpoint
else
{
tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto);
trackable = td.object_to_save;
}
if(trackable is not null)
{


+ 9
- 9
src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs View File

@@ -44,19 +44,19 @@ public static class SaveUtilV1
return (checkpoint_factory_map, null);
}

public static (List<MySaveableObject>, object?) frozen_saveables_and_savers(ObjectGraphView graph_view,
public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures,
object? saveables_cache = null)
{
if (to_graph is not null)
{
to_graph.as_default();
var g = to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
// tensorflow python: `with ops.device("/cpu:0")`
var serialized = graph_proto.ToByteString().ToString();
var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
g.Exit();
return (named_saveable_objects, registered_savers);
}
else
@@ -65,7 +65,7 @@ public static class SaveUtilV1
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
// tensorflow python: `with ops.device("/cpu:0")`
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
@@ -73,7 +73,7 @@ public static class SaveUtilV1
}
}

public static (List<MySaveableObject>, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view,
public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view,
IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
@@ -129,7 +129,7 @@ public static class SaveUtilV1
return object_graph_proto;
}

private static (List<MySaveableObject>, object?, object?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects,
TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids,
IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map,
bool call_with_mapped_captures, object? saveables_cache = null)
@@ -216,7 +216,7 @@ public static class SaveUtilV1

public record class CheckpointFactoryData
(
Maybe<ResourceVariable, MySaveableObject> factory,
Maybe<BaseResourceVariable, MySaveableObject> factory,
string name,
string checkpoint_key
);

+ 12
- 11
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -33,7 +33,7 @@ public class TrackableSaver
}

private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
gather_serialized_tensors(Tensor? object_graph_tensor = null)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache);
@@ -42,26 +42,27 @@ public class TrackableSaver

if(object_graph_tensor is null)
{
// tensorflow python: `with ops.device("/cpu:0"):`
object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
tf.device("/cpu:0");
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
}
else
{
feed_additions[object_graph_tensor] = graph_proto.ToString();
feed_additions[object_graph_tensor] = graph_proto.ToByteArray();
}
Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
if (serialized_tensors.ContainsKey(Trackable.None))
if (!serialized_tensors.ContainsKey(Trackable.None))
{
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>();
}
serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor;
return (serialized_tensors, feed_additions, registered_savers, graph_proto);
}

private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
@@ -86,11 +87,11 @@ public class TrackableSaver
return run_save();
}

private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options)
{
var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor);

Func<(Tensor, IDictionary<Tensor, string>)> run_save = () =>
Func<(Tensor, IDictionary<Tensor, object>)> run_save = () =>
{
if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function())
{
@@ -124,7 +125,7 @@ public class TrackableSaver
options = new CheckpointOptions();
}

Dictionary<Tensor, string> feed_dict = new();
Dictionary<Tensor, object> feed_dict = new();
bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function());
if (checkpoint_number is not null)
{


+ 39
- 21
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -12,6 +12,8 @@ using System.Linq;
using Tensorflow.Operations;
using Tensorflow.Training;
using Tensorflow.Graphs;
using System.Xml.Linq;
using System.Diagnostics;

namespace Tensorflow.Checkpoint
{
@@ -31,6 +33,10 @@ namespace Tensorflow.Checkpoint
{
return Func.DynamicInvoke(args);
}
public TR Invoke()
{
return Func.Invoke();
}
}
internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder
{
@@ -164,7 +170,6 @@ namespace Tensorflow.Checkpoint
{
var slice_spec = slice.Key;
var maybe_tensor = slice.Value;
// TODO: deal with other types. Currently only `SaveSpec` is allowed.
if(maybe_tensor.DataType == typeof(SaveSpec))
{
var spec = maybe_tensor.GetValueB();
@@ -284,14 +289,16 @@ namespace Tensorflow.Checkpoint
var obj = pair.Key;
var tensor_dict = pair.Value;
IFunctionHolder restore_fn;
if(obj is null)
if(obj == Trackable.None)
{
restore_fn = new FunctionHolder<object?>(() => null);
}
else
{
restore_fn = null;
// TODO: implement obj._restore_from_tensors
restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x =>
{
return obj._restore_from_tensors(x);
});
}

foreach(var item in tensor_dict)
@@ -343,7 +350,7 @@ namespace Tensorflow.Checkpoint
}
}

public Operation save(string file_prefix, CheckpointOptions? options= null)
public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
{
if(options is null)
{
@@ -351,9 +358,9 @@ namespace Tensorflow.Checkpoint
}

tf.device("CPU"); // may be risky.
// TODO: optimize the implementation with new APIs adding to `string_ops`.
string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part";
var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix);
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
constant_op.constant(".part"), constant_op.constant("_temp/part"));
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));

Operation save_fn()
@@ -385,7 +392,7 @@ namespace Tensorflow.Checkpoint
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true);
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
}
}

@@ -400,9 +407,9 @@ namespace Tensorflow.Checkpoint
}
}

public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options);
public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options);

public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null)
public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null)
{
if(options is null)
{
@@ -496,8 +503,10 @@ namespace Tensorflow.Checkpoint
public SaverDef to_proto()
{
var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename");
var save_tensor = _traced_save(filename_tensor);
var restore_op = _traced_restore(filename_tensor).op;
var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING);
var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING);
var save_tensor = traced_save_func(filename_tensor);
var restore_op = traced_restore_func(filename_tensor).op;
return new SaverDef()
{
FilenameTensorName = filename_tensor.name,
@@ -507,10 +516,9 @@ namespace Tensorflow.Checkpoint
};
}

[AutoGraph]
private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix.StringData()[0]);
var save_op = save(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
{
@@ -518,24 +526,34 @@ namespace Tensorflow.Checkpoint
}
}

[AutoGraph]
private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix.StringData()[0]);
var restore_op = restore(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[] { restore_op }))
using (ops.control_dependencies(restore_op.Values.ToArray()))
{
return array_ops.identity(file_prefix);
}
}

private static Tensor registered_saver_filename(string filename, string saver_name)
public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
foreach (var saveable in saveables)
{
var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable });
serialized_tensors[trackable] = trackable.serialize_to_tensors();
}
return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures);
}

private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name)
{
return tf.constant($"{filename}-{saver_name}");
return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") });
}
private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards)
{
return filename_tensor;
return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards);
}
}
}

+ 31
- 0
src/TensorFlowNET.Core/Eager/execute.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Xml.Linq;
using Tensorflow.Contexts;
using static Tensorflow.ApiDef.Types;
using static Tensorflow.CostGraphDef.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
internal class execute
{
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx)
{
var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx));
var types = v.Select(t => t.dtype.as_datatype_enum());
return (types.ToArray(), v.ToArray());
}
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
{
string device_name = ctx.DeviceName;

ctx.ensure_initialized();
var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs);

return tensors;
}
}
}

+ 23
- 0
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -406,5 +406,28 @@ namespace Tensorflow

meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true;
}

/// <summary>
/// Extract the Op name from a Tensor name.
/// </summary>
/// <param name="tensor_name"></param>
/// <returns></returns>
public static string op_name(string tensor_name)
{
if (string.IsNullOrEmpty(tensor_name))
{
throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}.");
}

if (tensor_name.StartsWith("^"))
{
tensor_name = tensor_name.Substring(1);
}
if (tensor_name.Contains(":"))
{
return tensor_name.Split(':')[0];
}
return tensor_name;
}
}
}

+ 37
- 1
src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs View File

@@ -14,11 +14,47 @@ namespace Tensorflow.ModelSaving
public IDictionary<string, object>? function_aliases { get; set; } = null;
public string? experimental_io_device { get; set; } = null;
// TODO: experimental
public Object? experimental_variable_polict { get; set; } = null;
public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None;
public bool experimental_custom_gradients { get; set; } = true;
public SaveOptions(bool save_debug_info = false)
{
this.save_debug_info = save_debug_info;
}
}

public class VariablePolicy
{
public string Policy { get; }
private VariablePolicy(string policy)
{
Policy = policy;
}
public static VariablePolicy None = new(null);
public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices");
public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables");

public bool save_variable_devices()
{
return this != VariablePolicy.None;
}

/// <summary>
/// Tries to convert `obj` to a VariablePolicy instance.
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
public static VariablePolicy from_obj(object obj)
{
if (obj is null) return VariablePolicy.None;
if (obj is VariablePolicy) return (VariablePolicy)obj;
var key = obj.ToString().ToLower();
return key switch
{
null => VariablePolicy.None,
"save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES,
"expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES,
_ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.")
};
}
}
}

+ 55
- 4
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -1,6 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
@@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations
/// path in the input checkpoint_prefixes. This is useful when those paths are non
/// user-facing temporary locations.
/// </remarks>
public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints")
{
public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files));
result = null;
return null;
//try
//{
// var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name,
// new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }));
// result = null;
// return null;
//}
//catch (System.Exception)
//{
// return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs,
// allow_missing_files: allow_missing_files, name: name, ctx: ctx);
//}
}
var dict = new Dictionary<string, object>();
dict["checkpoint_prefixes"] = checkpoint_prefixes;
dict["destination_prefix"] = destination_prefix;
if (delete_old_dirs.HasValue)
dict["delete_old_dirs"] = delete_old_dirs.Value;
dict["delete_old_dirs"] = delete_old_dirs;
var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict);
return op;
}

//public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx)
//{
// checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING);
// destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING);
// var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix };
// var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files };
// var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name);
// result = null;
// return null;
//}

/// <summary>
/// Transforms a spectrogram into a form that's useful for speech recognition.
/// </summary>
@@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["input"] = input;
dict["pattern"] = pattern;
@@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["basename"] = basename;
dict["shard"] = shard;
@@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin")
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator));
return result[0];
}
var dict = new Dictionary<string, object>();
dict["inputs"] = inputs;
if (separator != null)


+ 32
- 0
src/TensorFlowNET.Core/Operations/io_ops.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -23,11 +25,41 @@ namespace Tensorflow
{
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(
new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors }));
result = null;
return null;
}
catch (System.Exception)
{
return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx);
}
}
var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });

return _op;
}

public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx)
{
DataType[] attr_dtypes;
(attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx);
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING);
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING);
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING);
var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray();
var attrs = new object[] { "dtypes", attr_dtypes };

var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name);
result = null;
return null;
}

public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });


+ 54
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -17,7 +17,9 @@
using System;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Variables;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
@@ -177,5 +179,57 @@ namespace Tensorflow
return HandleData.Parser.ParseFrom(handle.BufferToArray());
}
}

/// <summary>
/// Copies an existing variable to a new graph, with no initializer.
/// </summary>
/// <param name="variable"></param>
public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable)
{
var new_variable = new UninitializedVariable(
trainable: variable.Trainable,
shape: variable.shape,
dtype: variable.dtype,
name: variable.SharedName,
aggregation: variable.Aggregation,
extra_handle_data: null);
new_variable._maybe_initialize_trackable();
return new_variable;
}

/// <summary>
/// Writes additional information of the variable into the SavedObject proto.
/// </summary>
/// <param name="resource_variable"></param>
/// <param name="proto"></param>
/// <param name="options"></param>
/// <param name="enforcing_naming"></param>
public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true)
{
// lack of API: `proto.Variable.SetInParent()`.
if(enforcing_naming && !resource_variable.Name.EndsWith(":0"))
{
throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " +
$"unexpected suffix in the name (expected ':0') which won't be restored.");
}
if(proto.Variable is null)
{
proto.Variable = new SavedVariable();
}
proto.Variable.Name = meta_graph.op_name(resource_variable.Name);
proto.Variable.Trainable = resource_variable.Trainable;
proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum();
// TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`.
proto.Variable.Aggregation = resource_variable.Aggregation;
proto.Variable.Shape = resource_variable.shape.as_proto();

if (options.experimental_variable_policy.save_variable_devices())
{
if (!string.IsNullOrEmpty(resource_variable.Device))
{
proto.Variable.Device = resource_variable.Device;
}
}
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using static Tensorflow.Binding;

namespace Tensorflow
{
public class ResourceVariableSaveable : MySaveableObject
@@ -35,6 +37,32 @@ namespace Tensorflow
this.name = name;
}

public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name)
{
_var_device = var.Device;
_var_shape = var.shape;

Tensor _read_variable_closure(BaseResourceVariable v)
{
tf.device(v.Device);
if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
}

this.handle_op = var.Handle;
var tensor = _read_variable_closure(var);

var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype);
_op = var;
specs = new SaveSpec[] { spec };
this.name = name;
}

public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];


+ 22
- 2
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -14,11 +14,31 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Checkpoint;

namespace Tensorflow
{
public class MySaveableObject
{
public Tensor op;
protected Maybe<Tensor, BaseResourceVariable> _op;
public Tensor op
{
get
{
if(_op.DataType == typeof(Tensor))
{
return _op.GetValueA();
}
else
{
throw new TypeError("The _op is not a tensor.");
}
}
set
{
_op = value;
}
}
public SaveSpec[] specs;
public string name;
public string device;
@@ -35,7 +55,7 @@ namespace Tensorflow

public MySaveableObject(Tensor op, SaveSpec[] specs, string name)
{
this.op = op;
this._op = op;
this.specs = specs;
this.name = name;
}


+ 7
- 7
src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs View File

@@ -10,6 +10,7 @@ using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;
using static Tensorflow.Binding;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow;

@@ -75,7 +76,7 @@ public class SaveableView
private void initialize_save_and_restore_functions()
{
// TODO: deal with the return value of `get_checkpoint_factories_and_keys`.
SaveUtilV1.get_checkpoint_factories_and_keys(_object_names);
var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names);
// skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver.
_obj_to_registered_saver = new();
_saveable_objects_map = new();
@@ -191,7 +192,7 @@ public class SaveableView
/// </summary>
/// <param name="asset_index"></param>
/// <returns></returns>
public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index, SaveOptions options)
public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index)
{
SavedObjectGraph proto = new();
fill_object_graph_proto(proto);
@@ -203,21 +204,20 @@ public class SaveableView
{
var obj = _nodes[i];
var obj_proto = proto.Nodes[i];
write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x),
options);
write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x));
}

return proto;
}

private static void write_object_proto(Trackable obj, SavedObject proto,
IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn, SaveOptions options)
IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn)
{
// skip the process of type Asset
if (resource_variable_ops.is_resource_variable(obj))
{
// TODO: complete it.
throw new NotImplementedException();
var options = SaveContext.get_save_options();
(obj as BaseResourceVariable).write_object_proto(proto, options);
}
else if (obj is Function)
{


+ 48
- 35
src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs View File

@@ -10,6 +10,7 @@ using Tensorflow.ModelSaving;
using Tensorflow.Train;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow;

@@ -43,7 +44,7 @@ public static partial class SavedModelUtils
{
SavedModelUtils.get_or_create_variables_dir(export_dir);
CheckpointOptions ckpt_options = new(options.experimental_io_device);
object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options);
object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options);
}
BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir);

@@ -68,6 +69,7 @@ public static partial class SavedModelUtils

var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB));
File.WriteAllBytes(path, saved_model.ToByteArray());
//File.WriteAllText(path, saved_model.ToString());

if (options.save_debug_info)
{
@@ -83,45 +85,48 @@ public static partial class SavedModelUtils
Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj,
ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null)
{
if (ops.inside_function())
using (SaveContext.save_context(options))
{
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " +
"Move the call to the outer eagerly-executed context.");
}
if (ops.inside_function())
{
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " +
"Move the call to the outer eagerly-executed context.");
}

if (meta_graph_def is null)
{
meta_graph_def = new MetaGraphDef();
}
if (meta_graph_def is null)
{
meta_graph_def = new MetaGraphDef();
}

AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is null)
{
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view);
}
// TODO: process of aignatures and wrapped_functions
AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj);
if (signatures is null)
{
signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view);
}

SaveableView saveable_view = new SaveableView(augmented_graph_view, options);
TrackableSaver object_saver = new TrackableSaver(augmented_graph_view);
var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures,
options.namespace_white_list, options.experimental_custom_gradients);
if (options.function_aliases is not null)
{
var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases;
foreach (var pair in options.function_aliases)
// TODO: process of aignatures and wrapped_functions

SaveableView saveable_view = new SaveableView(augmented_graph_view, options);
TrackableSaver object_saver = new TrackableSaver(augmented_graph_view);
var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures,
options.namespace_white_list, options.experimental_custom_gradients);
if (options.function_aliases is not null)
{
var alias = pair.Key;
var func = pair.Value;
// TODO: complete it.
throw new NotImplementedException();
var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases;
foreach (var pair in options.function_aliases)
{
var alias = pair.Key;
var func = pair.Value;
// TODO: complete it.
throw new NotImplementedException();
}
}
}

var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options);
meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto);
var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index);
meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto);

return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths);
return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths);
}
}

private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view,
@@ -134,7 +139,7 @@ public static partial class SavedModelUtils
Dictionary<Trackable, Trackable> object_map;
Dictionary<Tensor, Tensor> tensor_map;
AssetInfo asset_info;
exported_graph.as_default();
var g = exported_graph.as_default();
(object_map, tensor_map, asset_info) = saveable_view.map_resources();
// TODO: deal with signatures.
if (save_custom_gradients)
@@ -161,15 +166,23 @@ public static partial class SavedModelUtils
// Lack `CopyFrom` API
// meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY]

g.Exit();

foreach (var obj in object_map.Values)
{
obj._maybe_initialize_trackable();
}

// TODO: add the implementation of `call_with_mapped_functions`.
var (named_saveable_objects, registered_savers) =
SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false);
// TODO: complete the save of checkpoints with `MultiDeviceSaver`.
var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false);

var eg = exported_graph.as_default();
var saver_def = saver.to_proto();
meta_graph_def.SaverDef = saver_def;
eg.Exit();


saveable_view.dependency_sorted_node_ids();



+ 53
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.ModelSaving;

namespace Tensorflow.Training.Saving.SavedModel
{
/// <summary>
/// A context for building a graph of SavedModel.
/// </summary>
public static class SaveContext
{
// TODO: make it thead safe.
private static bool _in_save_context = false;
private static SaveOptions _save_options = null;

public static bool in_save_context() => _in_save_context;
public static SaveOptions get_save_options()
{
if (!in_save_context())
{
throw new ValueError("Not in a SaveContext.");
}
return _save_options;
}
public static SaveContextHandler save_context(SaveOptions options)
{
return new SaveContextHandler(options);
}
public class SaveContextHandler: IDisposable
{
private bool _old_in_save_context;
private SaveOptions _old_save_options;
public SaveContextHandler(SaveOptions options)
{
if (SaveContext.in_save_context())
{
throw new ValueError("Already in a SaveContext.");
}
_old_in_save_context = SaveContext._in_save_context;
SaveContext._in_save_context = true;
_old_save_options = SaveContext._save_options;
SaveContext._save_options = options;
}
public void Dispose()
{
SaveContext._in_save_context = _old_in_save_context;
SaveContext._save_options = _old_save_options;
}
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs View File

@@ -28,6 +28,11 @@ public static partial class SavedModelUtils
return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY));
}

public static string get_variables_path(string export_dir)
{
return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME));
}

/// <summary>
/// Return assets sub-directory, or create one if it doesn't exist.
/// </summary>


+ 83
- 26
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -19,6 +19,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using Tensorflow.Training;
using static Tensorflow.Binding;
@@ -117,8 +118,7 @@ namespace Tensorflow
}
else
{
Debug.Assert(variable is ResourceVariable);
yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name);
yield return new ResourceVariableSaveable(variable, "", name);
}
}
else
@@ -215,7 +215,7 @@ namespace Tensorflow
return names_to_saveables;
}

public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj)
{
// skip the process of type `PythonState`

@@ -251,7 +251,7 @@ namespace Tensorflow
specs.Add(new SaveSpec(item.Value, item.Key, spec_name));
}
}
Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new();
Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new();
res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix);
return res;
}
@@ -270,25 +270,6 @@ namespace Tensorflow
{
return tf.compat.as_str(x);
}
}

public class SaveableCompatibilityConverter: Trackable
{
private Trackable _obj;
private IList<MySaveableObject> _saveables;
public SaveableCompatibilityConverter(Trackable obj, IList<MySaveableObject> saveables)
{
_obj= obj;
_saveables= saveables;
}

public Trackable Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
return saveable_object_to_tensor_dict(_saveables);
}

/// <summary>
/// Converts a list of SaveableObjects to a tensor dictionary.
@@ -299,11 +280,11 @@ namespace Tensorflow
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
foreach (var saveable in saveables)
{
foreach(var spec in saveable.specs)
foreach (var spec in saveable.specs)
{
// skip the check that if `spec` is callable.
var name = saveable_object_util.convert_to_string(spec.name);
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);
var name = convert_to_string(spec.name);
var slice_spec = convert_to_string(spec.slice_spec);
if (!string.IsNullOrEmpty(slice_spec))
{
tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor;
@@ -316,5 +297,81 @@ namespace Tensorflow
}
return tensor_dict;
}

/// <summary>
/// Generates `Trackable._restore_from_tensors` from SaveableObjects.
/// </summary>
/// <returns></returns>
public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables)
{
return (restored_tensors) =>
{
Dictionary<string, Operation> restored_ops = new();

foreach(var saveable in saveables)
{
List<Tensor> saveable_restored_tensors = new();
foreach(var spec in saveable.specs)
{
var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name));
var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec);

var maybe_tensor = restored_tensors[name];
IDictionary<string, Tensor> dict;
if(maybe_tensor.DataType == typeof(Tensor))
{
dict = new Dictionary<string, Tensor>();
dict[""] = maybe_tensor.GetValueA();
}
else
{
dict = maybe_tensor.GetValueB();
}
saveable_restored_tensors.Add(dict[slice_spec]);
}
restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null);
}
return restored_ops;
};
}
}

public class SaveableCompatibilityConverter: Trackable
{
private object _obj;
private IList<MySaveableObject> _saveables;
public SaveableCompatibilityConverter(object obj, IList<MySaveableObject> saveables)
{
_obj= obj;
_saveables= saveables;
}

public object Obj => _obj;
public IList<MySaveableObject> mySaveables=> _saveables;

public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors()
{
return saveable_object_util.saveable_object_to_tensor_dict(_saveables);
}

/// <summary>
/// Returns the restore ops defined in the Saveables.
/// </summary>
/// <param name="restored_tensors"></param>
/// <returns></returns>
public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{
List<string> expected_keys = new();
foreach(var saveable in _saveables)
{
expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name))));
}
if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys))
{
throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." +
$"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}");
}
return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors);
}
}
}

+ 9
- 4
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -42,11 +42,11 @@ namespace Tensorflow.Train

protected IList<TrackableReference> _unconditional_checkpoint_dependencies;

protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories =
new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
private bool _manual_tracking = true;

private static Trackable _none = new Function();
private static Trackable _none = new AutoTrackable();
/// <summary>
/// This is a trick for that CSharp does not allow the key of `Dictionary` to be null.
/// The `None` can be any object that inherits `Trackable`.
@@ -225,7 +225,7 @@ namespace Tensorflow.Train
return self_tensor_map.Keys.ToList();
}

public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
if (saveable_object_util.trackable_has_serialize_to_tensor(this))
{
@@ -251,6 +251,11 @@ namespace Tensorflow.Train
{
throw new NotImplementedException();
}

public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors)
{
throw new NotImplementedException();
}
}

public record class TrackableReference(string Name, Trackable Refer);


+ 64
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -6,6 +6,8 @@ using Tensorflow.Train;
using static Tensorflow.Binding;
using System.Collections.Generic;
using Tensorflow.ModelSaving;
using System.Diagnostics;
using Tensorflow.Checkpoint;

namespace Tensorflow
{
@@ -13,6 +15,7 @@ namespace Tensorflow
{
protected string _name;
public virtual string Name => _handle_name;
public virtual string SharedName => _name;
protected TF_DataType _dtype;
public TF_DataType dtype => _dtype;
protected string _handle_name;
@@ -50,6 +53,7 @@ namespace Tensorflow
public Graph Graph => handle.graph;
public string Device => handle.Device;
EagerResourceDeleter eager_resource_deleter;
public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None;

public BaseResourceVariable()
{
@@ -77,6 +81,11 @@ namespace Tensorflow
_handle = handle.EagerTensorHandle.DangerousGetHandle();
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);
}
else if(handle is null)
{
// TODO: fix this dangerous change.
_handle = IntPtr.Zero;
}
else
{
_handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle();
@@ -247,5 +256,60 @@ namespace Tensorflow
else
return value();
}

public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options)
{
BaseResourceVariable new_variable;
if (save_options.experimental_variable_policy.save_variable_devices())
{
tf.device(this.Device);
Debug.Assert(this is ResourceVariable);
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
}
else
{
new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this);
}
Dictionary<Trackable, Trackable> obj_map = new();
Dictionary<Tensor, Tensor> resource_map = new();
obj_map[this] = new_variable;
resource_map[this.handle] = new_variable.handle;
return (obj_map, resource_map);
}

/// <summary>
/// Writes additional information of the variable into the SavedObject proto.
/// ubclasses of ResourceVariables could choose to override this method to
/// customize extra information to provide when saving a SavedModel.
/// </summary>
/// <param name="proto"></param>
/// <param name="options"></param>
public virtual void write_object_proto(SavedObject proto, SaveOptions options)
{
resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options);
}

public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
return res;
}

public Tensor is_initialized(string name = null)
{
return gen_resource_variable_ops.var_is_initialized_op(this.handle, name);
}

public Tensor read_value_no_copy()
{
Tensor value = null;
tf_with(ops.name_scope("Read"), _ =>
{
// TODO: `no_copy = true`.
value = _read_variable_op();
});
return array_ops.identity(value);
}
}
}

+ 1
- 7
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -41,6 +41,7 @@ namespace Tensorflow
VariableAggregation aggregation = VariableAggregation.None,
Shape shape = null)
{
Aggregation = aggregation;
if (variable_def != null)
{
if (initial_value != null)
@@ -237,12 +238,5 @@ namespace Tensorflow
{
return _graph_element.eval(session);
}

public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint()
{
var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>();
res[Trackable.Constants.VARIABLE_VALUE_KEY] = this;
return res;
}
}
}

+ 70
- 0
src/TensorFlowNET.Core/Variables/UninitializedVariable.cs View File

@@ -0,0 +1,70 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Gradients;
using static Tensorflow.Binding;

namespace Tensorflow.Variables
{
/// <summary>
/// A variable with no initializer.
/// </summary>
public sealed class UninitializedVariable: BaseResourceVariable
{
// TODO: complete the arg list.
public UninitializedVariable(
bool trainable = true,
string caching_device = "",
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
VariableAggregation aggregation = VariableAggregation.None,
Shape shape = null,
Tensor extra_handle_data = null)
{
string unique_id = "";
string handle_name = "";
tf_with(ops.init_scope(), (x) =>
{
_in_graph_mode = !tf.Context.executing_eagerly();
tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name =>
{
handle_name = ops.name_from_scope_name(name);
string? shared_name;
if (_in_graph_mode)
{
shared_name = handle_name;
unique_id = shared_name;
}
else
{
unique_id = $"{handle_name}-{ops.uid()}";
shared_name = null;
}
var handle = resource_variable_ops.variable_handle_from_shape_and_dtype(
shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data);
// skip the assignment of `handle._parent_trackable` because of lack of API.
// skip the assignment of `handle._name` and `handle._unique_id` because of accessability.

if (_in_graph_mode)
{
tf_with(ops.name_scope("Read"), _ =>
{
tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
// _maybe_set_handle_data(dtype, handle, value)
_graph_element = value;
});
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
}
else
{
_graph_element = null;
}
});
});
_shape = shape;
_dtype = dtype;
base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Engine
}
}

var layer_config = generic_utils.serialize_keras_object(layer);
var layer_config = generic_utils.serialize_layer_to_config(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine;

public abstract partial class Layer
{
public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);
public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this);

public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;



+ 3
- 0
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -18,6 +18,7 @@ using System.Linq;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving.SavedModel;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -105,5 +106,7 @@ namespace Tensorflow.Keras.Layers
{
return new InputLayer(args as InputLayerArgs);
}

public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this);
}
}

+ 3
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -55,6 +55,7 @@ public partial class KerasSavedModelUtils

var metadata = generate_keras_metadata(saved_nodes, node_paths);
File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray());
//File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString());

if (!include_optimizer)
{
@@ -100,7 +101,8 @@ public partial class KerasSavedModelUtils
Identifier = layer.ObjectIdentifier,
Metadata = layer.TrackingMetadata
};

metadata.Nodes.Add(saved_object);
}

return metadata;


+ 12
- 12
src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs View File

@@ -24,26 +24,26 @@ public partial class KerasSavedModelUtils
// TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs.

// TODO: change the inherits of `Variable` and revise the implmentation.
var variables = layer.Variables.Select(x =>
var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
});
var trainable_variables = layer.TrainableVariables.Select(x =>
}));
var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
});
var non_trainable_variables = layer.non_trainable_variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
});
}));
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
{
if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
}));

Dictionary<string, Trackable> res = new();
res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables);
res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables);
res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables);
res["variables"] = variables;
res["trainable_variables"] = trainable_variables;
res["non_trainable_variables"] = non_trainable_variables;
res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable()));

return res;


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Saving.SavedModel;

public abstract class SavedModelSaver
{
private Trackable _obj;
protected Trackable _obj;
public SavedModelSaver(Trackable obj)
{
_obj = obj;


+ 54
- 9
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -2,6 +2,7 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;

@@ -9,10 +10,11 @@ namespace Tensorflow.Keras.Saving.SavedModel;

public class LayerSavedModelSaver: SavedModelSaver
{
private Layer _obj;
private Layer _layer;
public LayerSavedModelSaver(Layer obj): base(obj)
{
_obj = obj;
_layer = obj;
}
public override string ObjectIdentifier
{
@@ -68,8 +70,8 @@ public class LayerSavedModelSaver: SavedModelSaver
/// <param name="serialization_cache"></param>
private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache);
var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache);
var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache);
var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache);

functions["_default_save_signature"] = null;

@@ -81,17 +83,18 @@ public class LayerSavedModelSaver: SavedModelSaver
get
{
JObject metadata = new JObject();
metadata["name"] = _obj.Name;
metadata["trainable"] = _obj.Trainable;
metadata["name"] = _layer.Name;
metadata["trainable"] = _layer.Trainable;
// metadata["expects_training_arg"] = _obj._expects_training_arg;
// metadata["dtype"] = policy.serialize(_obj._dtype_policy)
metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape);
metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape);
// metadata["stateful"] = _obj.stateful;
// metadata["must_restore_from_config"] = _obj.must_restore_from_config;
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
metadata["autocast"] = _obj.AutoCast;
metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings
metadata["autocast"] = _layer.AutoCast;

var temp = JObject.FromObject(get_serialized(_layer));
metadata.Merge(temp, new JsonMergeSettings
{
// Handle conflicts by using values from obj2
MergeArrayHandling = MergeArrayHandling.Merge
@@ -108,4 +111,46 @@ public class LayerSavedModelSaver: SavedModelSaver
return new Dictionary<string, object>();
//return generic_utils.serialize_keras_object(obj);
}
}

public class InputLayerSavedModelSaver: SavedModelSaver
{
public InputLayerSavedModelSaver(Layer obj) : base(obj)
{
}
public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER;

public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
return new Dictionary<string, Trackable>();
}

public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache)
{
return new Dictionary<string, Trackable>();
}

public override string TrackingMetadata
{
get
{
if(_obj is not Layer)
{
throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer.");
}
var layer = (Layer)_obj;
var info = new
{
class_name = layer.GetType().Name,
name = layer.Name,
dtype = layer.DType,
//sparse = layer.sparse,
//ragged = layer.ragged,
batch_input_shape = layer.BatchInputShape,
config = layer.get_config()
};
return JsonConvert.SerializeObject(info);
}
}
}

+ 8
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Saving;

@@ -22,7 +24,12 @@ namespace Tensorflow.Keras.Utils
{
public class generic_utils
{
public static LayerConfig serialize_keras_object(ILayer instance)
/// <summary>
/// This method does not have corresponding method in python. It's close to `serialize_keras_object`.
/// </summary>
/// <param name="instance"></param>
/// <returns></returns>
public static LayerConfig serialize_layer_to_config(ILayer instance)
{
var config = instance.get_config();
return new LayerConfig


Loading…
Cancel
Save