From acae9b3e39b4b5b12bc2fdeff21c4d566c486d10 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 30 Mar 2023 15:42:38 +0800 Subject: [PATCH] Partially support the analysis of loaded functions. --- TensorFlow.NET.sln | 18 ++- .../Extensions/OneofExtension.cs | 13 ++ Tensorflow.Common/Tensorflow.Common.csproj | 11 ++ src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 20 +-- .../Checkpoint/SaveUtilV1.cs | 7 +- .../Checkpoint/checkpoint.cs | 9 +- .../Checkpoint/functional_saver.cs | 136 +++--------------- src/TensorFlowNET.Core/Checkpoint/restore.cs | 33 ++--- .../Eager/forwardprop_util.cs | 13 ++ .../Functions/ConcreteFunction.cs | 70 ++++++++- .../Functions/EagerDefinedFunction.cs | 53 ++++++- src/TensorFlowNET.Core/Functions/Function.cs | 40 +++++- .../Functions/TapeGradientFunctions.cs | 16 ++- .../Functions/function_saved_model_utils.cs | 1 + .../Functions/monomorphic_function.cs | 27 +++- .../Gradients/gradients_util.cs | 5 + src/TensorFlowNET.Core/Graphs/Graph.cs | 1 + .../Operations/c_api.ops.cs | 4 + .../Operations/functional_ops.cs | 70 +++++++++ .../Operations/gen_functional_ops.cs | 83 +++++++++++ .../Operations/handle_data_util.cs | 25 +++- .../Operations/resource_variable_ops.cs | 18 ++- .../Protobuf/CppShapeInference.cs | 2 +- .../Protobuf/SavedObjectGraph.cs | 25 ++-- .../Tensorflow.Binding.csproj | 5 + src/TensorFlowNET.Core/Tensors/Tensor.cs | 1 + .../Training/Saving/SaveableObject.cs | 7 +- .../SavedModel/function_deserialization.cs | 53 ++++++- .../Training/Saving/SavedModel/loader.cs | 36 +++-- .../Saving/saveable_object_util.py.cs | 36 ++--- src/TensorFlowNET.Core/Training/Trackable.cs | 17 +-- src/TensorFlowNET.Core/Util/function_utils.cs | 23 +++ .../Variables/BaseResourceVariable.cs | 7 +- .../Variables/ResourceVariable.cs | 10 +- .../Variables/UninitializedVariable.cs | 9 +- src/TensorFlowNET.Core/ops.cs | 8 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 4 +- .../Saving/KerasObjectLoader.cs | 79 +++++++--- .../Saving/SavedModel/ReviveUtils.cs | 13 +- .../Saving/SavedModel/RevivedInputLayer.cs | 15 ++ .../Saving/SavedModel/RevivedLayer.cs | 27 ++++ .../SavedModel/serialized_attributes.cs | 16 +-- 42 files changed, 782 insertions(+), 284 deletions(-) create mode 100644 Tensorflow.Common/Extensions/OneofExtension.cs create mode 100644 Tensorflow.Common/Tensorflow.Common.csproj create mode 100644 src/TensorFlowNET.Core/Eager/forwardprop_util.cs create mode 100644 src/TensorFlowNET.Core/Operations/gen_functional_ops.cs create mode 100644 src/TensorFlowNET.Core/Util/function_utils.cs create mode 100644 src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 8846d5bf..433cace0 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.31624.102 +# Visual Studio Version 17 +VisualStudioVersion = 17.4.33213.308 MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" EndProject @@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -153,6 +155,18 @@ Global {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64 {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU {3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU + {0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Tensorflow.Common/Extensions/OneofExtension.cs b/Tensorflow.Common/Extensions/OneofExtension.cs new file mode 100644 index 00000000..c7fb8093 --- /dev/null +++ b/Tensorflow.Common/Extensions/OneofExtension.cs @@ -0,0 +1,13 @@ +using OneOf; +using System; + +namespace Tensorflow.Common.Extensions +{ + public static class OneofExtension + { + public static bool IsTypeOrDeriveFrom(this IOneOf src) + { + return src.Value is T; + } + } +} diff --git a/Tensorflow.Common/Tensorflow.Common.csproj b/Tensorflow.Common/Tensorflow.Common.csproj new file mode 100644 index 00000000..0501cded --- /dev/null +++ b/Tensorflow.Common/Tensorflow.Common.csproj @@ -0,0 +1,11 @@ + + + + netstandard2.0 + + + + + + + diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index c54cc93f..8b8cbf61 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -1,10 +1,12 @@ -using System; +using OneOf; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using Tensorflow.Train; using Tensorflow.Training; +using Tensorflow.Common.Extensions; using pbc = global::Google.Protobuf.Collections; namespace Tensorflow.Checkpoint @@ -28,7 +30,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -117,16 +119,16 @@ namespace Tensorflow.Checkpoint /// /// /// - private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) { - Dictionary>>> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach(var td in tensor_trackables) { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; Trackable trackable = null; - IDictionary>> tensor_dict; + IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); @@ -148,12 +150,12 @@ namespace Tensorflow.Checkpoint return serialized_tensors; } - private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { var trackable = trackable_data.object_to_save; // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. - IDictionary>> ret_tensor_dict; + IDictionary>> ret_tensor_dict; if (call_with_mapped_captures) { throw new NotImplementedException(); @@ -164,7 +166,7 @@ namespace Tensorflow.Checkpoint } // TODO: deal with the type `SaveSpce` (currently it will never be it). - Dictionary>> tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach(var pair in ret_tensor_dict) { var local_name = TrackableUtils.escape_local_name(pair.Key); @@ -200,7 +202,7 @@ namespace Tensorflow.Checkpoint /// /// /// - private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { Dictionary object_names = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 72372e41..c77c343c 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -8,6 +8,7 @@ using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; using Google.Protobuf; +using OneOf; namespace Tensorflow.Checkpoint; @@ -179,13 +180,13 @@ public static class SaveUtilV1 // TODO: tensorflow python has a process with callable `saveable_factory`. List saveables = new(); - if (maybe_saveable.TryGet(out var s)) + if (maybe_saveable.TryPickT1(out var s, out var variable)) { saveables.Add(s); } else { - saveables.AddRange(saveable_object_util.saveable_objects_for_op(maybe_saveable.GetValue() as Trackable, key)); + saveables.AddRange(saveable_object_util.saveable_objects_for_op(variable as Trackable, key)); } foreach (var saveable in saveables) @@ -217,7 +218,7 @@ public static class SaveUtilV1 public record class CheckpointFactoryData ( - Func> factory, + Func> factory, string name, string checkpoint_key ); diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 1934ffd5..445fd685 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -12,6 +12,7 @@ using static Tensorflow.Binding; using Tensorflow.Operations; using Newtonsoft.Json; using Tensorflow.Training; +using OneOf; namespace Tensorflow.Checkpoint; @@ -49,7 +50,7 @@ public class TrackableSaver } - private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -68,7 +69,7 @@ public class TrackableSaver Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); if (!serialized_tensors.ContainsKey(Trackable.None)) { - serialized_tensors[Trackable.None] = new Dictionary>>(); + serialized_tensors[Trackable.None] = new Dictionary>>(); } serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; return (serialized_tensors, feed_additions, registered_savers, graph_proto); @@ -400,7 +401,7 @@ public class CheckpointRestoreCoordinator // skip the callback. } - public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null) + public List restore_saveables(Dictionary> tensor_saveables, List positions, object? registered_savers = null) { List restore_ops = new(); foreach(var position in positions) @@ -412,7 +413,7 @@ public class CheckpointRestoreCoordinator Dictionary variable_dict = new(); foreach(var item in tensor_saveables) { - if(item.Value.TryGet(out var variable)) + if(item.Value.TryPickT0(out var variable, out var _)) { variable_dict[item.Key] = variable; } diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 96e6c8dd..3b49fa8d 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -15,106 +15,14 @@ using Tensorflow.Graphs; using System.Xml.Linq; using System.Diagnostics; using RestoreFunc = System.Func; +using OneOf; namespace Tensorflow.Checkpoint { - public class Maybe - { - private TA? _valueA = default(TA); - private TB? _valueB = default(TB); - private Type _type; - private bool _assignedTA; - public Maybe(TA value) - { - _valueA = value; - _type= typeof(TA); - _assignedTA = true; - } - public Maybe(TB value) - { - _valueB = value; - _type = typeof(TB); - _assignedTA = false; - } - - public Type DataType => _type; - - /// - /// Try to get the type T member of this instance. It returns true when TA or TB derive from T and is correspondingly assigned. - /// It returns - /// - /// - /// - /// - public bool TryGet(out T? res) - { - if(_valueA is T && _valueB is not T) - { - res = (T)(object)_valueA; - return _assignedTA; - } - else if(_valueA is not T && _valueB is T) - { - res = (T)(object)_valueB; - return !_assignedTA; - } - res = default(T); - return false; - } - - public bool IsTypeOrDeriveFrom() - { - if (_valueA is T && _valueB is not T) - { - return _assignedTA; - } - else if (_valueA is not T && _valueB is T) - { - return !_assignedTA; - } - else if (_valueA is T && _valueB is T) - { - return true; - } - else - { - return false; - } - } - - public T GetValue() - { - if (_valueA is T && _valueB is not T) - { - return (T)(object)_valueA; - } - else if (_valueA is not T && _valueB is T) - { - return (T)(object)_valueB; - } - else if (_valueA is T && _valueB is T) - { - throw new TypeError("The type is vague, this is always because TA and TB both derive from T."); - } - else - { - throw new TypeError($"Expected {typeof(TA)} or {typeof(TB)}, but got typeof{typeof(T)}."); - } - } - - public static implicit operator Maybe(TA a) - { - return new Maybe(a); - } - public static implicit operator Maybe(TB b) - { - return new Maybe(b); - } - } internal class SingleDeviceSaver { - private IDictionary>> _tensor_slice_dict; - public SingleDeviceSaver(IDictionary>> tensor_slice_dict) + private IDictionary>> _tensor_slice_dict; + public SingleDeviceSaver(IDictionary>> tensor_slice_dict) { _tensor_slice_dict = tensor_slice_dict; } @@ -122,15 +30,15 @@ namespace Tensorflow.Checkpoint { _tensor_slice_dict = tensor_slice_dict.ToDictionary( x => x.Key, x => x.Value.ToDictionary( - y => y.Key, y => new Maybe(y.Value)) - as IDictionary>); + y => y.Key, y => OneOf.FromT0(y.Value)) + as IDictionary>); } public SingleDeviceSaver(IDictionary> tensor_slice_dict) { _tensor_slice_dict = tensor_slice_dict.ToDictionary( x => x.Key, x => x.Value.ToDictionary( - y => y.Key, y => new Maybe(y.Value)) - as IDictionary>); + y => y.Key, y => OneOf.FromT1(y.Value)) + as IDictionary>); } public Operation? save(Tensor file_prefix, CheckpointOptions? options = null) { @@ -149,7 +57,7 @@ namespace Tensorflow.Checkpoint { var slice_spec = slice.Key; var maybe_tensor = slice.Value; - if(maybe_tensor.TryGet(out var spec)) + if(maybe_tensor.TryPickT1(out var spec, out var tensor)) { var tensor_value = spec.tensor; if (tensor_value is not null) @@ -161,7 +69,6 @@ namespace Tensorflow.Checkpoint } else { - var tensor = maybe_tensor.GetValue(); tensor_names.Add(checkpoint_key); tensors.Add(tensor); slice_specs.Add(slice_spec); @@ -193,7 +100,7 @@ namespace Tensorflow.Checkpoint var slice_spec = slice.Key; var maybe_tensor = slice.Value; // TODO: deal with other types. Currently only `SaveSpec` is allowed. - if(maybe_tensor.TryGet(out var spec)) + if(maybe_tensor.TryPickT1(out var spec, out var tensor)) { tensor_dtypes.Add(spec.dtype); slice_specs.Add(spec.slice_spec); @@ -201,7 +108,6 @@ namespace Tensorflow.Checkpoint } else { - var tensor = maybe_tensor.GetValue(); tensor_dtypes.Add(tensor.dtype); slice_specs.Add(slice_spec); tensor_names.Add(checkpoint_key); @@ -254,7 +160,7 @@ namespace Tensorflow.Checkpoint /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. /// /// - public MultiDeviceSaver(IDictionary>>> serialized_tensors, + public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); @@ -274,9 +180,9 @@ namespace Tensorflow.Checkpoint { restore_fn = new RestoreFunc(x => { - if(x is IDictionary>>) + if(x is IDictionary>>) { - return obj._restore_from_tensors(x as IDictionary>>); + return obj._restore_from_tensors(x as IDictionary>>); } throw new TypeError($"Expected `IDictionary>>` as input, got{x.GetType()}."); }); @@ -286,14 +192,14 @@ namespace Tensorflow.Checkpoint { var checkpoint_key = item.Key; IDictionary spec_to_tensor; - if(item.Value.TryGet(out var t)) + if(item.Value.TryPickT0(out var t, out var dic)) { spec_to_tensor = new Dictionary(); spec_to_tensor[""] = t; } else { - spec_to_tensor = item.Value.GetValue>(); + spec_to_tensor = dic; } foreach(var spec in spec_to_tensor) @@ -399,7 +305,7 @@ namespace Tensorflow.Checkpoint IDictionary restore_func() { - Dictionary>>> restore_fn_inputs = new(); + Dictionary>>> restore_fn_inputs = new(); Dictionary restore_fn_input_count = _restore_fn_to_keys.ToDictionary(x => x.Key, x => x.Value.Count); Dictionary restore_ops = new(); @@ -419,29 +325,29 @@ namespace Tensorflow.Checkpoint var slice_spec = item.Key; var tensor = item.Value; var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; - var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); if (!string.IsNullOrEmpty(slice_spec)) { if (!internal_dict.ContainsKey(checkpoint_key)) { Dictionary dict = new(); dict[slice_spec] = tensor; - internal_dict[checkpoint_key] = new Maybe>(dict); + internal_dict[checkpoint_key] = OneOf>.FromT1(dict); } else { - internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; + internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; } } else { - internal_dict[checkpoint_key] = new Maybe>(tensor); + internal_dict[checkpoint_key] = OneOf>.FromT0(tensor); } restore_fn_input_count[restore_fn]--; if (restore_fn_input_count[restore_fn] == 0) { - Dictionary>> restored_tensors = new(); + Dictionary>> restored_tensors = new(); foreach(var input in restore_fn_inputs[restore_fn]) { restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; @@ -519,7 +425,7 @@ namespace Tensorflow.Checkpoint public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) { - Dictionary>>> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach (var saveable in saveables) { var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs index b27396a7..e2770487 100644 --- a/src/TensorFlowNET.Core/Checkpoint/restore.cs +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -1,4 +1,5 @@ -using System; +using OneOf; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; @@ -61,13 +62,13 @@ public class CheckpointPosition } } - public (List, Dictionary>, List, object?) gather_ops_or_named_saveables() + public (List, Dictionary>, List, object?) gather_ops_or_named_saveables() { // skip the registered_saver if (ObjectProto.Attributes is null || ObjectProto.Attributes.Count == 0) { - return (new List(), new Dictionary>(), + return (new List(), new Dictionary>(), new List(), null); } @@ -75,7 +76,7 @@ public class CheckpointPosition List existing_restore_ops; List positions = new(); - Dictionary> named_saveables; + Dictionary> named_saveables; if (saveable_factories.Keys.Count == 1 && saveable_factories.Keys.First() == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) { (existing_restore_ops, named_saveables) = _create_serialize_to_tensor_saveable(saveable_factories); @@ -109,8 +110,8 @@ public class CheckpointPosition /// Creates a saveable using the _serialize_to_tensor method. /// /// - private (List, Dictionary>) _create_serialize_to_tensor_saveable( - IDictionary>> saveable_factories) + private (List, Dictionary>) _create_serialize_to_tensor_saveable( + IDictionary>> saveable_factories) { string suffix = SaveableCompat.get_saveable_name(this.Trackable); suffix = suffix ?? ""; @@ -124,23 +125,23 @@ public class CheckpointPosition var saveable = saveable_factories[TrackableUtils.SERIALIZE_TO_TENSORS_NAME](saveable_name); // skip the cache. - Dictionary> dict = new(); + Dictionary> dict = new(); dict[saveable_name] = saveable; return (new List(), dict); } - private (List, Dictionary>) _create_saveables_by_attribute_name( - IDictionary>> saveable_factories) + private (List, Dictionary>) _create_saveables_by_attribute_name( + IDictionary>> saveable_factories) { // TODO(Rinne): implement it. if(ObjectProto.Attributes is null) { - return (new List(), new Dictionary>()); + return (new List(), new Dictionary>()); } List existing_restore_ops = new(); HashSet created_compat_names = new(); - Dictionary> named_saveables = new(); + Dictionary> named_saveables = new(); foreach (var serialized_tensor in ObjectProto.Attributes) { Operation existing_op; @@ -172,12 +173,12 @@ public class CheckpointPosition _checkpoint.UnusedAttributes.SetDefault(_proto_id, new List()).Add(serialized_tensor.Name); continue; } - named_saveables[serialized_tensor.CheckpointKey] = saveable; + named_saveables[serialized_tensor.CheckpointKey] = saveable.Value; } return (existing_restore_ops, named_saveables); } - private Maybe _get_saveable_from_factory(IDictionary>> saveable_factories, + private OneOf? _get_saveable_from_factory(IDictionary>> saveable_factories, TrackableObjectGraph.Types.TrackableObject.Types.SerializedTensor serialized_tensor, HashSet created_compat_names) { var expected_factory_name = serialized_tensor.Name; @@ -221,7 +222,7 @@ public class CheckpointPosition Queue<(CheckpointPosition, Trackable)> visit_queue = new(); visit_queue.Enqueue((this, this.Trackable)); List restore_ops = new(); - Dictionary> tensor_saveables = new(); + Dictionary> tensor_saveables = new(); List positions = new(); CheckpointPosition current_position = null; @@ -306,7 +307,7 @@ public class CheckpointPosition } } - private (List, Dictionary>, List, object?) _single_restore() + private (List, Dictionary>, List, object?) _single_restore() { var trackable = this.Trackable; trackable._maybe_initialize_trackable(); @@ -318,7 +319,7 @@ public class CheckpointPosition } else { - return (new List(), new Dictionary>(), + return (new List(), new Dictionary>(), new List(), null); } } diff --git a/src/TensorFlowNET.Core/Eager/forwardprop_util.cs b/src/TensorFlowNET.Core/Eager/forwardprop_util.cs new file mode 100644 index 00000000..a53026d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/forwardprop_util.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Eager +{ + public class TangentInfo + { + // TODO(Rinne): implement it. + public object Indices { get; set; } + public object Tangents { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 9abcc61c..3cc27f25 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework.Models; using Tensorflow.Graphs; using Tensorflow.Train; @@ -17,11 +19,13 @@ namespace Tensorflow.Functions internal FuncGraph func_graph; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; protected Dictionary _attrs; + protected FunctionSpec _function_spec; + protected FunctionSpec _pre_initialized_function_spec = null; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; - public string Name => _delayed_rewrite_functions.forward().Name; + public string Name => _delayed_rewrite_functions.Forward().Name; public Tensor[] Outputs; public Type ReturnType; @@ -175,7 +179,13 @@ namespace Tensorflow.Functions var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (executing_eagerly) + { + flat_outputs = forward_function.Call(args_with_tangents); + } + else + { flat_outputs = forward_function.Call(args_with_tangents); + } forward_backward.Record(flat_outputs); return flat_outputs; } @@ -186,7 +196,7 @@ namespace Tensorflow.Functions { g = ops.get_default_graph(); } - _delayed_rewrite_functions.forward().AddToGraph(g); + _delayed_rewrite_functions.Forward().AddToGraph(g); } public void SetExternalCaptures(IEnumerable captures) @@ -196,8 +206,60 @@ namespace Tensorflow.Functions ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { - var functions = new FirstOrderTapeGradientFunctions(func_graph, false); - return new ForwardBackwardCall(functions, args, tape_watching: true); + TangentInfo input_tangents; + if (executing_eagerly) + { + throw new NotImplementedException(); + } + else + { + input_tangents = new TangentInfo(); + } + if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) + { + if(input_tangents.Indices is not null || executing_eagerly) + { + var functions = new FirstOrderTapeGradientFunctions(func_graph, false); + return new ForwardBackwardCall(functions, args, tape_watching: true); + } + else + { + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: true); + } + } + else if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER) + { + throw new NotImplementedException(); + } + + // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. + return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); + } + + internal void _set_function_spec(FunctionSpec spec) + { + _function_spec = null; + _pre_initialized_function_spec = spec; + _initialize_function_spec(); + } + + internal void _initialize_function_spec() + { + if(_pre_initialized_function_spec is null) + { + return; + } + Debug.Assert(_function_spec is null, "already initialized"); + var spec = _pre_initialized_function_spec; + //var args = spec.Fullargspec.DictValue.Fields["args"]; + // TODO(Rinne): self.structured_input_signature + + _function_spec = new FunctionSpec() + { + Fullargspec = spec.Fullargspec, + IsMethod = spec.IsMethod, + InputSignature = spec.InputSignature + }; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index 40b61511..4c2d4c37 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -5,6 +5,8 @@ using System.Linq; using System.Text; using Tensorflow.Contexts; using Tensorflow.Graphs; +using Tensorflow.Operations; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -14,7 +16,10 @@ namespace Tensorflow.Functions public int _num_outputs; FuncGraph _func_graph; FunctionDef _definition; + Tensor[] _func_graph_outputs; public string Name => _func_graph.FuncName; + public DataType[] OutputTypes { get; protected set; } + public Shape[] OutputShapes { get; protected set; } public FunctionDef Definition { get @@ -36,27 +41,69 @@ namespace Tensorflow.Functions var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) .Select(x => x as Operation).ToArray(); var output_names = new string[0]; + OutputShapes = outputs.Select(x => x.shape).ToArray(); + OutputTypes = outputs.Select(x => x.dtype.as_datatype_enum()).ToArray(); _func_graph = new FuncGraph(graph, name, attrs); + _func_graph_outputs = new List(outputs).ToArray(); _func_graph.ToGraph(operations, inputs, outputs, output_names); } public Tensors Call(Tensors args) { + // TODO(Rinne): Add arg `CancellationManager`. + // TODO(Rinne): Check the arg length. + var function_call_options = tf.Context.FunctionCallOptions; + string config; + if (string.IsNullOrEmpty(function_call_options.config_proto_serialized())) + { + config = function_utils.get_disabled_rewriter_config(); + } + else + { + config = function_call_options.config_proto_serialized(); + } + // TODO(Rinne): executor_type + var executing_eagerly = tf.Context.executing_eagerly(); + var attrs = new object[] { "executor_type", "", "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() }; - var results = tf.Runner.TFE_Execute(tf.Context, + Tensor[] outputs; + if (executing_eagerly) + { + outputs = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, _func_graph.FuncName, args, attrs, _num_outputs); - - return results; + } + else + { + tf.GradientTape().stop_recording(); + outputs = functional_ops.partitioned_call(args, this, OutputTypes, + executing_eagerly, config, ""); + } + foreach(var (i, func_graph_output) in enumerate(_func_graph_outputs)) + { + handle_data_util.copy_handle_data(func_graph_output, outputs[i]); + } + if (executing_eagerly) + { + return outputs; + } + else + { + foreach(var (i, shape) in enumerate(OutputShapes)) + { + outputs[i].shape = shape; + } + return outputs; + } } public void AddToGraph(Graph g = null) diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index 45a13632..cfea3954 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -9,16 +9,46 @@ namespace Tensorflow #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; #pragma warning restore CS0169 // The field 'Function._handle' is never used - + + protected Func _function; + protected ConcreteFunction _concrete_variable_creation_fn; + protected bool _auto_graph; public string Name { get; set; } - public Function() + public Function(Func function, + string name, bool auto_graph = true) + { + _function = function; + Name = name; + _auto_graph = auto_graph; + } + + public virtual Tensors Apply(Tensors inputs) { + if (_run_functions_eagerly()) + { + return _function(inputs); + } + var result = _call(inputs); + return result; } - - public Function(string name) + + protected virtual Tensors _call(Tensors inputs) { - Name = name; + _initialize(); + + return _concrete_variable_creation_fn.CallFlat(inputs, + _concrete_variable_creation_fn.CapturedInputs); + } + + protected virtual bool _run_functions_eagerly() + { + return false; + } + + private void _initialize() + { + } } } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 9f216ff7..23889d44 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -15,11 +15,11 @@ namespace Tensorflow.Functions /// public abstract class TapeGradientFunctions { - string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; - string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; - string _FORWARD_PREFIX = "__forward_"; - string _BACKWARD_PREFIX = "__backward_"; - string _INFERENCE_PREFIX = "__inference_"; + protected string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; + protected string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; + protected string _FORWARD_PREFIX = "__forward_"; + protected string _BACKWARD_PREFIX = "__backward_"; + protected string _INFERENCE_PREFIX = "__inference_"; protected FuncGraph _func_graph; protected EagerDefinedFunction _forward; @@ -35,8 +35,9 @@ namespace Tensorflow.Functions _func_graph = func_graph; } - public EagerDefinedFunction Forward(Tensors inference_args) + public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) { + // TODO(Rinne): add input_tangents arg. return ForwardAndBackwardFunctions(inference_args); } @@ -45,8 +46,9 @@ namespace Tensorflow.Functions /// /// /// - public void Record(Tensors flat_outputs, Tensors inference_args) + public virtual void Record(Tensors flat_outputs, Tensors inference_args) { + // TODO(Rinne): add arg `input_tagents`. var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); tf.Runner.RecordGradient(_forward.Name, inference_args, new object[0], to_record, getBackwardFunction: backward_function); diff --git a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs index c39f2402..e92fa3a1 100644 --- a/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs +++ b/src/TensorFlowNET.Core/Functions/function_saved_model_utils.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Operations; using Tensorflow.Train; +using Tensorflow.Variables; using static Tensorflow.Binding; namespace Tensorflow.Functions diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs index df8b6d4e..a8769438 100644 --- a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -5,16 +5,13 @@ using Tensorflow.Graphs; namespace Tensorflow.Functions { - public class DelayedRewriteGradientFunctions + public class DelayedRewriteGradientFunctions: TapeGradientFunctions { - static readonly string _INFERENCE_PREFIX = "__inference_"; - static readonly string _BACKWARD_PREFIX = "__backward_"; - static readonly string _FORWARD_PREFIX = "__forward_"; - FuncGraph _func_graph; EagerDefinedFunction _inference_function; Dictionary _attrs; int _num_inference_outputs; public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) + :base(func_graph, false) { _func_graph= func_graph; _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), @@ -23,7 +20,7 @@ namespace Tensorflow.Functions _num_inference_outputs = _func_graph.Outputs.Length; } - public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) + public override EagerDefinedFunction Forward(Tensors inference_args = null, Tensors input_tangents = null) { if(input_tangents is not null) { @@ -33,7 +30,23 @@ namespace Tensorflow.Functions return _inference_function; } - private static string _inference_name(string name) + public override void Record(Tensors flat_outputs, Tensors inference_args) + { + // TODO(Rinne): implement it. + throw new NotImplementedException(); + base.Record(flat_outputs, inference_args); + } + + //private (BackwardFunction, Tensors) _backward(Tensors outputs) + //{ + // Tensor[] backward_function(Tensor[] grads, long[] unneeded_gradients) + // { + // var call_op = outputs[0].op; + + // } + //} + + private string _inference_name(string name) { return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; } diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 40a83493..e6312c0d 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -25,6 +25,11 @@ namespace Tensorflow { public class gradients_util { + // Represents the output of TFE_Py_TapeSetPossibleGradientTypes. Real enums are + // unfortunately too slow to use here. + public static int POSSIBLE_GRADIENT_TYPES_NONE = 0; + public static int POSSIBLE_GRADIENT_TYPES_FIRST_ORDER = 1; + public static int POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER = 2; public static Tensor[] _GradientsHelper(Tensor[] ys, Tensor[] xs, Tensor[] grad_ys = null, diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index cf38d6b1..e583868e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -129,6 +129,7 @@ namespace Tensorflow protected Graph outer_graph; public Graph OuterGraph => outer_graph; public Dictionary Functions => _functions; + public SafeGraphHandle c_graph => _handle; public Graph() { diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 900db8ca..46a654e0 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -208,5 +208,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + [DllImport(TensorFlowLibName)] + public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); } } diff --git a/src/TensorFlowNET.Core/Operations/functional_ops.cs b/src/TensorFlowNET.Core/Operations/functional_ops.cs index 908029f5..2d447207 100644 --- a/src/TensorFlowNET.Core/Operations/functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/functional_ops.cs @@ -14,10 +14,14 @@ limitations under the License. ******************************************************************************/ +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Framework; +using Tensorflow.Functions; +using Tensorflow.Operations; using Tensorflow.Util; using static Tensorflow.Binding; @@ -25,6 +29,72 @@ namespace Tensorflow { public class functional_ops { + public static Tensor[] partitioned_call(Tensors args, EagerDefinedFunction f, DataType[] tout, + bool executing_eagerly, string config, string executor_type) + { + if (tout is null) + { + throw new NotImplementedException(); + } + + if (config is null) + { + config = function_utils.get_disabled_rewriter_config(); + } + + if (executor_type is null) + { + executor_type = ""; + } + + if (executing_eagerly) + { + throw new NotImplementedException(); + } + + var converted_args = args.Select(x => ops.convert_to_tensor(x)).ToArray(); + AttrValue tin_attr = new() + { + List = new AttrValue.Types.ListValue() + }; + tin_attr.List.Type.AddRange(args.Select(x => x.dtype.as_datatype_enum())); + AttrValue tout_attr = new() + { + List = new AttrValue.Types.ListValue() + }; + tout_attr.List.Type.AddRange(tout); + AttrValue func_attr = new() + { + Func = new NameAttrList() + }; + func_attr.Func.Name = f.Name; + AttrValue executor_type_attr = new AttrValue() + { + S = tf.compat.as_bytes(executor_type) + }; + AttrValue config_proto = new AttrValue() + { + S = ByteString.CopyFromUtf8(executor_type) + }; + + var graph = ops.get_default_graph(); + f.AddToGraph(graph); + // TODO(Rinne): complete it with `f.stateful` + var op_name = "PartitionedCall"; + string xla_compile_attr = "_XlaMustCompile"; + Dictionary op_attrs = new(); + op_attrs["Tin"] = tin_attr; + op_attrs["Tout"] = tout_attr; + op_attrs["f"] = func_attr; + op_attrs["config_proto"] = config_proto; + op_attrs["executor_type"] = executor_type_attr; + // TODO(Rinne): deal with `f.definition`. + var op = graph.create_op(op_name, args, tout.Select(x => x.as_tf_dtype()).ToArray(), + name: op_name, attrs: op_attrs); + var outputs = op.outputs; + // TODO(Rinne): deal with `f.graph`. + return outputs; + } public static Tensor scan( Func fn, Tensor elems, diff --git a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs new file mode 100644 index 00000000..ce37ec7d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Xml.Linq; +using Tensorflow.Contexts; +using Tensorflow.Eager; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + public class gen_functional_ops + { + public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, + string config = "", string config_proto = "", string executor_type = "", string name = null) + { + var ctx = tf.Context; + if (ctx.executing_eagerly()) + { + try + { + return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, + args, tout, f, config, config_proto, executor_type)); + } + catch (Exception) + { + + } + } + + if (config is null) + { + config = ""; + } + if (config_proto is null) + { + config_proto = ""; + } + if (executor_type is null) + { + executor_type = ""; + } + Dictionary kwargs = new(); + kwargs["args"] = args; + kwargs["Tout"] = tout; + kwargs["f"] = f; + kwargs["config"] = config; + kwargs["config_proto"] = config_proto; + kwargs["executor_type"] = executor_type; + var output = tf.OpDefLib._apply_op_helper("PartitionedCall", + name, kwargs); + var result = output.outputs; + if (execute.must_record_gradient()) + { + throw new NotImplementedException(); + } + return result; + } + + public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, + string config, string config_proto, string executor_type, string name, Context ctx) + { + // TODO(Rinne): implement it. + throw new NotImplementedException(); + if(config is null) + { + config = ""; + } + if(config_proto is null) + { + config_proto = ""; + } + if(executor_type is null) + { + executor_type = ""; + } + object[] attrs = new object[] + { + + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs index 6d4d8a19..ca690774 100644 --- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -1,7 +1,9 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Eager; +using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow.Operations { @@ -11,18 +13,31 @@ namespace Tensorflow.Operations { if(target_t.dtype == dtypes.resource || target_t.dtype == dtypes.variant) { - SafeTensorHandle handle_data; + HandleData handle_data; if(source_t is EagerTensor) { - handle_data = source_t.Handle; + handle_data = source_t.HandleData; } else { handle_data = ops.get_resource_handle_data(source_t); } - throw new NotImplementedException(); - //if(handle_data is not null && handle_data.) + if(handle_data is not null && handle_data.IsSet && handle_data.ShapeAndType is not null + && handle_data.ShapeAndType.Count > 0) + { + set_handle_data(target_t, handle_data); + } + } + } + + public static void set_handle_data(Tensor target_t, HandleData handle_data) + { + if(target_t is EagerTensor) + { + target_t.HandleData = handle_data; + return; } + c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); } } } diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 2b1d9a84..83ff50b1 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -39,7 +39,7 @@ namespace Tensorflow public static bool is_resource_variable(IVariableV1 var) { - return var is ResourceVariable; + return var is BaseResourceVariable; } public static bool is_resource_variable(Trackable var) @@ -231,5 +231,21 @@ namespace Tensorflow } } } + + public static void _maybe_set_handle_data(TF_DataType dtype, Tensor handle, Tensor tensor) + { + if(dtype == dtypes.variant) + { + var handle_data = get_eager_safe_handle_data(handle); + if(handle_data.IsSet && handle_data.ShapeAndType.Count > 1) + { + tensor.HandleData = new HandleData() + { + IsSet = true + }; + tensor.HandleData.ShapeAndType.AddRange(handle_data.ShapeAndType.Skip(1)); + } + } + } } } diff --git a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs index f76bf2f0..7a601ed5 100644 --- a/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs +++ b/src/TensorFlowNET.Core/Protobuf/CppShapeInference.cs @@ -479,7 +479,7 @@ namespace Tensorflow { /// [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public pbc::RepeatedField ShapeAndType { - get { return shapeAndType_; } + get { return shapeAndType_; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 32575213..3d056cae 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -277,15 +277,15 @@ namespace Tensorflow { get { return Descriptor; } } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public SavedObject() { + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject() { OnConstruction(); } partial void OnConstruction(); - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] - public SavedObject(SavedObject other) : this() { + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public SavedObject(SavedObject other) : this() { children_ = other.children_.Clone(); dependencies_ = other.dependencies_.Clone(); slotVariables_ = other.slotVariables_.Clone(); @@ -329,7 +329,9 @@ namespace Tensorflow { public const int ChildrenFieldNumber = 1; private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); - private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); + private static readonly pb::FieldCodec _repeated_dependencies_codec + = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); + private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// /// Objects which this object depends on: named edges in the dependency @@ -501,7 +503,8 @@ namespace Tensorflow { return true; } if(!children_.Equals(other.children_)) return false; - if(!slotVariables_.Equals(other.slotVariables_)) return false; + if (!dependencies_.Equals(other.dependencies_)) return false; + if (!slotVariables_.Equals(other.slotVariables_)) return false; if (!object.Equals(UserObject, other.UserObject)) return false; if (!object.Equals(Asset, other.Asset)) return false; if (!object.Equals(Function, other.Function)) return false; @@ -519,6 +522,7 @@ namespace Tensorflow { public override int GetHashCode() { int hash = 1; hash ^= children_.GetHashCode(); + hash ^= dependencies_.GetHashCode(); hash ^= slotVariables_.GetHashCode(); if (kindCase_ == KindOneofCase.UserObject) hash ^= UserObject.GetHashCode(); if (kindCase_ == KindOneofCase.Asset) hash ^= Asset.GetHashCode(); @@ -544,6 +548,7 @@ namespace Tensorflow { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public void WriteTo(pb::CodedOutputStream output) { children_.WriteTo(output, _repeated_children_codec); + children_.WriteTo(output, _repeated_dependencies_codec); slotVariables_.WriteTo(output, _repeated_slotVariables_codec); if (kindCase_ == KindOneofCase.UserObject) { output.WriteRawTag(34); @@ -587,6 +592,7 @@ namespace Tensorflow { public int CalculateSize() { int size = 0; size += children_.CalculateSize(_repeated_children_codec); + size += children_.CalculateSize(_repeated_dependencies_codec); size += slotVariables_.CalculateSize(_repeated_slotVariables_codec); if (kindCase_ == KindOneofCase.UserObject) { size += 1 + pb::CodedOutputStream.ComputeMessageSize(UserObject); @@ -619,7 +625,7 @@ namespace Tensorflow { return size; } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] public void MergeFrom(SavedObject other) { if (other == null) { return; @@ -682,7 +688,7 @@ namespace Tensorflow { _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } - [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + //[global::System.Diagnostics.DebuggerNonUserCodeAttribute] public void MergeFrom(pb::CodedInputStream input) { uint tag; while ((tag = input.ReadTag()) != 0) { @@ -692,9 +698,10 @@ namespace Tensorflow { break; case 10: { children_.AddEntriesFrom(input, _repeated_children_codec); + dependencies_.AddRange(children_.Except(dependencies_)); break; } - case 26: { + case 26: { slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); break; } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 214b2777..6d226513 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -109,7 +109,12 @@ https://tensorflownet.readthedocs.io + + + + + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index ade00d5c..0bffbfba 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -87,6 +87,7 @@ namespace Tensorflow public object Tag { get; set; } protected new SafeTensorHandle _handle; public virtual SafeTensorHandle Handle => _handle; + public Tensorflow.CppShapeInferenceResult.Types.HandleData HandleData { get; internal set; } protected SafeEagerTensorHandle _eagerTensorHandle; /// diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs index 2fd0d1d8..f8c97975 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs @@ -14,18 +14,19 @@ limitations under the License. ******************************************************************************/ +using OneOf; using Tensorflow.Checkpoint; namespace Tensorflow { public class MySaveableObject { - protected Maybe _op; + protected OneOf _op; public Tensor op { get { - if(_op.TryGet(out var tensor)) + if(_op.TryPickT0(out var tensor, out var _)) { return tensor; } @@ -43,7 +44,7 @@ namespace Tensorflow { get { - if (_op.TryGet(out var v)) + if (_op.TryPickT1(out var v, out var _)) { return v; } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 25697c6e..951d7d00 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -25,11 +25,32 @@ namespace Tensorflow.Training.Saving.SavedModel /// /// /// - public static ConcreteFunction recreate_function(SavedFunction saved_concrete_function, + public static Function recreate_function(SavedFunction saved_function, IDictionary concrete_functions) { - var function_spec = _deserialize_function_spec_as_nonmethod(saved_concrete_function.FunctionSpec); - return null; + var function_spec = _deserialize_function_spec_as_nonmethod(saved_function.FunctionSpec); + + List concrete_function_objects = new(); + foreach(var concrete_function_name in saved_function.ConcreteFunctions) + { + concrete_function_objects.Add(concrete_functions[concrete_function_name]); + } + foreach(var cf in concrete_function_objects) + { + cf._set_function_spec(function_spec); + } + + foreach(var function_name in saved_function.ConcreteFunctions) + { + var function = concrete_functions[function_name]; + if(_concrete_function_callable_with(function, null, false)) + { + return new RestoredFunction(null, function, "function_from_deserialization"); + } + } + return new RestoredFunction(x => x, new ConcreteFunction(x => x, TF_DataType.TF_FLOAT), "function_return_itself"); + //throw new ValueError("Unexpected runtime behavior, please submit an issue to " + + // "https://github.com/SciSharp/TensorFlow.NET/issues"); } public static Dictionary load_function_def_library(FunctionDefLibrary library, @@ -385,5 +406,31 @@ namespace Tensorflow.Training.Saving.SavedModel JitCompile = function_spec_proto.JitCompile }; } + + private static Tensors _call_concrete_function(ConcreteFunction function, Tensors inputs) + { + // TODO(Rinne): var expected_structure = function.func_graph.structured_input_signature + return function.CallFlat(inputs, function.CapturedInputs); + } + + private static bool _concrete_function_callable_with(ConcreteFunction function, Tensors inputs, bool allow_conversion) + { + // TODO(Rinne): revise it. + return true; + } + } + + public class RestoredFunction : Function + { + public RestoredFunction(Func function, ConcreteFunction concrete_function, + string name, bool auto_graph = true): base(function, name, auto_graph) + { + _concrete_variable_creation_fn = concrete_function; + } + + protected override bool _run_functions_eagerly() + { + return false; + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 3505da93..53ac9e2a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -14,6 +14,7 @@ using Tensorflow.Variables; using Tensorflow.Functions; using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Trackables; +using OneOf; namespace Tensorflow { @@ -44,6 +45,8 @@ namespace Tensorflow _asset_file_def = meta_graph.AssetFileDef; _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); _proto = object_graph_proto; + // Debug(Rinne) + var temp = _proto.ToString(); _export_dir = export_dir; // TODO: `this._concrete_functions` and `this._restored_concrete_functions` _concrete_functions = function_deserialization.load_function_def_library( @@ -259,9 +262,9 @@ namespace Tensorflow /// /// /// - private Dictionary, int> _get_node_dependencies(SavedObject proto) + private Dictionary, int> _get_node_dependencies(SavedObject proto) { - Dictionary, int> dependencies = new(); + Dictionary, int> dependencies = new(); foreach(var refer in proto.Dependencies) { dependencies[refer.LocalName] = refer.NodeId; @@ -375,11 +378,6 @@ namespace Tensorflow // Re-create everything. foreach (var (node_id, proto) in _iter_all_nodes()) { - if(node_id == 45) - { - // TODelete - Console.WriteLine(); - } if (nodes.ContainsKey(node_id)) { continue; @@ -474,7 +472,7 @@ namespace Tensorflow } } - private void _setup_function_captures(string concrete_function_name, IDictionary, Trackable> nodes) + private void _setup_function_captures(string concrete_function_name, IDictionary, Trackable> nodes) { if (_restored_concrete_functions.Contains(concrete_function_name)) { @@ -509,6 +507,11 @@ namespace Tensorflow /// private void _add_object_graph_edges(SavedObject proto, int node_id) { + // Debug(Rinne) + if(node_id == 1) + { + Console.WriteLine(); + } var obj = _nodes[node_id]; var setter = _node_setters[node_id]; @@ -549,8 +552,13 @@ namespace Tensorflow private (Trackable, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) { // skip the registered classes. + if(node_id == 16) + { + // Debug(Rinne) + Console.WriteLine(); + } - Dictionary, Trackable> dependencies = new(); + Dictionary, Trackable> dependencies = new(); foreach(var item in _get_node_dependencies(proto)) { dependencies[item.Key] = nodes[item.Value]; @@ -571,7 +579,7 @@ namespace Tensorflow /// /// /// - private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, Trackable> dependencies) + private (Trackable, Action) _recreate_default(SavedObject proto, int node_id, IDictionary, Trackable> dependencies) { return proto.KindCase switch { @@ -637,10 +645,10 @@ namespace Tensorflow } } - private (ConcreteFunction, Action) _recreate_function(SavedFunction proto, - Dictionary, Trackable> dependencies) + private (Function, Action) _recreate_function(SavedFunction proto, + Dictionary, Trackable> dependencies) { - var fn = function_deserialization.recreate_function(proto, null); + var fn = function_deserialization.recreate_function(proto, _concrete_functions); foreach (var name in proto.ConcreteFunctions) { _setup_function_captures(name, dependencies); @@ -649,7 +657,7 @@ namespace Tensorflow } private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, - IDictionary, Trackable> dependencies) + IDictionary, Trackable> dependencies) { var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); _setup_function_captures(proto.ConcreteFunctionName, dependencies); diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 20831122..5456669e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using OneOf; using System; using System.Collections.Generic; using System.Diagnostics; @@ -174,7 +175,7 @@ namespace Tensorflow full_name = name + "_" + attr; } var op = factory(full_name); - if(op.TryGet(out var variable)) + if(op.TryPickT0(out var variable, out var saveable)) { foreach (var v in saveable_objects_for_op(variable as Trackable, variable.Name)) { @@ -183,7 +184,6 @@ namespace Tensorflow } else { - var saveable = op.GetValue(); foreach (var v in saveable_objects_for_op(saveable, saveable.name)) { yield return v; @@ -252,11 +252,11 @@ namespace Tensorflow return names_to_saveables; } - public static IDictionary>> saveable_objects_from_trackable(Trackable obj) + public static IDictionary>> saveable_objects_from_trackable(Trackable obj) { // skip the process of type `PythonState` - Maybe create_saveable(string name = "") + OneOf create_saveable(string name = "") { // skip the case that `obj._serialize_to_tensors` is `ConcreteFunction`. var tensor_dict = obj.serialize_to_tensors(); @@ -272,14 +272,14 @@ namespace Tensorflow string spec_name = name + TrackableUtils.escape_local_name(tensor_name); IDictionary internal_dict; - if (maybe_tensor.TryGet(out var tensor)) + if (maybe_tensor.TryPickT0(out var tensor, out var dic)) { internal_dict = new Dictionary(); internal_dict[""] = tensor; } else { - internal_dict = maybe_tensor.GetValue>(); + internal_dict = dic; } foreach (var item in internal_dict) @@ -292,7 +292,7 @@ namespace Tensorflow if (trackable_has_serialize_to_tensor(obj)) { - Dictionary>> res = new(); + Dictionary>> res = new(); res[TrackableUtils.SERIALIZE_TO_TENSORS_NAME] = create_saveable; return res; } @@ -316,9 +316,9 @@ namespace Tensorflow /// Converts a list of SaveableObjects to a tensor dictionary. /// /// - public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) { - Dictionary>> tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { foreach (var spec in saveable.specs) @@ -328,7 +328,7 @@ namespace Tensorflow var slice_spec = convert_to_string(spec.slice_spec); if (!string.IsNullOrEmpty(slice_spec)) { - tensor_dict.SetDefault(name, new Dictionary()).GetValue>()[slice_spec] = spec.tensor; + tensor_dict.SetDefault(name, new Dictionary()).AsT1[slice_spec] = spec.tensor; } else { @@ -343,7 +343,7 @@ namespace Tensorflow /// Generates `Trackable._restore_from_tensors` from SaveableObjects. /// /// - public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) + public static Func>>, IDictionary> saveable_object_to_restore_fn(IList saveables) { return (restored_tensors) => { @@ -359,14 +359,14 @@ namespace Tensorflow var maybe_tensor = restored_tensors[name]; IDictionary dict; - if(maybe_tensor.TryGet(out var tensor)) + if(maybe_tensor.TryPickT0(out var tensor, out var dic)) { dict = new Dictionary(); dict[""] = tensor; } else { - dict = maybe_tensor.GetValue>(); + dict = dic; } saveable_restored_tensors.Add(dict[slice_spec]); } @@ -381,18 +381,18 @@ namespace Tensorflow /// /// /// - public static IDictionary>> recreate_saveable_objects( + public static IDictionary>> recreate_saveable_objects( IDictionary saveable_fn_by_name, IEnumerable? temp_session) { if (saveable_fn_by_name.Count > 0) { throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); } - var res = new Dictionary>>(); + var res = new Dictionary>>(); return res; } - public static Maybe create_saveable_object(string name, string key, Func> factory, + public static OneOf create_saveable_object(string name, string key, Func> factory, bool call_with_mapped_captures = false) { return factory(key); @@ -412,7 +412,7 @@ namespace Tensorflow public object Obj => _obj; public IList mySaveables=> _saveables; - public override IDictionary>> serialize_to_tensors() + public override IDictionary>> serialize_to_tensors() { return saveable_object_util.saveable_object_to_tensor_dict(_saveables); } @@ -422,7 +422,7 @@ namespace Tensorflow /// /// /// - public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + public override IDictionary _restore_from_tensors(IDictionary>> restored_tensors) { List expected_keys = new(); foreach(var saveable in _saveables) diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 7c86a580..b64b5ebc 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using OneOf; using System; using System.Collections.Generic; using System.Diagnostics; @@ -43,8 +44,8 @@ namespace Tensorflow.Train protected IList _unconditional_checkpoint_dependencies; protected Dictionary> _unconditional_deferred_dependencies; - protected IDictionary>> _self_saveable_object_factories = - new Dictionary>>(); + protected IDictionary>> _self_saveable_object_factories = + new Dictionary>>(); private bool _manual_tracking = true; private static Trackable _none = new AutoTrackable(); @@ -73,7 +74,7 @@ namespace Tensorflow.Train public IDictionary UnconditionalDependencyNames { get => _unconditional_dependency_names; } public IList CheckpointDependencies { get => UnconditionalCheckpointDependencies; } public Dictionary> DeferredDependencies => _unconditional_deferred_dependencies; - public IDictionary>> SelfSaveableObjectFactories + public IDictionary>> SelfSaveableObjectFactories { get { @@ -249,9 +250,9 @@ namespace Tensorflow.Train return self_tensor_map.Keys.ToList(); } - public virtual IDictionary>> gather_saveables_for_checkpoint() + public virtual IDictionary>> gather_saveables_for_checkpoint() { - Maybe create_saveable(string name = "") + OneOf create_saveable(string name = "") { throw new NotImplementedException(); //return new TrackableSaveable(this, null, name, null, null); @@ -259,7 +260,7 @@ namespace Tensorflow.Train if (saveable_object_util.trackable_has_serialize_to_tensor(this)) { // TODO: complete the implementation (need to complete the class `saveable_object_util.TrackableSaveable`). - Dictionary>> res = new(); + Dictionary>> res = new(); res[""] = create_saveable; return res; } @@ -278,12 +279,12 @@ namespace Tensorflow.Train /// /// /// - public virtual IDictionary>> serialize_to_tensors() + public virtual IDictionary>> serialize_to_tensors() { throw new NotImplementedException(); } - public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) + public virtual IDictionary _restore_from_tensors(IDictionary>> restored_tensors) { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Util/function_utils.cs b/src/TensorFlowNET.Core/Util/function_utils.cs new file mode 100644 index 00000000..2944e88e --- /dev/null +++ b/src/TensorFlowNET.Core/Util/function_utils.cs @@ -0,0 +1,23 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Util +{ + internal static class function_utils + { + private static string _rewriter_config_optimizer_disabled; + public static string get_disabled_rewriter_config() + { + if(_rewriter_config_optimizer_disabled is null) + { + var config = new ConfigProto(); + var rewriter_config = config.GraphOptions.RewriteOptions; + rewriter_config.DisableMetaOptimizer = true; + _rewriter_config_optimizer_disabled = config.ToString(); + } + return _rewriter_config_optimizer_disabled; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 9427b87f..cc5ee542 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Diagnostics; using Tensorflow.Checkpoint; using Tensorflow.Training.Saving.SavedModel; +using OneOf; namespace Tensorflow { @@ -155,7 +156,7 @@ namespace Tensorflow { variable_accessed(this); var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); - // _maybe_set_handle_data(_dtype, _handle, result); + resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); // have to set shape when converting to substituent placeholder if (result.shape.ndim == -1) @@ -293,9 +294,9 @@ namespace Tensorflow resource_variable_ops.write_object_proto_for_resource_variable(this, proto, options); } - public override IDictionary>> gather_saveables_for_checkpoint() + public override IDictionary>> gather_saveables_for_checkpoint() { - var res = new Dictionary>>(); + var res = new Dictionary>>(); res[Trackable.Constants.VARIABLE_VALUE_KEY] = x => this; return res; } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 3b1f1e96..7d0ac4f8 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -124,7 +124,9 @@ namespace Tensorflow initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; ops.colocate_with(initializer_op); - + tf.device(handle.Device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); _graph_element = gen_array_ops.identity(handle, name = "read"); ops.add_to_collections(collections, this); _dtype = handle.dtype; @@ -141,6 +143,12 @@ namespace Tensorflow gen_resource_variable_ops.assign_variable_op(handle, _initial_value); initializer_op = null; _graph_element = null; + if (!string.IsNullOrEmpty(caching_device)) + { + tf.device(caching_device); + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + } _dtype = _initial_value.dtype.as_base_dtype(); // initial_value = _in_graph_mode ? initial_value : null; } diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs index 6c034995..8ee3c62b 100644 --- a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Variables /// /// A variable with no initializer. /// - public sealed class UninitializedVariable: BaseResourceVariable + public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1 { // TODO: complete the arg list. public UninitializedVariable( @@ -23,6 +23,7 @@ namespace Tensorflow.Variables { string unique_id = ""; string handle_name = ""; + Tensor created_handle = null; tf_with(ops.init_scope(), (x) => { _in_graph_mode = !tf.Context.executing_eagerly(); @@ -40,7 +41,7 @@ namespace Tensorflow.Variables unique_id = $"{handle_name}-{ops.uid()}"; shared_name = null; } - var handle = resource_variable_ops.variable_handle_from_shape_and_dtype( + created_handle = resource_variable_ops.variable_handle_from_shape_and_dtype( shape, dtype, shared_name, name, _in_graph_mode, extra_handle_data); // skip the assignment of `handle._parent_trackable` because of lack of API. // skip the assignment of `handle._name` and `handle._unique_id` because of accessability. @@ -51,7 +52,7 @@ namespace Tensorflow.Variables { tf.device(handle.Device); var value = gen_resource_variable_ops.read_variable_op(handle, dtype); - // _maybe_set_handle_data(dtype, handle, value) + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); _graph_element = value; }); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); @@ -64,7 +65,7 @@ namespace Tensorflow.Variables }); _shape = shape; _dtype = dtype; - base.__init__(trainable, handle, unique_id: unique_id, handle_name: handle_name); + base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); } } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 59081ecf..bce64198 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -26,6 +26,7 @@ using Tensorflow.Eager; using Tensorflow.Graphs; using Tensorflow.Util; using static Tensorflow.Binding; +using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow { @@ -572,9 +573,12 @@ namespace Tensorflow return get_default_graph().building_function; } - public static SafeTensorHandle get_resource_handle_data(Tensor graph_op) + public static HandleData get_resource_handle_data(Tensor graph_op) { - throw new NotImplementedException(); + // This implementation hasn't been checked for some reasons. + // If it throws an exception in the future, please check it. + var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); + return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); } public static void dismantle_graph(Graph graph) diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 0f809cba..99ee66c2 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -40,7 +40,7 @@ namespace Tensorflow.Keras.Engine /// /// Arguments initialize layer. /// - LayerArgs args; + internal LayerArgs args; /// /// Indicates whether `build` needs to be called upon layer call, to create @@ -147,7 +147,7 @@ namespace Tensorflow.Keras.Engine List outboundNodes; public List OutboundNodes => outboundNodes; - public JObject SerializedAttributes { get; set; } + public Dictionary SerializedAttributes { get; set; } ThreadLocal callContext = new ThreadLocal(); public CallContext CallContext => callContext.Value; diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 898eb18f..90612c07 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Saving { public class KerasObjectLoader { - internal static readonly IDictionary PUBLIC_ATTRIBUTES = new CommonEndPoints().CheckpointableObjects; + internal static readonly IDictionary PUBLIC_ATTRIBUTES; private SavedMetadata _metadata; private SavedObjectGraph _proto; private Dictionary _node_paths = new Dictionary(); @@ -39,7 +39,13 @@ namespace Tensorflow.Keras.Saving static KerasObjectLoader() { - PUBLIC_ATTRIBUTES[Keras.Saving.SavedModel.Constants.KERAS_ATTR] = null; + var endPoints = new CommonEndPoints(); + PUBLIC_ATTRIBUTES = new Dictionary(); + foreach (var key in endPoints._all_checkpointable_objects.Concat(endPoints._all_functions)) + { + PUBLIC_ATTRIBUTES[key] = null; + } + PUBLIC_ATTRIBUTES[SavedModel.Constants.KERAS_ATTR] = null; } public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def) @@ -125,8 +131,14 @@ namespace Tensorflow.Keras.Saving continue; } - // TODO: deal with `RevivedLayer` and `RevivedInputLayer`. - layers_revived_from_config.Add(node as Layer); + if(node is RevivedLayer or RevivedInputLayer) + { + layers_revived_from_saved_model.Add(node as Layer); + } + else + { + layers_revived_from_config.Add(node as Layer); + } } _finalize_saved_model_layers(layers_revived_from_saved_model); @@ -171,10 +183,13 @@ namespace Tensorflow.Keras.Saving // TODO(Rinne): implement it } } - - // `model.__init__(layers, config["name"])` - s.InitLayers(layers); - s.Name = config["name"].ToObject(); + + // `model.__init__(layers, config["name"])`InitLayers(layers); + s = new Sequential(new SequentialArgs(){ + Layers = layers.Select(x => x as ILayer).ToList(), + Name = config["name"].ToObject() + }); + //s.Name = config["name"].ToObject(); if(s.input is null || s.input.Length == 0) { var first_layer = _get_child_layer_node_ids(model_id)[0]; @@ -205,7 +220,12 @@ namespace Tensorflow.Keras.Saving private void _set_network_attributes_from_metadata(Model revived_object) { - // TODO: implement it. + var metadata = revived_object.SerializedAttributes["matadata"] as JObject; + if (metadata.ContainsKey("dtype")) + { + // TODO(Rinne): set_dtype_policy. + } + revived_object.args.Trainable = metadata["trainable"].Value(); } /// @@ -330,7 +350,7 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _revive_from_config(string identifier, KerasMetaData metadata, int node_id) { Trackable obj; - if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) + if(identifier == SavedModel.Constants.METRIC_IDENTIFIER) { // TODO(Rinne): implement it. return (null, null); @@ -429,25 +449,26 @@ namespace Tensorflow.Keras.Saving return obj; } - private void _revive_setter(object layer, object name, object value) + private void _revive_setter(object obj, object name, object value) { Debug.Assert(name is string); - Debug.Assert(layer is Layer); + Debug.Assert(obj is Layer); + Layer layer = (Layer)obj; if(PUBLIC_ATTRIBUTES.ContainsKey(name as string)) { if(value is Trackable) { - (layer as Layer)._track_trackable(value as Trackable, name as string); + layer._track_trackable(value as Trackable, name as string); } - if((layer as Layer).SerializedAttributes is null) + if(layer.SerializedAttributes is null) { - (layer as Layer).SerializedAttributes = new JObject(); + layer.SerializedAttributes = new Dictionary(); } - (layer as Layer).SerializedAttributes[name as string] = JToken.FromObject(value); + layer.SerializedAttributes[name as string] = value; } - else if(layer is Functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) + else if(layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) { - (layer as Functional)._track_trackable(value as Trackable, name as string, overwrite: true); + functional._track_trackable(value as Trackable, name as string, overwrite: true); } else { @@ -521,7 +542,7 @@ namespace Tensorflow.Keras.Saving } var metric_list_node_id = _search_for_child_node(node_id, new string[] { - Keras.Saving.SavedModel.Constants.KERAS_ATTR, "layer_metrics" + SavedModel.Constants.KERAS_ATTR, "layer_metrics" }); if(metric_list_node_id is not null && obj is Model model && model.metrics is not null) { @@ -547,7 +568,7 @@ namespace Tensorflow.Keras.Saving // skip the check for registered identifier Action setter; - if (Keras.Saving.SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) + if (SavedModel.Constants.KERAS_OBJECT_IDENTIFIERS.Contains(obj_child.ObjectIdentifier)) { setter = _revive_setter; } @@ -659,7 +680,23 @@ namespace Tensorflow.Keras.Saving private void _maybe_add_serialized_attributes(Layer layer, KerasMetaData metadata) { - // TODO: deal with `RevivedLayer` + if(layer.SerializedAttributes is null || layer.SerializedAttributes.Count == 0) + { + layer.SerializedAttributes = new Dictionary(); + layer.SerializedAttributes["metadata"] = metadata; + } + } + + private static object _get_keras_attr(Layer layer) + { + if((layer.SerializedAttributes ?? new Dictionary()).TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) + { + return value; + } + else + { + return null; + } } /// diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs index 4dc56130..d2c4a55a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs @@ -24,17 +24,22 @@ namespace Tensorflow.Keras.Saving.SavedModel } } - public static void _revive_setter(object layer, object name, object value) + public static void _revive_setter(object obj, object name, object value) { Debug.Assert(name is string); - Debug.Assert(layer is Layer); + Debug.Assert(obj is Layer); + Layer layer = (Layer)obj; if (KerasObjectLoader.PUBLIC_ATTRIBUTES.ContainsKey(name as string)) { if (value is Trackable trackable) { - (layer as Layer)._track_trackable(trackable, name as string); + layer._track_trackable(trackable, name as string); } - (layer as Layer).SerializedAttributes[name] = JToken.FromObject(value); + if (layer.SerializedAttributes is null) + { + layer.SerializedAttributes = new Dictionary(); + } + layer.SerializedAttributes[name as string] = value; } else if (layer is Functional functional && Regex.Match(name as string, @"^layer(_with_weights)?-[\d+]").Success) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs new file mode 100644 index 00000000..639d3aa0 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Saving.SavedModel +{ + public class RevivedInputLayer: Layer + { + private RevivedInputLayer(): base(null) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs index cb375c9c..4df6613f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs @@ -55,6 +55,21 @@ namespace Tensorflow.Keras.Saving.SavedModel private RevivedConfig _config = null; + public object keras_api + { + get + { + if (SerializedAttributes.TryGetValue(SavedModel.Constants.KERAS_ATTR, out var value)) + { + return value; + } + else + { + return null; + } + } + } + public RevivedLayer(LayerArgs args): base(args) { @@ -69,5 +84,17 @@ namespace Tensorflow.Keras.Saving.SavedModel { return _config; } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) + { + return base.Call(inputs, state, training); + } + else + { + return (func as Function).Apply(inputs); + } + } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index ac194c00..db3b782e 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -19,8 +19,8 @@ namespace Tensorflow.Keras.Saving.SavedModel protected IDictionary _object_dict; protected IDictionary _function_dict; protected AutoTrackable _keras_trackable; - protected HashSet _all_functions; - protected HashSet _all_checkpointable_objects; + internal HashSet _all_functions; + internal HashSet _all_checkpointable_objects; private SerializedAttributes() { @@ -197,19 +197,15 @@ namespace Tensorflow.Keras.Saving.SavedModel public class CommonEndPoints: SerializedAttributes { public CommonEndPoints(IEnumerable checkpointable_objects, IEnumerable functions) : - //base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), - // functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) - base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables"}), - functions.Concat(new string[] { })) + base(checkpointable_objects.Concat(new string[] { "variables", "trainable_variables", "regularization_losses" }), + functions.Concat(new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" })) { } public CommonEndPoints() : - //base(new string[] { "variables", "trainable_variables", "regularization_losses" }, - // new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) - base(new string[] { "variables", "trainable_variables"}, - new string[] {}) + base(new string[] { "variables", "trainable_variables", "regularization_losses" }, + new string[] { "__call__", "call_and_return_all_conditional_losses", "_default_save_signature" }) { }