| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using Tensorflow.Train; | |||
| @@ -85,17 +86,18 @@ public static class CheckPointUtils | |||
| } | |||
| } | |||
| public static string get_full_name(Trackable var) | |||
| public static string get_full_name(Trackable variable) | |||
| { | |||
| // TODO: This state is not correct, the whole framework need to be updated in the future. | |||
| if (!(var is IVariableV1 || resource_variable_ops.is_resource_variable(var))) | |||
| if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable))) | |||
| { | |||
| return ""; | |||
| } | |||
| // skip the check of attribute `_save_slice_info` . | |||
| // TODO: Need to be revised!!! | |||
| return ((ResourceVariable)(object)var).Name; | |||
| Debug.Assert(variable is BaseResourceVariable); | |||
| return ((BaseResourceVariable)variable).Name; | |||
| } | |||
| public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto) | |||
| @@ -28,7 +28,7 @@ namespace Tensorflow.Checkpoint | |||
| ); | |||
| public static class SaveUtil | |||
| { | |||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null) | |||
| { | |||
| var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); | |||
| @@ -39,7 +39,7 @@ namespace Tensorflow.Checkpoint | |||
| var serialized_tensors = get_and_write_tensors_to_serialize(tensor_trackables, node_ids, call_with_mapped_captures, cache, object_graph_proto); | |||
| var registered_savers = get_and_write_registered_savers(registered_trackables, object_graph_proto); | |||
| Dictionary<Tensor, string> feed_additions; | |||
| Dictionary<Tensor, object> feed_additions; | |||
| if(cache is null) | |||
| { | |||
| feed_additions = null; | |||
| @@ -125,7 +125,7 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| // TODO: deal with cache. | |||
| var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; | |||
| var trackable = td.object_to_save; | |||
| Trackable trackable = null; | |||
| IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict; | |||
| if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) | |||
| { | |||
| @@ -134,6 +134,7 @@ namespace Tensorflow.Checkpoint | |||
| else | |||
| { | |||
| tensor_dict = get_tensors_from_trackable(td, call_with_mapped_captures, object_graph_proto); | |||
| trackable = td.object_to_save; | |||
| } | |||
| if(trackable is not null) | |||
| { | |||
| @@ -44,19 +44,19 @@ public static class SaveUtilV1 | |||
| return (checkpoint_factory_map, null); | |||
| } | |||
| public static (List<MySaveableObject>, object?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||
| public static (List<MySaveableObject>, IDictionary<string, IDictionary<string, Trackable>>?) frozen_saveables_and_savers(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, Graph? to_graph, bool call_with_mapped_captures, | |||
| object? saveables_cache = null) | |||
| { | |||
| if (to_graph is not null) | |||
| { | |||
| to_graph.as_default(); | |||
| var g = to_graph.as_default(); | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| // tensorflow python: `with ops.device("/cpu:0")` | |||
| var serialized = graph_proto.ToByteString().ToString(); | |||
| var object_graph_tensor = constant_op.constant("aaaa", TF_DataType.TF_STRING); | |||
| tf.device("/cpu:0"); | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| g.Exit(); | |||
| return (named_saveable_objects, registered_savers); | |||
| } | |||
| else | |||
| @@ -65,7 +65,7 @@ public static class SaveUtilV1 | |||
| { | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| // tensorflow python: `with ops.device("/cpu:0")` | |||
| tf.device("/cpu:0"); | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| return (named_saveable_objects, registered_savers); | |||
| @@ -73,7 +73,7 @@ public static class SaveUtilV1 | |||
| } | |||
| } | |||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, object?) serialize_gathered_objects(ObjectGraphView graph_view, | |||
| public static (List<MySaveableObject>, TrackableObjectGraph, object?, IDictionary<string, IDictionary<string, Trackable>>?) serialize_gathered_objects(ObjectGraphView graph_view, | |||
| IDictionary<Trackable, Trackable> object_map, bool call_with_mapped_captures, object? saveables_cache = null) | |||
| { | |||
| var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); | |||
| @@ -129,7 +129,7 @@ public static class SaveUtilV1 | |||
| return object_graph_proto; | |||
| } | |||
| private static (List<MySaveableObject>, object?, object?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||
| private static (List<MySaveableObject>, object?, IDictionary<string, IDictionary<string, Trackable>>?) add_attributes_to_object_graph(IList<Trackable> trackable_objects, | |||
| TrackableObjectGraph object_graph_proto, IDictionary<Trackable, int> node_ids, | |||
| IDictionary<Trackable, string> object_names, IDictionary<Trackable, Trackable> object_map, | |||
| bool call_with_mapped_captures, object? saveables_cache = null) | |||
| @@ -216,7 +216,7 @@ public static class SaveUtilV1 | |||
| public record class CheckpointFactoryData | |||
| ( | |||
| Maybe<ResourceVariable, MySaveableObject> factory, | |||
| Maybe<BaseResourceVariable, MySaveableObject> factory, | |||
| string name, | |||
| string checkpoint_key | |||
| ); | |||
| @@ -33,7 +33,7 @@ public class TrackableSaver | |||
| } | |||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, string>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| private (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph) | |||
| gather_serialized_tensors(Tensor? object_graph_tensor = null) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); | |||
| @@ -42,26 +42,27 @@ public class TrackableSaver | |||
| if(object_graph_tensor is null) | |||
| { | |||
| // tensorflow python: `with ops.device("/cpu:0"):` | |||
| object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
| tf.device("/cpu:0"); | |||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
| } | |||
| else | |||
| { | |||
| feed_additions[object_graph_tensor] = graph_proto.ToString(); | |||
| feed_additions[object_graph_tensor] = graph_proto.ToByteArray(); | |||
| } | |||
| Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| if (serialized_tensors.ContainsKey(Trackable.None)) | |||
| if (!serialized_tensors.ContainsKey(Trackable.None)) | |||
| { | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
| serialized_tensors[Trackable.None] = new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>(); | |||
| } | |||
| serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; | |||
| return (serialized_tensors, feed_additions, registered_savers, graph_proto); | |||
| } | |||
| private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||
| private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(Tensor file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||
| Func<(Tensor, IDictionary<Tensor, string>)> run_save = () => | |||
| Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||
| { | |||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||
| { | |||
| @@ -86,11 +87,11 @@ public class TrackableSaver | |||
| return run_save(); | |||
| } | |||
| private (Tensor, IDictionary<Tensor, string>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||
| private (Tensor, IDictionary<Tensor, object>) save_cached_when_graph_building(string file_prefix, Tensor object_graph_tensor, CheckpointOptions options) | |||
| { | |||
| var (serialized_tensors, feed_additions, registered_savers, graph_proto) = gather_serialized_tensors(object_graph_tensor); | |||
| Func<(Tensor, IDictionary<Tensor, string>)> run_save = () => | |||
| Func<(Tensor, IDictionary<Tensor, object>)> run_save = () => | |||
| { | |||
| if (_last_save_object_graph != graph_proto || tf.Context.executing_eagerly() || ops.inside_function()) | |||
| { | |||
| @@ -124,7 +125,7 @@ public class TrackableSaver | |||
| options = new CheckpointOptions(); | |||
| } | |||
| Dictionary<Tensor, string> feed_dict = new(); | |||
| Dictionary<Tensor, object> feed_dict = new(); | |||
| bool use_session = (!tf.Context.executing_eagerly() && !ops.inside_function()); | |||
| if (checkpoint_number is not null) | |||
| { | |||
| @@ -12,6 +12,8 @@ using System.Linq; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Training; | |||
| using Tensorflow.Graphs; | |||
| using System.Xml.Linq; | |||
| using System.Diagnostics; | |||
| namespace Tensorflow.Checkpoint | |||
| { | |||
| @@ -31,6 +33,10 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| return Func.DynamicInvoke(args); | |||
| } | |||
| public TR Invoke() | |||
| { | |||
| return Func.Invoke(); | |||
| } | |||
| } | |||
| internal record class FunctionHolder<TA1, TR>(Func<TA1, TR> Func) : IFunctionHolder | |||
| { | |||
| @@ -164,7 +170,6 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var slice_spec = slice.Key; | |||
| var maybe_tensor = slice.Value; | |||
| // TODO: deal with other types. Currently only `SaveSpec` is allowed. | |||
| if(maybe_tensor.DataType == typeof(SaveSpec)) | |||
| { | |||
| var spec = maybe_tensor.GetValueB(); | |||
| @@ -284,14 +289,16 @@ namespace Tensorflow.Checkpoint | |||
| var obj = pair.Key; | |||
| var tensor_dict = pair.Value; | |||
| IFunctionHolder restore_fn; | |||
| if(obj is null) | |||
| if(obj == Trackable.None) | |||
| { | |||
| restore_fn = new FunctionHolder<object?>(() => null); | |||
| } | |||
| else | |||
| { | |||
| restore_fn = null; | |||
| // TODO: implement obj._restore_from_tensors | |||
| restore_fn = new FunctionHolder<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>>(x => | |||
| { | |||
| return obj._restore_from_tensors(x); | |||
| }); | |||
| } | |||
| foreach(var item in tensor_dict) | |||
| @@ -343,7 +350,7 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| } | |||
| public Operation save(string file_prefix, CheckpointOptions? options= null) | |||
| public Operation save(Tensor file_prefix, CheckpointOptions? options= null) | |||
| { | |||
| if(options is null) | |||
| { | |||
| @@ -351,9 +358,9 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| tf.device("CPU"); // may be risky. | |||
| // TODO: optimize the implementation with new APIs adding to `string_ops`. | |||
| string sharded_suffix = Regex.Match(file_prefix, "^s3://.*").Success ? ".part" : "_temp/part"; | |||
| var tmp_checkpoint_prefix = tf.constant(file_prefix + sharded_suffix); | |||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
| constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||
| var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
| Operation save_fn() | |||
| @@ -385,7 +392,7 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||
| tf.device(merge_device); | |||
| return gen_ops.merge_v2checkpoints(tf.concat(saved_prefixes, 0), tf.constant(file_prefix), delete_old_dirs: true); | |||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
| } | |||
| } | |||
| @@ -400,9 +407,9 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| } | |||
| public Operation save(Tensor file_prefix, CheckpointOptions? options = null) => save(file_prefix.numpy().StringData()[0], options); | |||
| public Operation save(string file_prefix, CheckpointOptions? options = null) => save(tf.constant(file_prefix), options); | |||
| public IDictionary<string, Operation> restore(string file_prefix, CheckpointOptions? options = null) | |||
| public IDictionary<string, Operation> restore(Tensor file_prefix, CheckpointOptions? options = null) | |||
| { | |||
| if(options is null) | |||
| { | |||
| @@ -496,8 +503,10 @@ namespace Tensorflow.Checkpoint | |||
| public SaverDef to_proto() | |||
| { | |||
| var filename_tensor = array_ops.placeholder(TF_DataType.TF_STRING, new int[] { }, "saver_filename"); | |||
| var save_tensor = _traced_save(filename_tensor); | |||
| var restore_op = _traced_restore(filename_tensor).op; | |||
| var traced_save_func = tf.autograph.to_graph(_traced_save, TF_DataType.TF_STRING); | |||
| var traced_restore_func = tf.autograph.to_graph(_traced_restore, TF_DataType.TF_STRING); | |||
| var save_tensor = traced_save_func(filename_tensor); | |||
| var restore_op = traced_restore_func(filename_tensor).op; | |||
| return new SaverDef() | |||
| { | |||
| FilenameTensorName = filename_tensor.name, | |||
| @@ -507,10 +516,9 @@ namespace Tensorflow.Checkpoint | |||
| }; | |||
| } | |||
| [AutoGraph] | |||
| private Tensor _traced_save(Tensor file_prefix) | |||
| { | |||
| var save_op = save(file_prefix.StringData()[0]); | |||
| var save_op = save(file_prefix); | |||
| tf.device("cpu:0"); | |||
| using (ops.control_dependencies(new object[]{ save_op })) | |||
| { | |||
| @@ -518,24 +526,34 @@ namespace Tensorflow.Checkpoint | |||
| } | |||
| } | |||
| [AutoGraph] | |||
| private Tensor _traced_restore(Tensor file_prefix) | |||
| { | |||
| var restore_op = restore(file_prefix.StringData()[0]); | |||
| var restore_op = restore(file_prefix); | |||
| tf.device("cpu:0"); | |||
| using (ops.control_dependencies(new object[] { restore_op })) | |||
| using (ops.control_dependencies(restore_op.Values.ToArray())) | |||
| { | |||
| return array_ops.identity(file_prefix); | |||
| } | |||
| } | |||
| private static Tensor registered_saver_filename(string filename, string saver_name) | |||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
| { | |||
| Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| var trackable = new SaveableCompatibilityConverter(saveable, new List<MySaveableObject>() { saveable }); | |||
| serialized_tensors[trackable] = trackable.serialize_to_tensors(); | |||
| } | |||
| return new MultiDeviceSaver(serialized_tensors, registered_savers, call_with_mapped_captures); | |||
| } | |||
| private static Tensor registered_saver_filename(Tensor filename_tensor, string saver_name) | |||
| { | |||
| return tf.constant($"{filename}-{saver_name}"); | |||
| return gen_ops.string_join(new Tensor[] { filename_tensor, constant_op.constant($"-{saver_name}") }); | |||
| } | |||
| private static Tensor sharded_filename(Tensor filename_tensor, int shard, Tensor num_shards) | |||
| { | |||
| return filename_tensor; | |||
| return gen_ops.sharded_filename(filename_tensor, tf.constant(shard), num_shards); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Contexts; | |||
| using static Tensorflow.ApiDef.Types; | |||
| using static Tensorflow.CostGraphDef.Types; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Eager | |||
| { | |||
| internal class execute | |||
| { | |||
| public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||
| { | |||
| var v = values.Select(t => ops.convert_to_tensor(t, ctx:ctx)); | |||
| var types = v.Select(t => t.dtype.as_datatype_enum()); | |||
| return (types.ToArray(), v.ToArray()); | |||
| } | |||
| public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
| { | |||
| string device_name = ctx.DeviceName; | |||
| ctx.ensure_initialized(); | |||
| var tensors = tf.Runner.TFE_Execute(ctx, device_name, op_name, inputs, attrs, num_outputs); | |||
| return tensors; | |||
| } | |||
| } | |||
| } | |||
| @@ -406,5 +406,28 @@ namespace Tensorflow | |||
| meta_graph_def.MetaInfoDef.StrippedDefaultAttrs = true; | |||
| } | |||
| /// <summary> | |||
| /// Extract the Op name from a Tensor name. | |||
| /// </summary> | |||
| /// <param name="tensor_name"></param> | |||
| /// <returns></returns> | |||
| public static string op_name(string tensor_name) | |||
| { | |||
| if (string.IsNullOrEmpty(tensor_name)) | |||
| { | |||
| throw new ValueError($"Tensor name cannot be empty or None. Received: {tensor_name}."); | |||
| } | |||
| if (tensor_name.StartsWith("^")) | |||
| { | |||
| tensor_name = tensor_name.Substring(1); | |||
| } | |||
| if (tensor_name.Contains(":")) | |||
| { | |||
| return tensor_name.Split(':')[0]; | |||
| } | |||
| return tensor_name; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,11 +14,47 @@ namespace Tensorflow.ModelSaving | |||
| public IDictionary<string, object>? function_aliases { get; set; } = null; | |||
| public string? experimental_io_device { get; set; } = null; | |||
| // TODO: experimental | |||
| public Object? experimental_variable_polict { get; set; } = null; | |||
| public VariablePolicy experimental_variable_policy { get; set; } = VariablePolicy.None; | |||
| public bool experimental_custom_gradients { get; set; } = true; | |||
| public SaveOptions(bool save_debug_info = false) | |||
| { | |||
| this.save_debug_info = save_debug_info; | |||
| } | |||
| } | |||
| public class VariablePolicy | |||
| { | |||
| public string Policy { get; } | |||
| private VariablePolicy(string policy) | |||
| { | |||
| Policy = policy; | |||
| } | |||
| public static VariablePolicy None = new(null); | |||
| public static VariablePolicy SAVE_VARIABLE_DEVICES = new("save_variable_devices"); | |||
| public static VariablePolicy EXPAND_DISTRIBUTED_VARIABLES = new("expand_distributed_variables"); | |||
| public bool save_variable_devices() | |||
| { | |||
| return this != VariablePolicy.None; | |||
| } | |||
| /// <summary> | |||
| /// Tries to convert `obj` to a VariablePolicy instance. | |||
| /// </summary> | |||
| /// <param name="obj"></param> | |||
| /// <returns></returns> | |||
| public static VariablePolicy from_obj(object obj) | |||
| { | |||
| if (obj is null) return VariablePolicy.None; | |||
| if (obj is VariablePolicy) return (VariablePolicy)obj; | |||
| var key = obj.ToString().ToLower(); | |||
| return key switch | |||
| { | |||
| null => VariablePolicy.None, | |||
| "save_variable_devices" => VariablePolicy.SAVE_VARIABLE_DEVICES, | |||
| "expand_distributed_variables" => VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, | |||
| _ => throw new ValueError($"Received invalid VariablePolicy value: {obj}.") | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,6 +1,9 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Operations | |||
| @@ -17182,17 +17185,47 @@ namespace Tensorflow.Operations | |||
| /// path in the input checkpoint_prefixes. This is useful when those paths are non | |||
| /// user-facing temporary locations. | |||
| /// </remarks> | |||
| public static Operation merge_v2checkpoints(Tensor checkpoint_prefixes, Tensor destination_prefix, bool? delete_old_dirs = null, string name = "MergeV2Checkpoints") | |||
| { | |||
| public static Operation merge_v2_checkpoints(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs = true, bool allow_missing_files = false, string name = "MergeV2Checkpoints") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||
| checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); | |||
| result = null; | |||
| return null; | |||
| //try | |||
| //{ | |||
| // var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||
| // new object[] { checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files })); | |||
| // result = null; | |||
| // return null; | |||
| //} | |||
| //catch (System.Exception) | |||
| //{ | |||
| // return merge_v2_checkpoints_eager_fallback(checkpoint_prefixes, destination_prefix, delete_old_dirs: delete_old_dirs, | |||
| // allow_missing_files: allow_missing_files, name: name, ctx: ctx); | |||
| //} | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["checkpoint_prefixes"] = checkpoint_prefixes; | |||
| dict["destination_prefix"] = destination_prefix; | |||
| if (delete_old_dirs.HasValue) | |||
| dict["delete_old_dirs"] = delete_old_dirs.Value; | |||
| dict["delete_old_dirs"] = delete_old_dirs; | |||
| var op = tf.OpDefLib._apply_op_helper("MergeV2Checkpoints", name: name, keywords: dict); | |||
| return op; | |||
| } | |||
| //public static Operation merge_v2_checkpoints_eager_fallback(Tensor[] checkpoint_prefixes, Tensor destination_prefix, bool delete_old_dirs, bool allow_missing_files, string name, Context ctx) | |||
| //{ | |||
| // checkpoint_prefixes = ops.convert_to_tensor(checkpoint_prefixes, TF_DataType.TF_STRING); | |||
| // destination_prefix = ops.convert_to_tensor(destination_prefix, TF_DataType.TF_STRING); | |||
| // var inputs_flat = new Tensor[] { checkpoint_prefixes, destination_prefix }; | |||
| // var attrs = new object[] { "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files }; | |||
| // var result = execute.quick_execute("MergeV2Checkpoints", 0, inputs_flat, attrs, ctx, name); | |||
| // result = null; | |||
| // return null; | |||
| //} | |||
| /// <summary> | |||
| /// Transforms a spectrogram into a form that's useful for speech recognition. | |||
| /// </summary> | |||
| @@ -24259,6 +24292,12 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static Tensor regex_full_match(Tensor input, Tensor pattern, string name = "RegexFullMatch") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); | |||
| return result[0]; | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["input"] = input; | |||
| dict["pattern"] = pattern; | |||
| @@ -29744,6 +29783,12 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static Tensor sharded_filename(Tensor basename, Tensor shard, Tensor num_shards, string name = "ShardedFilename") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); | |||
| return result[0]; | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["basename"] = basename; | |||
| dict["shard"] = shard; | |||
| @@ -34668,6 +34713,12 @@ namespace Tensorflow.Operations | |||
| /// </remarks> | |||
| public static Tensor string_join(Tensor[] inputs, string separator = null, string name = "StringJoin") | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); | |||
| return result[0]; | |||
| } | |||
| var dict = new Dictionary<string, object>(); | |||
| dict["inputs"] = inputs; | |||
| if (separator != null) | |||
| @@ -14,7 +14,9 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Linq; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -23,11 +25,41 @@ namespace Tensorflow | |||
| { | |||
| public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (ctx.executing_eagerly()) | |||
| { | |||
| try | |||
| { | |||
| var result = tf.Runner.TFE_FastPathExecute( | |||
| new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); | |||
| result = null; | |||
| return null; | |||
| } | |||
| catch (System.Exception) | |||
| { | |||
| return save_v2_eager_fallback(prefix, tensor_names, shape_and_slices, tensors, name, ctx); | |||
| } | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors }); | |||
| return _op; | |||
| } | |||
| public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) | |||
| { | |||
| DataType[] attr_dtypes; | |||
| (attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); | |||
| prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||
| var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||
| var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
| var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); | |||
| var attrs = new object[] { "dtypes", attr_dtypes }; | |||
| var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); | |||
| result = null; | |||
| return null; | |||
| } | |||
| public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | |||
| { | |||
| var _op = tf.OpDefLib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||
| @@ -17,7 +17,9 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Variables; | |||
| using static Tensorflow.CppShapeInferenceResult.Types; | |||
| namespace Tensorflow | |||
| @@ -177,5 +179,57 @@ namespace Tensorflow | |||
| return HandleData.Parser.ParseFrom(handle.BufferToArray()); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Copies an existing variable to a new graph, with no initializer. | |||
| /// </summary> | |||
| /// <param name="variable"></param> | |||
| public static UninitializedVariable copy_to_graph_uninitialized(ResourceVariable variable) | |||
| { | |||
| var new_variable = new UninitializedVariable( | |||
| trainable: variable.Trainable, | |||
| shape: variable.shape, | |||
| dtype: variable.dtype, | |||
| name: variable.SharedName, | |||
| aggregation: variable.Aggregation, | |||
| extra_handle_data: null); | |||
| new_variable._maybe_initialize_trackable(); | |||
| return new_variable; | |||
| } | |||
| /// <summary> | |||
| /// Writes additional information of the variable into the SavedObject proto. | |||
| /// </summary> | |||
| /// <param name="resource_variable"></param> | |||
| /// <param name="proto"></param> | |||
| /// <param name="options"></param> | |||
| /// <param name="enforcing_naming"></param> | |||
| public static void write_object_proto_for_resource_variable(BaseResourceVariable resource_variable, SavedObject proto, SaveOptions options, bool enforcing_naming = true) | |||
| { | |||
| // lack of API: `proto.Variable.SetInParent()`. | |||
| if(enforcing_naming && !resource_variable.Name.EndsWith(":0")) | |||
| { | |||
| throw new ValueError($"Cowardly refusing to save variable {resource_variable.Name} because of " + | |||
| $"unexpected suffix in the name (expected ':0') which won't be restored."); | |||
| } | |||
| if(proto.Variable is null) | |||
| { | |||
| proto.Variable = new SavedVariable(); | |||
| } | |||
| proto.Variable.Name = meta_graph.op_name(resource_variable.Name); | |||
| proto.Variable.Trainable = resource_variable.Trainable; | |||
| proto.Variable.Dtype = resource_variable.dtype.as_datatype_enum(); | |||
| // TODO: lack of API `proto.Variable.Synchronization = resource_variable.synchronization.value`. | |||
| proto.Variable.Aggregation = resource_variable.Aggregation; | |||
| proto.Variable.Shape = resource_variable.shape.as_proto(); | |||
| if (options.experimental_variable_policy.save_variable_devices()) | |||
| { | |||
| if (!string.IsNullOrEmpty(resource_variable.Device)) | |||
| { | |||
| proto.Variable.Device = resource_variable.Device; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -14,6 +14,8 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class ResourceVariableSaveable : MySaveableObject | |||
| @@ -35,6 +37,32 @@ namespace Tensorflow | |||
| this.name = name; | |||
| } | |||
| public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, string name) | |||
| { | |||
| _var_device = var.Device; | |||
| _var_shape = var.shape; | |||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||
| { | |||
| tf.device(v.Device); | |||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| tf.device("/device:CPU:0"); | |||
| return array_ops.identity(x); | |||
| } | |||
| this.handle_op = var.Handle; | |||
| var tensor = _read_variable_closure(var); | |||
| var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); | |||
| _op = var; | |||
| specs = new SaveSpec[] { spec }; | |||
| this.name = name; | |||
| } | |||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||
| { | |||
| var restored_tensor = restored_tensors[0]; | |||
| @@ -14,11 +14,31 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Tensorflow.Checkpoint; | |||
| namespace Tensorflow | |||
| { | |||
| public class MySaveableObject | |||
| { | |||
| public Tensor op; | |||
| protected Maybe<Tensor, BaseResourceVariable> _op; | |||
| public Tensor op | |||
| { | |||
| get | |||
| { | |||
| if(_op.DataType == typeof(Tensor)) | |||
| { | |||
| return _op.GetValueA(); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("The _op is not a tensor."); | |||
| } | |||
| } | |||
| set | |||
| { | |||
| _op = value; | |||
| } | |||
| } | |||
| public SaveSpec[] specs; | |||
| public string name; | |||
| public string device; | |||
| @@ -35,7 +55,7 @@ namespace Tensorflow | |||
| public MySaveableObject(Tensor op, SaveSpec[] specs, string name) | |||
| { | |||
| this.op = op; | |||
| this._op = op; | |||
| this.specs = specs; | |||
| this.name = name; | |||
| } | |||
| @@ -10,6 +10,7 @@ using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using pbc = global::Google.Protobuf.Collections; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow; | |||
| @@ -75,7 +76,7 @@ public class SaveableView | |||
| private void initialize_save_and_restore_functions() | |||
| { | |||
| // TODO: deal with the return value of `get_checkpoint_factories_and_keys`. | |||
| SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); | |||
| var (checkpoint_factory_map, registered_savers) = SaveUtilV1.get_checkpoint_factories_and_keys(_object_names); | |||
| // skip the process of registered savers and the generation of saveable_objects_map and _obj_to_registered_saver. | |||
| _obj_to_registered_saver = new(); | |||
| _saveable_objects_map = new(); | |||
| @@ -191,7 +192,7 @@ public class SaveableView | |||
| /// </summary> | |||
| /// <param name="asset_index"></param> | |||
| /// <returns></returns> | |||
| public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index, SaveOptions options) | |||
| public SavedObjectGraph serialize_object_graph(IDictionary<object, object> asset_file_def_index) | |||
| { | |||
| SavedObjectGraph proto = new(); | |||
| fill_object_graph_proto(proto); | |||
| @@ -203,21 +204,20 @@ public class SaveableView | |||
| { | |||
| var obj = _nodes[i]; | |||
| var obj_proto = proto.Nodes[i]; | |||
| write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x), | |||
| options); | |||
| write_object_proto(obj, obj_proto, asset_file_def_index, x => _augmented_graph_view.list_children(x)); | |||
| } | |||
| return proto; | |||
| } | |||
| private static void write_object_proto(Trackable obj, SavedObject proto, | |||
| IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn, SaveOptions options) | |||
| IDictionary<object, object> asset_file_def_index, Func<Trackable, List<TrackableReference>> list_children_fn) | |||
| { | |||
| // skip the process of type Asset | |||
| if (resource_variable_ops.is_resource_variable(obj)) | |||
| { | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| var options = SaveContext.get_save_options(); | |||
| (obj as BaseResourceVariable).write_object_proto(proto, options); | |||
| } | |||
| else if (obj is Function) | |||
| { | |||
| @@ -10,6 +10,7 @@ using Tensorflow.ModelSaving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Exceptions; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow; | |||
| @@ -43,7 +44,7 @@ public static partial class SavedModelUtils | |||
| { | |||
| SavedModelUtils.get_or_create_variables_dir(export_dir); | |||
| CheckpointOptions ckpt_options = new(options.experimental_io_device); | |||
| object_saver.save(SavedModelUtils.get_variables_dir(export_dir), options:ckpt_options); | |||
| object_saver.save(SavedModelUtils.get_variables_path(export_dir), options:ckpt_options); | |||
| } | |||
| BuilderUtils.copy_assets_to_destination_dir(asset_info.asset_filename_map, export_dir); | |||
| @@ -68,6 +69,7 @@ public static partial class SavedModelUtils | |||
| var path = Path.Combine(tf.compat.as_str(export_dir), tf.compat.as_str(Constants.SAVED_MODEL_FILENAME_PB)); | |||
| File.WriteAllBytes(path, saved_model.ToByteArray()); | |||
| //File.WriteAllText(path, saved_model.ToString()); | |||
| if (options.save_debug_info) | |||
| { | |||
| @@ -83,45 +85,48 @@ public static partial class SavedModelUtils | |||
| Dictionary<Trackable, IEnumerable<TrackableReference>>) _build_meta_graph(Trackable obj, | |||
| ConcreteFunction? signatures, SaveOptions options, MetaGraphDef? meta_graph_def = null) | |||
| { | |||
| if (ops.inside_function()) | |||
| using (SaveContext.save_context(options)) | |||
| { | |||
| throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + | |||
| "Move the call to the outer eagerly-executed context."); | |||
| } | |||
| if (ops.inside_function()) | |||
| { | |||
| throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " + | |||
| "Move the call to the outer eagerly-executed context."); | |||
| } | |||
| if (meta_graph_def is null) | |||
| { | |||
| meta_graph_def = new MetaGraphDef(); | |||
| } | |||
| if (meta_graph_def is null) | |||
| { | |||
| meta_graph_def = new MetaGraphDef(); | |||
| } | |||
| AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||
| if (signatures is null) | |||
| { | |||
| signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); | |||
| } | |||
| // TODO: process of aignatures and wrapped_functions | |||
| AugmentedGraphView augmented_graph_view = new AugmentedGraphView(obj); | |||
| if (signatures is null) | |||
| { | |||
| signatures = SignatureSerializationUtils.find_function_to_export(augmented_graph_view); | |||
| } | |||
| SaveableView saveable_view = new SaveableView(augmented_graph_view, options); | |||
| TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); | |||
| var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, | |||
| options.namespace_white_list, options.experimental_custom_gradients); | |||
| if (options.function_aliases is not null) | |||
| { | |||
| var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; | |||
| foreach (var pair in options.function_aliases) | |||
| // TODO: process of aignatures and wrapped_functions | |||
| SaveableView saveable_view = new SaveableView(augmented_graph_view, options); | |||
| TrackableSaver object_saver = new TrackableSaver(augmented_graph_view); | |||
| var (asset_info, exported_graph) = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, | |||
| options.namespace_white_list, options.experimental_custom_gradients); | |||
| if (options.function_aliases is not null) | |||
| { | |||
| var alias = pair.Key; | |||
| var func = pair.Value; | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| var function_aliases = meta_graph_def.MetaInfoDef.FunctionAliases; | |||
| foreach (var pair in options.function_aliases) | |||
| { | |||
| var alias = pair.Key; | |||
| var func = pair.Value; | |||
| // TODO: complete it. | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index, options); | |||
| meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); | |||
| var object_graph_proto = saveable_view.serialize_object_graph(asset_info.asset_index); | |||
| meta_graph_def.ObjectGraphDef = new SavedObjectGraph(object_graph_proto); | |||
| return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); | |||
| return (meta_graph_def, exported_graph, object_saver, asset_info, saveable_view.Nodes, saveable_view.NodePaths); | |||
| } | |||
| } | |||
| private static (AssetInfo, Graph) _fill_meta_graph_def(MetaGraphDef meta_graph_def, SaveableView saveable_view, | |||
| @@ -134,7 +139,7 @@ public static partial class SavedModelUtils | |||
| Dictionary<Trackable, Trackable> object_map; | |||
| Dictionary<Tensor, Tensor> tensor_map; | |||
| AssetInfo asset_info; | |||
| exported_graph.as_default(); | |||
| var g = exported_graph.as_default(); | |||
| (object_map, tensor_map, asset_info) = saveable_view.map_resources(); | |||
| // TODO: deal with signatures. | |||
| if (save_custom_gradients) | |||
| @@ -161,15 +166,23 @@ public static partial class SavedModelUtils | |||
| // Lack `CopyFrom` API | |||
| // meta_graph_def.SignatureDef[Tensorflow.Constants.INIT_OP_SIGNATURE_KEY] | |||
| g.Exit(); | |||
| foreach (var obj in object_map.Values) | |||
| { | |||
| obj._maybe_initialize_trackable(); | |||
| } | |||
| // TODO: add the implementation of `call_with_mapped_functions`. | |||
| var (named_saveable_objects, registered_savers) = | |||
| SaveUtilV1.frozen_saveables_and_savers(saveable_view.AugmentedGraphView, object_map, exported_graph, false); | |||
| // TODO: complete the save of checkpoints with `MultiDeviceSaver`. | |||
| var saver = MultiDeviceSaver.from_saveables(named_saveable_objects, registered_savers, false); | |||
| var eg = exported_graph.as_default(); | |||
| var saver_def = saver.to_proto(); | |||
| meta_graph_def.SaverDef = saver_def; | |||
| eg.Exit(); | |||
| saveable_view.dependency_sorted_node_ids(); | |||
| @@ -0,0 +1,53 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.ModelSaving; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| /// <summary> | |||
| /// A context for building a graph of SavedModel. | |||
| /// </summary> | |||
| public static class SaveContext | |||
| { | |||
| // TODO: make it thead safe. | |||
| private static bool _in_save_context = false; | |||
| private static SaveOptions _save_options = null; | |||
| public static bool in_save_context() => _in_save_context; | |||
| public static SaveOptions get_save_options() | |||
| { | |||
| if (!in_save_context()) | |||
| { | |||
| throw new ValueError("Not in a SaveContext."); | |||
| } | |||
| return _save_options; | |||
| } | |||
| public static SaveContextHandler save_context(SaveOptions options) | |||
| { | |||
| return new SaveContextHandler(options); | |||
| } | |||
| public class SaveContextHandler: IDisposable | |||
| { | |||
| private bool _old_in_save_context; | |||
| private SaveOptions _old_save_options; | |||
| public SaveContextHandler(SaveOptions options) | |||
| { | |||
| if (SaveContext.in_save_context()) | |||
| { | |||
| throw new ValueError("Already in a SaveContext."); | |||
| } | |||
| _old_in_save_context = SaveContext._in_save_context; | |||
| SaveContext._in_save_context = true; | |||
| _old_save_options = SaveContext._save_options; | |||
| SaveContext._save_options = options; | |||
| } | |||
| public void Dispose() | |||
| { | |||
| SaveContext._in_save_context = _old_in_save_context; | |||
| SaveContext._save_options = _old_save_options; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -28,6 +28,11 @@ public static partial class SavedModelUtils | |||
| return Path.Combine(tf.compat.as_text(export_dir), tf.compat.as_text(Constants.VARIABLES_DIRECTORY)); | |||
| } | |||
| public static string get_variables_path(string export_dir) | |||
| { | |||
| return Path.Combine(tf.compat.as_text(get_variables_dir(export_dir)), tf.compat.as_text(Constants.VARIABLES_FILENAME)); | |||
| } | |||
| /// <summary> | |||
| /// Return assets sub-directory, or create one if it doesn't exist. | |||
| /// </summary> | |||
| @@ -19,6 +19,7 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Operations.Activation; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| using static Tensorflow.Binding; | |||
| @@ -117,8 +118,7 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| Debug.Assert(variable is ResourceVariable); | |||
| yield return new ResourceVariableSaveable((ResourceVariable)variable, "", name); | |||
| yield return new ResourceVariableSaveable(variable, "", name); | |||
| } | |||
| } | |||
| else | |||
| @@ -215,7 +215,7 @@ namespace Tensorflow | |||
| return names_to_saveables; | |||
| } | |||
| public static IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
| public static IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> saveable_objects_from_trackable(Trackable obj) | |||
| { | |||
| // skip the process of type `PythonState` | |||
| @@ -251,7 +251,7 @@ namespace Tensorflow | |||
| specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); | |||
| } | |||
| } | |||
| Dictionary<string, Maybe<ResourceVariable, MySaveableObject>> res = new(); | |||
| Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> res = new(); | |||
| res[name] = new TrackableSaveable(obj, specs, name, local_names, prefix); | |||
| return res; | |||
| } | |||
| @@ -270,25 +270,6 @@ namespace Tensorflow | |||
| { | |||
| return tf.compat.as_str(x); | |||
| } | |||
| } | |||
| public class SaveableCompatibilityConverter: Trackable | |||
| { | |||
| private Trackable _obj; | |||
| private IList<MySaveableObject> _saveables; | |||
| public SaveableCompatibilityConverter(Trackable obj, IList<MySaveableObject> saveables) | |||
| { | |||
| _obj= obj; | |||
| _saveables= saveables; | |||
| } | |||
| public Trackable Obj => _obj; | |||
| public IList<MySaveableObject> mySaveables=> _saveables; | |||
| public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| { | |||
| return saveable_object_to_tensor_dict(_saveables); | |||
| } | |||
| /// <summary> | |||
| /// Converts a list of SaveableObjects to a tensor dictionary. | |||
| @@ -299,11 +280,11 @@ namespace Tensorflow | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new(); | |||
| foreach (var saveable in saveables) | |||
| { | |||
| foreach(var spec in saveable.specs) | |||
| foreach (var spec in saveable.specs) | |||
| { | |||
| // skip the check that if `spec` is callable. | |||
| var name = saveable_object_util.convert_to_string(spec.name); | |||
| var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | |||
| var name = convert_to_string(spec.name); | |||
| var slice_spec = convert_to_string(spec.slice_spec); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| tensor_dict.SetDefault(name, new Dictionary<string, Tensor>()).GetValueB()[slice_spec] = spec.tensor; | |||
| @@ -316,5 +297,81 @@ namespace Tensorflow | |||
| } | |||
| return tensor_dict; | |||
| } | |||
| /// <summary> | |||
| /// Generates `Trackable._restore_from_tensors` from SaveableObjects. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| public static Func<IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>, IDictionary<string, Operation>> saveable_object_to_restore_fn(IList<MySaveableObject> saveables) | |||
| { | |||
| return (restored_tensors) => | |||
| { | |||
| Dictionary<string, Operation> restored_ops = new(); | |||
| foreach(var saveable in saveables) | |||
| { | |||
| List<Tensor> saveable_restored_tensors = new(); | |||
| foreach(var spec in saveable.specs) | |||
| { | |||
| var name = TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(spec.name)); | |||
| var slice_spec = saveable_object_util.convert_to_string(spec.slice_spec); | |||
| var maybe_tensor = restored_tensors[name]; | |||
| IDictionary<string, Tensor> dict; | |||
| if(maybe_tensor.DataType == typeof(Tensor)) | |||
| { | |||
| dict = new Dictionary<string, Tensor>(); | |||
| dict[""] = maybe_tensor.GetValueA(); | |||
| } | |||
| else | |||
| { | |||
| dict = maybe_tensor.GetValueB(); | |||
| } | |||
| saveable_restored_tensors.Add(dict[slice_spec]); | |||
| } | |||
| restored_ops[saveable.name] = saveable.restore(saveable_restored_tensors.ToArray(), null); | |||
| } | |||
| return restored_ops; | |||
| }; | |||
| } | |||
| } | |||
| public class SaveableCompatibilityConverter: Trackable | |||
| { | |||
| private object _obj; | |||
| private IList<MySaveableObject> _saveables; | |||
| public SaveableCompatibilityConverter(object obj, IList<MySaveableObject> saveables) | |||
| { | |||
| _obj= obj; | |||
| _saveables= saveables; | |||
| } | |||
| public object Obj => _obj; | |||
| public IList<MySaveableObject> mySaveables=> _saveables; | |||
| public override IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> serialize_to_tensors() | |||
| { | |||
| return saveable_object_util.saveable_object_to_tensor_dict(_saveables); | |||
| } | |||
| /// <summary> | |||
| /// Returns the restore ops defined in the Saveables. | |||
| /// </summary> | |||
| /// <param name="restored_tensors"></param> | |||
| /// <returns></returns> | |||
| public override IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| { | |||
| List<string> expected_keys = new(); | |||
| foreach(var saveable in _saveables) | |||
| { | |||
| expected_keys.AddRange(saveable.specs.Select(x => TrackableUtils.extract_local_name(saveable_object_util.convert_to_string(x.name)))); | |||
| } | |||
| if (!expected_keys.Distinct().SequenceEqual(restored_tensors.Keys)) | |||
| { | |||
| throw new ValueError($"Could not restore object {_obj} because not all expected tensors were in the checkpoint." + | |||
| $"\n\tExpected: {expected_keys} \n\tGot: {list(restored_tensors.Keys)}"); | |||
| } | |||
| return saveable_object_util.saveable_object_to_restore_fn(_saveables).Invoke(restored_tensors); | |||
| } | |||
| } | |||
| } | |||
| @@ -42,11 +42,11 @@ namespace Tensorflow.Train | |||
| protected IList<TrackableReference> _unconditional_checkpoint_dependencies; | |||
| protected IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
| new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||
| protected IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> _self_saveable_object_factories = | |||
| new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| private bool _manual_tracking = true; | |||
| private static Trackable _none = new Function(); | |||
| private static Trackable _none = new AutoTrackable(); | |||
| /// <summary> | |||
| /// This is a trick for that CSharp does not allow the key of `Dictionary` to be null. | |||
| /// The `None` can be any object that inherits `Trackable`. | |||
| @@ -225,7 +225,7 @@ namespace Tensorflow.Train | |||
| return self_tensor_map.Keys.ToList(); | |||
| } | |||
| public virtual IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| public virtual IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| { | |||
| if (saveable_object_util.trackable_has_serialize_to_tensor(this)) | |||
| { | |||
| @@ -251,6 +251,11 @@ namespace Tensorflow.Train | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| public virtual IDictionary<string, Operation> _restore_from_tensors(IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| public record class TrackableReference(string Name, Trackable Refer); | |||
| @@ -6,6 +6,8 @@ using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.ModelSaving; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Checkpoint; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -13,6 +15,7 @@ namespace Tensorflow | |||
| { | |||
| protected string _name; | |||
| public virtual string Name => _handle_name; | |||
| public virtual string SharedName => _name; | |||
| protected TF_DataType _dtype; | |||
| public TF_DataType dtype => _dtype; | |||
| protected string _handle_name; | |||
| @@ -50,6 +53,7 @@ namespace Tensorflow | |||
| public Graph Graph => handle.graph; | |||
| public string Device => handle.Device; | |||
| EagerResourceDeleter eager_resource_deleter; | |||
| public VariableAggregation Aggregation { get; protected set; } = VariableAggregation.None; | |||
| public BaseResourceVariable() | |||
| { | |||
| @@ -77,6 +81,11 @@ namespace Tensorflow | |||
| _handle = handle.EagerTensorHandle.DangerousGetHandle(); | |||
| eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | |||
| } | |||
| else if(handle is null) | |||
| { | |||
| // TODO: fix this dangerous change. | |||
| _handle = IntPtr.Zero; | |||
| } | |||
| else | |||
| { | |||
| _handle = handle.Handle == null ? IntPtr.Zero : handle.Handle.DangerousGetHandle(); | |||
| @@ -247,5 +256,60 @@ namespace Tensorflow | |||
| else | |||
| return value(); | |||
| } | |||
| public override (IDictionary<Trackable, Trackable>, IDictionary<Tensor, Tensor>) map_resources(SaveOptions save_options) | |||
| { | |||
| BaseResourceVariable new_variable; | |||
| if (save_options.experimental_variable_policy.save_variable_devices()) | |||
| { | |||
| tf.device(this.Device); | |||
| Debug.Assert(this is ResourceVariable); | |||
| new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
| } | |||
| else | |||
| { | |||
| new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
| } | |||
| Dictionary<Trackable, Trackable> obj_map = new(); | |||
| Dictionary<Tensor, Tensor> resource_map = new(); | |||
| obj_map[this] = new_variable; | |||
| resource_map[this.handle] = new_variable.handle; | |||
| return (obj_map, resource_map); | |||
| } | |||
| /// <summary> | |||
| /// Writes additional information of the variable into the SavedObject proto. | |||
| /// ubclasses of ResourceVariables could choose to override this method to | |||
| /// customize extra information to provide when saving a SavedModel. | |||
| /// </summary> | |||
| /// <param name="proto"></param> | |||
| /// <param name="options"></param> | |||
| public virtual void write_object_proto(SavedObject proto, SaveOptions options) | |||
| { | |||
| resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); | |||
| } | |||
| public override IDictionary<string, Maybe<BaseResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| { | |||
| var res = new Dictionary<string, Maybe<BaseResourceVariable, MySaveableObject>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
| return res; | |||
| } | |||
| public Tensor is_initialized(string name = null) | |||
| { | |||
| return gen_resource_variable_ops.var_is_initialized_op(this.handle, name); | |||
| } | |||
| public Tensor read_value_no_copy() | |||
| { | |||
| Tensor value = null; | |||
| tf_with(ops.name_scope("Read"), _ => | |||
| { | |||
| // TODO: `no_copy = true`. | |||
| value = _read_variable_op(); | |||
| }); | |||
| return array_ops.identity(value); | |||
| } | |||
| } | |||
| } | |||
| @@ -41,6 +41,7 @@ namespace Tensorflow | |||
| VariableAggregation aggregation = VariableAggregation.None, | |||
| Shape shape = null) | |||
| { | |||
| Aggregation = aggregation; | |||
| if (variable_def != null) | |||
| { | |||
| if (initial_value != null) | |||
| @@ -237,12 +238,5 @@ namespace Tensorflow | |||
| { | |||
| return _graph_element.eval(session); | |||
| } | |||
| public override IDictionary<string, Maybe<ResourceVariable, MySaveableObject>> gather_saveables_for_checkpoint() | |||
| { | |||
| var res = new Dictionary<string, Maybe<ResourceVariable, MySaveableObject>>(); | |||
| res[Trackable.Constants.VARIABLE_VALUE_KEY] = this; | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,70 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Gradients; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Variables | |||
| { | |||
| /// <summary> | |||
| /// A variable with no initializer. | |||
| /// </summary> | |||
| public sealed class UninitializedVariable: BaseResourceVariable | |||
| { | |||
| // TODO: complete the arg list. | |||
| public UninitializedVariable( | |||
| bool trainable = true, | |||
| string caching_device = "", | |||
| string name = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| VariableAggregation aggregation = VariableAggregation.None, | |||
| Shape shape = null, | |||
| Tensor extra_handle_data = null) | |||
| { | |||
| string unique_id = ""; | |||
| string handle_name = ""; | |||
| tf_with(ops.init_scope(), (x) => | |||
| { | |||
| _in_graph_mode = !tf.Context.executing_eagerly(); | |||
| tf_with(ops.name_scope(name, "Variable", skip_on_eager: false), name => | |||
| { | |||
| handle_name = ops.name_from_scope_name(name); | |||
| string? shared_name; | |||
| if (_in_graph_mode) | |||
| { | |||
| shared_name = handle_name; | |||
| unique_id = shared_name; | |||
| } | |||
| else | |||
| { | |||
| unique_id = $"{handle_name}-{ops.uid()}"; | |||
| shared_name = null; | |||
| } | |||
| var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( | |||
| shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); | |||
| // skip the assignment of `handle._parent_trackable` because of lack of API. | |||
| // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. | |||
| if (_in_graph_mode) | |||
| { | |||
| tf_with(ops.name_scope("Read"), _ => | |||
| { | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| // _maybe_set_handle_data(dtype, handle, value) | |||
| _graph_element = value; | |||
| }); | |||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
| } | |||
| else | |||
| { | |||
| _graph_element = null; | |||
| } | |||
| }); | |||
| }); | |||
| _shape = shape; | |||
| _dtype = dtype; | |||
| base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); | |||
| } | |||
| } | |||
| } | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| var layer_config = generic_utils.serialize_keras_object(layer); | |||
| var layer_config = generic_utils.serialize_layer_to_config(layer); | |||
| layer_config.Name = layer.Name; | |||
| layer_config.InboundNodes = filtered_inbound_nodes; | |||
| layer_configs.Add(layer_config); | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine; | |||
| public abstract partial class Layer | |||
| { | |||
| public LayerSavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||
| public virtual SavedModelSaver TrackableSavedModelSaver => new LayerSavedModelSaver(this); | |||
| public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
| @@ -18,6 +18,7 @@ using System.Linq; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| @@ -105,5 +106,7 @@ namespace Tensorflow.Keras.Layers | |||
| { | |||
| return new InputLayer(args as InputLayerArgs); | |||
| } | |||
| public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | |||
| } | |||
| } | |||
| @@ -55,6 +55,7 @@ public partial class KerasSavedModelUtils | |||
| var metadata = generate_keras_metadata(saved_nodes, node_paths); | |||
| File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray()); | |||
| //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString()); | |||
| if (!include_optimizer) | |||
| { | |||
| @@ -100,7 +101,8 @@ public partial class KerasSavedModelUtils | |||
| Identifier = layer.ObjectIdentifier, | |||
| Metadata = layer.TrackingMetadata | |||
| }; | |||
| metadata.Nodes.Add(saved_object); | |||
| } | |||
| return metadata; | |||
| @@ -24,26 +24,26 @@ public partial class KerasSavedModelUtils | |||
| // TODO: deal with losses and metrics. Currently, `Layer` lacks these two APIs. | |||
| // TODO: change the inherits of `Variable` and revise the implmentation. | |||
| var variables = layer.Variables.Select(x => | |||
| var variables = TrackableDataStructure.wrap_or_unwrap(layer.Variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| var trainable_variables = layer.TrainableVariables.Select(x => | |||
| })); | |||
| var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| var non_trainable_variables = layer.non_trainable_variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| }); | |||
| })); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| Dictionary<string, Trackable> res = new(); | |||
| res["variables"] = TrackableDataStructure.wrap_or_unwrap(variables); | |||
| res["trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(trainable_variables); | |||
| res["non_trainable_variables"] = TrackableDataStructure.wrap_or_unwrap(non_trainable_variables); | |||
| res["variables"] = variables; | |||
| res["trainable_variables"] = trainable_variables; | |||
| res["non_trainable_variables"] = non_trainable_variables; | |||
| res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); | |||
| return res; | |||
| @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public abstract class SavedModelSaver | |||
| { | |||
| private Trackable _obj; | |||
| protected Trackable _obj; | |||
| public SavedModelSaver(Trackable obj) | |||
| { | |||
| _obj = obj; | |||
| @@ -2,6 +2,7 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow.Train; | |||
| @@ -9,10 +10,11 @@ namespace Tensorflow.Keras.Saving.SavedModel; | |||
| public class LayerSavedModelSaver: SavedModelSaver | |||
| { | |||
| private Layer _obj; | |||
| private Layer _layer; | |||
| public LayerSavedModelSaver(Layer obj): base(obj) | |||
| { | |||
| _obj = obj; | |||
| _layer = obj; | |||
| } | |||
| public override string ObjectIdentifier | |||
| { | |||
| @@ -68,8 +70,8 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
| /// <param name="serialization_cache"></param> | |||
| private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache); | |||
| var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache); | |||
| var objects = KerasSavedModelUtils.wrap_layer_objects(_layer, serialization_cache); | |||
| var functions = KerasSavedModelUtils.wrap_layer_functions(_layer, serialization_cache); | |||
| functions["_default_save_signature"] = null; | |||
| @@ -81,17 +83,18 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
| get | |||
| { | |||
| JObject metadata = new JObject(); | |||
| metadata["name"] = _obj.Name; | |||
| metadata["trainable"] = _obj.Trainable; | |||
| metadata["name"] = _layer.Name; | |||
| metadata["trainable"] = _layer.Trainable; | |||
| // metadata["expects_training_arg"] = _obj._expects_training_arg; | |||
| // metadata["dtype"] = policy.serialize(_obj._dtype_policy) | |||
| metadata["batch_input_shape"] = _obj.BatchInputShape is null ? null : JToken.FromObject(_obj.BatchInputShape); | |||
| metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); | |||
| // metadata["stateful"] = _obj.stateful; | |||
| // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | |||
| // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | |||
| metadata["autocast"] = _obj.AutoCast; | |||
| metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings | |||
| metadata["autocast"] = _layer.AutoCast; | |||
| var temp = JObject.FromObject(get_serialized(_layer)); | |||
| metadata.Merge(temp, new JsonMergeSettings | |||
| { | |||
| // Handle conflicts by using values from obj2 | |||
| MergeArrayHandling = MergeArrayHandling.Merge | |||
| @@ -108,4 +111,46 @@ public class LayerSavedModelSaver: SavedModelSaver | |||
| return new Dictionary<string, object>(); | |||
| //return generic_utils.serialize_keras_object(obj); | |||
| } | |||
| } | |||
| public class InputLayerSavedModelSaver: SavedModelSaver | |||
| { | |||
| public InputLayerSavedModelSaver(Layer obj) : base(obj) | |||
| { | |||
| } | |||
| public override string ObjectIdentifier => Constants.INPUT_LAYER_IDENTIFIER; | |||
| public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, IDictionary<Trackable, ISerializedAttributes>> serialization_cache) | |||
| { | |||
| return new Dictionary<string, Trackable>(); | |||
| } | |||
| public override string TrackingMetadata | |||
| { | |||
| get | |||
| { | |||
| if(_obj is not Layer) | |||
| { | |||
| throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); | |||
| } | |||
| var layer = (Layer)_obj; | |||
| var info = new | |||
| { | |||
| class_name = layer.GetType().Name, | |||
| name = layer.Name, | |||
| dtype = layer.DType, | |||
| //sparse = layer.sparse, | |||
| //ragged = layer.ragged, | |||
| batch_input_shape = layer.BatchInputShape, | |||
| config = layer.get_config() | |||
| }; | |||
| return JsonConvert.SerializeObject(info); | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,8 @@ | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Keras.Saving; | |||
| @@ -22,7 +24,12 @@ namespace Tensorflow.Keras.Utils | |||
| { | |||
| public class generic_utils | |||
| { | |||
| public static LayerConfig serialize_keras_object(ILayer instance) | |||
| /// <summary> | |||
| /// This method does not have corresponding method in python. It's close to `serialize_keras_object`. | |||
| /// </summary> | |||
| /// <param name="instance"></param> | |||
| /// <returns></returns> | |||
| public static LayerConfig serialize_layer_to_config(ILayer instance) | |||
| { | |||
| var config = instance.get_config(); | |||
| return new LayerConfig | |||