diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 70d77155..cd37703b 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using Tensorflow.Train; @@ -85,17 +86,18 @@ public static class CheckPointUtils } } - public static string get_full_name(Trackable var) + public static string get_full_name(Trackable variable) { // TODO: This state is not correct, the whole framework need to be updated in the future. - if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) + if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) { return ""; } // skip the check of attribute `_save_slice_info` . - + // TODO: Need to be revised!!! - return ((ResourceVariable)(object)var).Name; + Debug.Assert(variable is BaseResourceVariable); + return ((BaseResourceVariable)variable).Name; } public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index e646f1f0..84e0ca4e 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -39,7 +39,7 @@ namespace Tensorflow.Checkpoint var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); - Dictionary feed_additions; + Dictionary feed_additions; if(cache is null) { feed_additions = null; @@ -125,7 +125,7 @@ namespace Tensorflow.Checkpoint { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; - var trackable = td.object_to_save; + Trackable trackable = null; IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { @@ -134,6 +134,7 @@ namespace Tensorflow.Checkpoint else { tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); + trackable = td.object_to_save; } if(trackable is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index d8e251ec..4f1d04d2 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -44,19 +44,19 @@ public static class SaveUtilV1 return (checkpoint_factory_map, null); } - public static (List, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, + public static (List, IDictionary>?) frozen_saveables_and_savers(ObjectGraphView graph_view, IDictionary object_map, Graph? to_graph, bool call_with_mapped_captures, object? saveables_cache = null) { if (to_graph is not null) { - to_graph.as_default(); + var g = to_graph.as_default(); var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` - var serialized = graph_proto.ToByteString().ToString(); - var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); + tf.device("/cpu:0"); + var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); + g.Exit(); return (named_saveable_objects, registered_savers); } else @@ -65,7 +65,7 @@ public static class SaveUtilV1 { var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - // tensorflow python: `with ops.device("/cpu:0")` + tf.device("/cpu:0"); var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); @@ -73,7 +73,7 @@ public static class SaveUtilV1 } } - public static (List, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, + public static (List, TrackableObjectGraph, object?, IDictionary>?) serialize_gathered_objects(ObjectGraphView graph_view, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); @@ -129,7 +129,7 @@ public static class SaveUtilV1 return object_graph_proto; } - private static (List, object?, object?) add_attributes_to_object_graph(IList trackable_objects, + private static (List, object?, IDictionary>?) add_attributes_to_object_graph(IList trackable_objects, TrackableObjectGraph object_graph_proto, IDictionary node_ids, IDictionary object_names, IDictionary object_map, bool call_with_mapped_captures, object? saveables_cache = null) @@ -216,7 +216,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - Maybe factory, + Maybe factory, string name, string checkpoint_key ); \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index c9bee0db..0c2862da 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -33,7 +33,7 @@ public class TrackableSaver } - private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -42,26 +42,27 @@ public class TrackableSaver if(object_graph_tensor is null) { - // tensorflow python: `with ops.device("/cpu:0"):` - object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + tf.device("/cpu:0"); + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); } else { - feed_additions[object_graph_tensor] = graph_proto.ToString(); + feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); } Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); - if (serialized_tensors.ContainsKey(Trackable.None)) + if (!serialized_tensors.ContainsKey(Trackable.None)) { - serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + serialized_tensors[Trackable.None] = new Dictionary>>(); } + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; return (serialized_tensors, feed_additions, registered_savers, graph_proto); } - private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -86,11 +87,11 @@ public class TrackableSaver return run_save(); } - private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) + private (Tensor, IDictionary) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); - Func<(Tensor, IDictionary)> run_save = () => + Func<(Tensor, IDictionary)> run_save = () => { if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) { @@ -124,7 +125,7 @@ public class TrackableSaver options = new CheckpointOptions(); } - Dictionary feed_dict = new(); + Dictionary feed_dict = new(); bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); if (checkpoint_number is not null) { diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index c4a03985..90bbccf0 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -12,6 +12,8 @@ using System.Linq; using Tensorflow.Operations; using Tensorflow.Training; using Tensorflow.Graphs; +using System.Xml.Linq; +using System.Diagnostics; namespace Tensorflow.Checkpoint { @@ -31,6 +33,10 @@ namespace Tensorflow.Checkpoint { return Func.DynamicInvoke(args); } + public TR Invoke() + { + return Func.Invoke(); + } } internal record class FunctionHolder(Func Func) : IFunctionHolder { @@ -164,7 +170,6 @@ namespace Tensorflow.Checkpoint { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - // TODO: deal with other types. Currently only `SaveSpec` is allowed. if(maybe_tensor.DataType == typeof(SaveSpec)) { var spec = maybe_tensor.GetValueB(); @@ -284,14 +289,16 @@ namespace Tensorflow.Checkpoint var obj = pair.Key; var tensor_dict = pair.Value; IFunctionHolder restore_fn; - if(obj is null) + if(obj == Trackable.None) { restore_fn = new FunctionHolder(() => null); } else { - restore_fn = null; - // TODO: implement obj._restore_from_tensors + restore_fn = new FunctionHolder>>, IDictionary>(x => + { + return obj._restore_from_tensors(x); + }); } foreach(var item in tensor_dict) @@ -343,7 +350,7 @@ namespace Tensorflow.Checkpoint } } - public Operation save(string file_prefix, CheckpointOptions? options= null) + public Operation save(Tensor file_prefix, CheckpointOptions? options= null) { if(options is null) { @@ -351,9 +358,9 @@ namespace Tensorflow.Checkpoint } tf.device("CPU"); // may be risky. - // TODO: optimize the implementation with new APIs adding to `string_ops`. - string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; - var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + constant_op.constant(".part"), constant_op.constant("_temp/part")); + var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); Operation save_fn() @@ -385,7 +392,7 @@ namespace Tensorflow.Checkpoint { string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; tf.device(merge_device); - return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); } } @@ -400,9 +407,9 @@ namespace Tensorflow.Checkpoint } } - public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); + public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); - public IDictionary restore(string file_prefix, CheckpointOptions? options = null) + public IDictionary restore(Tensor file_prefix, CheckpointOptions? options = null) { if(options is null) { @@ -496,8 +503,10 @@ namespace Tensorflow.Checkpoint public SaverDef to_proto() { var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); - var save_tensor = _traced_save(filename_tensor); - var restore_op = _traced_restore(filename_tensor).op; + var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); + var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); + var save_tensor = traced_save_func(filename_tensor); + var restore_op = traced_restore_func(filename_tensor).op; return new SaverDef() { FilenameTensorName = filename_tensor.name, @@ -507,10 +516,9 @@ namespace Tensorflow.Checkpoint }; } - [AutoGraph] private Tensor _traced_save(Tensor file_prefix) { - var save_op = save(file_prefix.StringData()[0]); + var save_op = save(file_prefix); tf.device("cpu:0"); using (ops.control_dependencies(new object[]{ save_op })) { @@ -518,24 +526,34 @@ namespace Tensorflow.Checkpoint } } - [AutoGraph] private Tensor _traced_restore(Tensor file_prefix) { - var restore_op = restore(file_prefix.StringData()[0]); + var restore_op = restore(file_prefix); tf.device("cpu:0"); - using (ops.control_dependencies(new object[] { restore_op })) + using (ops.control_dependencies(restore_op.Values.ToArray())) { return array_ops.identity(file_prefix); } } - private static Tensor registered_saver_filename(string filename, string saver_name) + public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) + { + Dictionary>>> serialized_tensors = new(); + foreach (var saveable in saveables) + { + var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); + serialized_tensors[trackable] = trackable.serialize_to_tensors(); + } + return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); + } + + private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) { - return tf.constant($"{filename}-{saver_name}"); + return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); } private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) { - return filename_tensor; + return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); } } } diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs new file mode 100644 index 00000000..cb3ea4d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using static Tensorflow.ApiDef.Types; +using static Tensorflow.CostGraphDef.Types; +using static Tensorflow.Binding; + +namespace Tensorflow.Eager +{ + internal class execute + { + public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) + { + var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); + var types = v.Select(t => t.dtype.as_datatype_enum()); + return (types.ToArray(), v.ToArray()); + } + public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + string device_name = ctx.DeviceName; + + ctx.ensure_initialized(); + var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); + + return tensors; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index cce13b55..c3616faf 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -406,5 +406,28 @@ namespace Tensorflow meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; } + + /// + /// Extract the Op name from a Tensor name. + /// + /// + /// + public static string op_name(string tensor_name) + { + if (string.IsNullOrEmpty(tensor_name)) + { + throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); + } + + if (tensor_name.StartsWith("^")) + { + tensor_name = tensor_name.Substring(1); + } + if (tensor_name.Contains(":")) + { + return tensor_name.Split(':')[0]; + } + return tensor_name; + } } } diff --git a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs index fce42850..45ebd884 100644 --- a/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs +++ b/src/TensorFlowNET.Core/ModelSaving/SaveOptions.cs @@ -14,11 +14,47 @@ namespace Tensorflow.ModelSaving public IDictionary? function_aliases { get; set; } = null; public string? experimental_io_device { get; set; } = null; // TODO: experimental - public Object? experimental_variable_polict { get; set; } = null; + public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; public bool experimental_custom_gradients { get; set; } = true; public SaveOptions(bool save_debug_info = false) { this.save_debug_info = save_debug_info; } } + + public class VariablePolicy + { + public string Policy { get; } + private VariablePolicy(string policy) + { + Policy = policy; + } + public static VariablePolicy None = new(null); + public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); + public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); + + public bool save_variable_devices() + { + return this != VariablePolicy.None; + } + + /// + /// Tries to convert `obj` to a VariablePolicy instance. + /// + /// + /// + public static VariablePolicy from_obj(object obj) + { + if (obj is null) return VariablePolicy.None; + if (obj is VariablePolicy) return (VariablePolicy)obj; + var key = obj.ToString().ToLower(); + return key switch + { + null => VariablePolicy.None, + "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, + "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, + _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") + }; + } + } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 11cb6de8..956be96b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Operations @@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations /// path in the input checkpoint_prefixes. This is useful when those paths are non /// user-facing temporary locations. /// - public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") - { + public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); + result = null; + return null; + //try + //{ + // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, + // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); + // result = null; + // return null; + //} + //catch (System.Exception) + //{ + // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, + // allow_missing_files: allow_missing_files, name: name, ctx: ctx); + //} + } var dict = new Dictionary(); dict["checkpoint_prefixes"] = checkpoint_prefixes; dict["destination_prefix"] = destination_prefix; - if (delete_old_dirs.HasValue) - dict["delete_old_dirs"] = delete_old_dirs.Value; + dict["delete_old_dirs"] = delete_old_dirs; var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); return op; } + //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) + //{ + // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); + // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); + // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; + // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; + // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); + // result = null; + // return null; + //} + /// /// Transforms a spectrogram into a form that's useful for speech recognition. /// @@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations /// public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); + return result[0]; + } var dict = new Dictionary(); dict["input"] = input; dict["pattern"] = pattern; @@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations /// public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); + return result[0]; + } var dict = new Dictionary(); dict["basename"] = basename; dict["shard"] = shard; @@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations /// public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); + return result[0]; + } var dict = new Dictionary(); dict["inputs"] = inputs; if (separator != null) diff --git a/src/TensorFlowNET.Core/Operations/io_ops.cs b/src/TensorFlowNET.Core/Operations/io_ops.cs index 4f276e36..35c5877f 100644 --- a/src/TensorFlowNET.Core/Operations/io_ops.cs +++ b/src/TensorFlowNET.Core/Operations/io_ops.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using System.Linq; using Tensorflow.Contexts; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -23,11 +25,41 @@ namespace Tensorflow { public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute( + new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); + result = null; + return null; + } + catch (System.Exception) + { + return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); + } + } var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); return _op; } + public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) + { + DataType[] attr_dtypes; + (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); + prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); + var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); + var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); + var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); + var attrs = new object[] { "dtypes", attr_dtypes }; + + var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); + result = null; + return null; + } + public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) { var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index d5a32c10..1b1fa003 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -17,7 +17,9 @@ using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.ModelSaving; using Tensorflow.Train; +using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow @@ -177,5 +179,57 @@ namespace Tensorflow return HandleData.Parser.ParseFrom(handle.BufferToArray()); } } + + /// + /// Copies an existing variable to a new graph, with no initializer. + /// + /// + public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) + { + var new_variable = new UninitializedVariable( + trainable: variable.Trainable, + shape: variable.shape, + dtype: variable.dtype, + name: variable.SharedName, + aggregation: variable.Aggregation, + extra_handle_data: null); + new_variable._maybe_initialize_trackable(); + return new_variable; + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// + /// + /// + /// + /// + public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) + { + // lack of API: `proto.Variable.SetInParent()`. + if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) + { + throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + + $"unexpected suffix in the name (expected ':0') which won't be restored."); + } + if(proto.Variable is null) + { + proto.Variable = new SavedVariable(); + } + proto.Variable.Name = meta_graph.op_name(resource_variable.Name); + proto.Variable.Trainable = resource_variable.Trainable; + proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); + // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. + proto.Variable.Aggregation = resource_variable.Aggregation; + proto.Variable.Shape = resource_variable.shape.as_proto(); + + if (options.experimental_variable_policy.save_variable_devices()) + { + if (!string.IsNullOrEmpty(resource_variable.Device)) + { + proto.Variable.Device = resource_variable.Device; + } + } + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 167c635a..2d23a325 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; + namespace Tensorflow { public class ResourceVariableSaveable : MySaveableObject @@ -35,6 +37,32 @@ namespace Tensorflow this.name = name; } + public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) + { + _var_device = var.Device; + _var_shape = var.shape; + + Tensor _read_variable_closure(BaseResourceVariable v) + { + tf.device(v.Device); + if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + tf.device("/device:CPU:0"); + return array_ops.identity(x); + } + + this.handle_op = var.Handle; + var tensor = _read_variable_closure(var); + + var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + _op = var; + specs = new SaveSpec[] { spec }; + this.name = name; + } + public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 6239030b..43d36dba 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -14,11 +14,31 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Checkpoint; + namespace Tensorflow { public class MySaveableObject { - public Tensor op; + protected Maybe _op; + public Tensor op + { + get + { + if(_op.DataType == typeof(Tensor)) + { + return _op.GetValueA(); + } + else + { + throw new TypeError("The _op is not a tensor."); + } + } + set + { + _op = value; + } + } public SaveSpec[] specs; public string name; public string device; @@ -35,7 +55,7 @@ namespace Tensorflow public MySaveableObject(Tensor op, SaveSpec[] specs, string name) { - this.op = op; + this._op = op; this.specs = specs; this.name = name; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index 6700e277..6132e025 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -10,6 +10,7 @@ using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -75,7 +76,7 @@ public class SaveableView private void initialize_save_and_restore_functions() { // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. - SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); + var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. _obj_to_registered_saver = new(); _saveable_objects_map = new(); @@ -191,7 +192,7 @@ public class SaveableView /// /// /// - public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index, SaveOptions options) + public SavedObjectGraph serialize_object_graph(IDictionary asset_file_def_index) { SavedObjectGraph proto = new(); fill_object_graph_proto(proto); @@ -203,21 +204,20 @@ public class SaveableView { var obj = _nodes[i]; var obj_proto = proto.Nodes[i]; - write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), - options); + write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); } return proto; } private static void write_object_proto(Trackable obj, SavedObject proto, - IDictionary asset_file_def_index, Func> list_children_fn, SaveOptions options) + IDictionary asset_file_def_index, Func> list_children_fn) { // skip the process of type Asset if (resource_variable_ops.is_resource_variable(obj)) { - // TODO: complete it. - throw new NotImplementedException(); + var options = SaveContext.get_save_options(); + (obj as BaseResourceVariable).write_object_proto(proto, options); } else if (obj is Function) { diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs index f3f273b8..d82d49d8 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs @@ -10,6 +10,7 @@ using Tensorflow.ModelSaving; using Tensorflow.Train; using Tensorflow.Exceptions; using static Tensorflow.Binding; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow; @@ -43,7 +44,7 @@ public static partial class SavedModelUtils { SavedModelUtils.get_or_create_variables_dir(export_dir); CheckpointOptions ckpt_options = new(options.experimental_io_device); - object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); + object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); } BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); @@ -68,6 +69,7 @@ public static partial class SavedModelUtils var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); File.WriteAllBytes(path, saved_model.ToByteArray()); + //File.WriteAllText(path, saved_model.ToString()); if (options.save_debug_info) { @@ -83,45 +85,48 @@ public static partial class SavedModelUtils Dictionary>) _build_meta_graph(Trackable obj, ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) { - if (ops.inside_function()) + using (SaveContext.save_context(options)) { - throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + - "Move the call to the outer eagerly-executed context."); - } + if (ops.inside_function()) + { + throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + + "Move the call to the outer eagerly-executed context."); + } - if (meta_graph_def is null) - { - meta_graph_def = new MetaGraphDef(); - } + if (meta_graph_def is null) + { + meta_graph_def = new MetaGraphDef(); + } - AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); - if (signatures is null) - { - signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); - } - - // TODO: process of aignatures and wrapped_functions + AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); + if (signatures is null) + { + signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); + } - SaveableView saveable_view = new SaveableView(augmented_graph_view, options); - TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); - var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, - options.namespace_white_list, options.experimental_custom_gradients); - if (options.function_aliases is not null) - { - var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; - foreach (var pair in options.function_aliases) + // TODO: process of aignatures and wrapped_functions + + SaveableView saveable_view = new SaveableView(augmented_graph_view, options); + TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); + var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, + options.namespace_white_list, options.experimental_custom_gradients); + if (options.function_aliases is not null) { - var alias = pair.Key; - var func = pair.Value; - // TODO: complete it. - throw new NotImplementedException(); + var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; + foreach (var pair in options.function_aliases) + { + var alias = pair.Key; + var func = pair.Value; + // TODO: complete it. + throw new NotImplementedException(); + } } - } - var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); - meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); + var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); + meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); - return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); + } } private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, @@ -134,7 +139,7 @@ public static partial class SavedModelUtils Dictionary object_map; Dictionary tensor_map; AssetInfo asset_info; - exported_graph.as_default(); + var g = exported_graph.as_default(); (object_map, tensor_map, asset_info) = saveable_view.map_resources(); // TODO: deal with signatures. if (save_custom_gradients) @@ -161,15 +166,23 @@ public static partial class SavedModelUtils // Lack `CopyFrom` API // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] + g.Exit(); + foreach (var obj in object_map.Values) { obj._maybe_initialize_trackable(); } + // TODO: add the implementation of `call_with_mapped_functions`. var (named_saveable_objects, registered_savers) = SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); - - // TODO: complete the save of checkpoints with `MultiDeviceSaver`. + var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); + + var eg = exported_graph.as_default(); + var saver_def = saver.to_proto(); + meta_graph_def.SaverDef = saver_def; + eg.Exit(); + saveable_view.dependency_sorted_node_ids(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs new file mode 100644 index 00000000..4cfe0b69 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/save_context.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.ModelSaving; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A context for building a graph of SavedModel. + /// + public static class SaveContext + { + // TODO: make it thead safe. + private static bool _in_save_context = false; + private static SaveOptions _save_options = null; + + public static bool in_save_context() => _in_save_context; + public static SaveOptions get_save_options() + { + if (!in_save_context()) + { + throw new ValueError("Not in a SaveContext."); + } + return _save_options; + } + public static SaveContextHandler save_context(SaveOptions options) + { + return new SaveContextHandler(options); + } + + public class SaveContextHandler: IDisposable + { + private bool _old_in_save_context; + private SaveOptions _old_save_options; + public SaveContextHandler(SaveOptions options) + { + if (SaveContext.in_save_context()) + { + throw new ValueError("Already in a SaveContext."); + } + _old_in_save_context = SaveContext._in_save_context; + SaveContext._in_save_context = true; + _old_save_options = SaveContext._save_options; + SaveContext._save_options = options; + } + public void Dispose() + { + SaveContext._in_save_context = _old_in_save_context; + SaveContext._save_options = _old_save_options; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs index 723419f6..2deff027 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/utils.cs @@ -28,6 +28,11 @@ public static partial class SavedModelUtils return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); } + public static string get_variables_path(string export_dir) + { + return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); + } + /// /// Return assets sub-directory, or create one if it doesn't exist. /// diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 7066b366..582e2431 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Checkpoint; +using Tensorflow.Operations.Activation; using Tensorflow.Train; using Tensorflow.Training; using static Tensorflow.Binding; @@ -117,8 +118,7 @@ namespace Tensorflow } else { - Debug.Assert(variable is ResourceVariable); - yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); + yield return new ResourceVariableSaveable(variable, "", name); } } else @@ -215,7 +215,7 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary> saveable_objects_from_trackable(Trackable obj) + public static IDictionary> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` @@ -251,7 +251,7 @@ namespace Tensorflow specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); } } - Dictionary> res = new(); + Dictionary> res = new(); res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); return res; } @@ -270,25 +270,6 @@ namespace Tensorflow { return tf.compat.as_str(x); } - } - - public class SaveableCompatibilityConverter: Trackable - { - private Trackable _obj; - private IList _saveables; - public SaveableCompatibilityConverter(Trackable obj, IList saveables) - { - _obj= obj; - _saveables= saveables; - } - - public Trackable Obj => _obj; - public IList mySaveables=> _saveables; - - public override IDictionary>> serialize_to_tensors() - { - return saveable_object_to_tensor_dict(_saveables); - } /// /// Converts a list of SaveableObjects to a tensor dictionary. @@ -299,11 +280,11 @@ namespace Tensorflow Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { - foreach(var spec in saveable.specs) + foreach (var spec in saveable.specs) { // skip the check that if `spec` is callable. - var name = saveable_object_util.convert_to_string(spec.name); - var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + var name = convert_to_string(spec.name); + var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { tensor_dict.SetDefault(name, new Dictionary()).GetValueB()[slice_spec] = spec.tensor; @@ -316,5 +297,81 @@ namespace Tensorflow } return tensor_dict; } + + /// + /// Generates `Trackable._restore_from_tensors` from SaveableObjects. + /// + /// + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + { + return (restored_tensors) => + { + Dictionary restored_ops = new(); + + foreach(var saveable in saveables) + { + List saveable_restored_tensors = new(); + foreach(var spec in saveable.specs) + { + var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); + var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); + + var maybe_tensor = restored_tensors[name]; + IDictionary dict; + if(maybe_tensor.DataType == typeof(Tensor)) + { + dict = new Dictionary(); + dict[""] = maybe_tensor.GetValueA(); + } + else + { + dict = maybe_tensor.GetValueB(); + } + saveable_restored_tensors.Add(dict[slice_spec]); + } + restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); + } + return restored_ops; + }; + } + } + + public class SaveableCompatibilityConverter: Trackable + { + private object _obj; + private IList _saveables; + public SaveableCompatibilityConverter(object obj, IList saveables) + { + _obj= obj; + _saveables= saveables; + } + + public object Obj => _obj; + public IList mySaveables=> _saveables; + + public override IDictionary>> serialize_to_tensors() + { + return saveable_object_util.saveable_object_to_tensor_dict(_saveables); + } + + /// + /// Returns the restore ops defined in the Saveables. + /// + /// + /// + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + List expected_keys = new(); + foreach(var saveable in _saveables) + { + expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); + } + if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) + { + throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + + $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); + } + return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); + } } } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index a677044a..434d51b6 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -42,11 +42,11 @@ namespace Tensorflow.Train protected IList _unconditional_checkpoint_dependencies; - protected IDictionary> _self_saveable_object_factories = - new Dictionary>(); + protected IDictionary> _self_saveable_object_factories = + new Dictionary>(); private bool _manual_tracking = true; - private static Trackable _none = new Function(); + private static Trackable _none = new AutoTrackable(); /// /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. /// The `None` can be any object that inherits `Trackable`. @@ -225,7 +225,7 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary> gather_saveables_for_checkpoint() + public virtual IDictionary> gather_saveables_for_checkpoint() { if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { @@ -251,6 +251,11 @@ namespace Tensorflow.Train { throw new NotImplementedException(); } + + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + { + throw new NotImplementedException(); + } } public record class TrackableReference(string Name, Trackable Refer); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 756024db..4005d564 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -6,6 +6,8 @@ using Tensorflow.Train; using static Tensorflow.Binding; using System.Collections.Generic; using Tensorflow.ModelSaving; +using System.Diagnostics; +using Tensorflow.Checkpoint; namespace Tensorflow { @@ -13,6 +15,7 @@ namespace Tensorflow { protected string _name; public virtual string Name => _handle_name; + public virtual string SharedName => _name; protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; @@ -50,6 +53,7 @@ namespace Tensorflow public Graph Graph => handle.graph; public string Device => handle.Device; EagerResourceDeleter eager_resource_deleter; + public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; public BaseResourceVariable() { @@ -77,6 +81,11 @@ namespace Tensorflow _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); } + else if(handle is null) + { + // TODO: fix this dangerous change. + _handle = IntPtr.Zero; + } else { _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); @@ -247,5 +256,60 @@ namespace Tensorflow else return value(); } + + public override (IDictionary, IDictionary) map_resources(SaveOptions save_options) + { + BaseResourceVariable new_variable; + if (save_options.experimental_variable_policy.save_variable_devices()) + { + tf.device(this.Device); + Debug.Assert(this is ResourceVariable); + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + else + { + new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + } + Dictionary obj_map = new(); + Dictionary resource_map = new(); + obj_map[this] = new_variable; + resource_map[this.handle] = new_variable.handle; + return (obj_map, resource_map); + } + + /// + /// Writes additional information of the variable into the SavedObject proto. + /// ubclasses of ResourceVariables could choose to override this method to + /// customize extra information to provide when saving a SavedModel. + /// + /// + /// + public virtual void write_object_proto(SavedObject proto, SaveOptions options) + { + resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); + } + + public override IDictionary> gather_saveables_for_checkpoint() + { + var res = new Dictionary>(); + res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; + return res; + } + + public Tensor is_initialized(string name = null) + { + return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); + } + + public Tensor read_value_no_copy() + { + Tensor value = null; + tf_with(ops.name_scope("Read"), _ => + { + // TODO: `no_copy = true`. + value = _read_variable_op(); + }); + return array_ops.identity(value); + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 6093f810..1645d713 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -41,6 +41,7 @@ namespace Tensorflow VariableAggregation aggregation = VariableAggregation.None, Shape shape = null) { + Aggregation = aggregation; if (variable_def != null) { if (initial_value != null) @@ -237,12 +238,5 @@ namespace Tensorflow { return _graph_element.eval(session); } - - public override IDictionary> gather_saveables_for_checkpoint() - { - var res = new Dictionary>(); - res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; - return res; - } } } diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs new file mode 100644 index 00000000..6c034995 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -0,0 +1,70 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Gradients; +using static Tensorflow.Binding; + +namespace Tensorflow.Variables +{ + /// + /// A variable with no initializer. + /// + public sealed class UninitializedVariable: BaseResourceVariable + { + // TODO: complete the arg list. + public UninitializedVariable( + bool trainable = true, + string caching_device = "", + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableAggregation aggregation = VariableAggregation.None, + Shape shape = null, + Tensor extra_handle_data = null) + { + string unique_id = ""; + string handle_name = ""; + tf_with(ops.init_scope(), (x) => + { + _in_graph_mode = !tf.Context.executing_eagerly(); + tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => + { + handle_name = ops.name_from_scope_name(name); + string? shared_name; + if (_in_graph_mode) + { + shared_name = handle_name; + unique_id = shared_name; + } + else + { + unique_id = $"{handle_name}-{ops.uid()}"; + shared_name = null; + } + var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); + // skip the assignment of `handle._parent_trackable` because of lack of API. + // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. + + if (_in_graph_mode) + { + tf_with(ops.name_scope("Read"), _ => + { + tf.device(handle.Device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + // _maybe_set_handle_data(dtype, handle, value) + _graph_element = value; + }); + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); + } + else + { + _graph_element = null; + } + }); + }); + _shape = shape; + _dtype = dtype; + base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index 23c40fbf..a221444b 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Engine } } - var layer_config = generic_utils.serialize_keras_object(layer); + var layer_config = generic_utils.serialize_layer_to_config(layer); layer_config.Name = layer.Name; layer_config.InboundNodes = filtered_inbound_nodes; layer_configs.Add(layer_config); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index ffb6f71b..fc405d87 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine; public abstract partial class Layer { - public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); + public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 6b064716..03b4b742 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -18,6 +18,7 @@ using System.Linq; using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving.SavedModel; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -105,5 +106,7 @@ namespace Tensorflow.Keras.Layers { return new InputLayer(args as InputLayerArgs); } + + public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 6a6e418c..4ff8f02f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -55,6 +55,7 @@ public partial class KerasSavedModelUtils var metadata = generate_keras_metadata(saved_nodes, node_paths); File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); + //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); if (!include_optimizer) { @@ -100,7 +101,8 @@ public partial class KerasSavedModelUtils Identifier = layer.ObjectIdentifier, Metadata = layer.TrackingMetadata }; - + + metadata.Nodes.Add(saved_object); } return metadata; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs index fc7eab3a..f7e1bf45 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/SaveImpl.cs @@ -24,26 +24,26 @@ public partial class KerasSavedModelUtils // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. // TODO: change the inherits of `Variable` and revise the implmentation. - var variables = layer.Variables.Select(x => + var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var trainable_variables = layer.TrainableVariables.Select(x => + })); + var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); - var non_trainable_variables = layer.non_trainable_variables.Select(x => - { - if (x is ResourceVariable or RefVariable) return (Trackable)x; - else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - }); + })); + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + { + if (x is ResourceVariable or RefVariable) return (Trackable)x; + else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); + })); Dictionary res = new(); - res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); - res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); - res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); + res["variables"] = variables; + res["trainable_variables"] = trainable_variables; + res["non_trainable_variables"] = non_trainable_variables; res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); return res; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs index 0235f87b..60c4ee5b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/base_serialization.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; public abstract class SavedModelSaver { - private Trackable _obj; + protected Trackable _obj; public SavedModelSaver(Trackable obj) { _obj = obj; diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index b092b595..655127af 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -2,6 +2,7 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; using Tensorflow.Train; @@ -9,10 +10,11 @@ namespace Tensorflow.Keras.Saving.SavedModel; public class LayerSavedModelSaver: SavedModelSaver { - private Layer _obj; + private Layer _layer; public LayerSavedModelSaver(Layer obj): base(obj) { _obj = obj; + _layer = obj; } public override string ObjectIdentifier { @@ -68,8 +70,8 @@ public class LayerSavedModelSaver: SavedModelSaver /// private (IDictionary, IDictionary) get_serialized_attributes_internal(IDictionary> serialization_cache) { - var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); - var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); + var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); + var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); functions["_default_save_signature"] = null; @@ -81,17 +83,18 @@ public class LayerSavedModelSaver: SavedModelSaver get { JObject metadata = new JObject(); - metadata["name"] = _obj.Name; - metadata["trainable"] = _obj.Trainable; + metadata["name"] = _layer.Name; + metadata["trainable"] = _layer.Trainable; // metadata["expects_training_arg"] = _obj._expects_training_arg; // metadata["dtype"] = policy.serialize(_obj._dtype_policy) - metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); + metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; - metadata["autocast"] = _obj.AutoCast; - - metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings + metadata["autocast"] = _layer.AutoCast; + + var temp = JObject.FromObject(get_serialized(_layer)); + metadata.Merge(temp, new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge @@ -108,4 +111,46 @@ public class LayerSavedModelSaver: SavedModelSaver return new Dictionary(); //return generic_utils.serialize_keras_object(obj); } +} + +public class InputLayerSavedModelSaver: SavedModelSaver +{ + public InputLayerSavedModelSaver(Layer obj) : base(obj) + { + + } + public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; + + public override IDictionary functions_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override IDictionary objects_to_serialize(IDictionary> serialization_cache) + { + return new Dictionary(); + } + + public override string TrackingMetadata + { + get + { + if(_obj is not Layer) + { + throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); + } + var layer = (Layer)_obj; + var info = new + { + class_name = layer.GetType().Name, + name = layer.Name, + dtype = layer.DType, + //sparse = layer.sparse, + //ragged = layer.ragged, + batch_input_shape = layer.BatchInputShape, + config = layer.get_config() + }; + return JsonConvert.SerializeObject(info); + } + } } \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index c2839cdc..68903eb2 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Collections; +using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.Saving; @@ -22,7 +24,12 @@ namespace Tensorflow.Keras.Utils { public class generic_utils { - public static LayerConfig serialize_keras_object(ILayer instance) + /// + /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. + /// + /// + /// + public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); return new LayerConfig