From 3943375b67a437ee8e8d763296c471d8c04ace80 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 4 Apr 2023 01:24:01 +0800 Subject: [PATCH] Support loading weights for customized layer. --- src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs | 25 +- .../Checkpoint/checkpoint.cs | 10 +- .../Checkpoint/functional_saver.cs | 31 +- src/TensorFlowNET.Core/Checkpoint/restore.cs | 3 +- src/TensorFlowNET.Core/Contexts/Context.cs | 5 + .../Framework/Models/DenseSpec.cs | 7 +- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 2 +- .../Operations/handle_data_util.cs | 15 +- .../Operations/resource_variable_ops.cs | 96 +++-- .../Protobuf/SavedObjectGraph.cs | 8 +- .../Training/AutoTrackable.cs | 19 + .../Saving/ResourceVariableSaveable.cs | 24 +- .../Training/Saving/SaveSpec.cs | 47 ++- .../Saving/SavedModel/RevivedTypes.cs | 33 +- .../Training/Saving/SavedModel/loader.cs | 27 +- .../Saving/saveable_object_util.py.cs | 83 ++-- src/TensorFlowNET.Core/Training/Trackable.cs | 68 +++- .../Training/data_structures.cs | 371 ++++++++++++++++-- src/TensorFlowNET.Core/Util/nest.py.cs | 8 + .../Variables/ResourceVariable.cs | 2 +- src/TensorFlowNET.Keras/BackendImpl.cs | 8 + src/TensorFlowNET.Keras/Engine/Layer.cs | 57 +++ src/TensorFlowNET.Keras/Engine/Model.cs | 53 ++- .../Saving/KerasObjectLoader.cs | 51 ++- .../Saving/SavedModel/ReviveUtils.cs | 14 +- .../Saving/SavedModel/Save.cs | 15 +- .../Utils/base_layer_utils.cs | 8 + .../Utils/compile_utils.cs | 22 ++ src/TensorFlowNET.Keras/Utils/tf_utils.cs | 25 ++ .../SaveModel/SequentialModelLoad.cs | 3 + 30 files changed, 942 insertions(+), 198 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Utils/compile_utils.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs index 8b8cbf61..84e5f75c 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Checkpoint ); public static class SaveUtil { - public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + public static (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) serialize_graph_view(ObjectGraphView graph_view, IDictionary? object_map = null, bool call_with_mapped_captures = false, object? cache = null) { var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map); @@ -119,16 +119,16 @@ namespace Tensorflow.Checkpoint /// /// /// - private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, + private static IDictionary>>> get_and_write_tensors_to_serialize(IList tensor_trackables, IDictionary node_ids, bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto) { - Dictionary>>> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach(var td in tensor_trackables) { // TODO: deal with cache. var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? ""; Trackable trackable = null; - IDictionary>> tensor_dict; + IDictionary>> tensor_dict; if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0) { (trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto); @@ -150,12 +150,12 @@ namespace Tensorflow.Checkpoint return serialized_tensors; } - private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) + private static IDictionary>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { var trackable = trackable_data.object_to_save; // TODO: complete it. Note that actually `call_with_mapped_captures` is of function type. - IDictionary>> ret_tensor_dict; + IDictionary>> ret_tensor_dict; if (call_with_mapped_captures) { throw new NotImplementedException(); @@ -165,8 +165,7 @@ namespace Tensorflow.Checkpoint ret_tensor_dict = trackable.serialize_to_tensors(); } - // TODO: deal with the type `SaveSpce` (currently it will never be it). - Dictionary>> tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach(var pair in ret_tensor_dict) { var local_name = TrackableUtils.escape_local_name(pair.Key); @@ -175,10 +174,12 @@ namespace Tensorflow.Checkpoint tensor_dict[checkpoint_key] = maybe_tensor; - if(maybe_tensor.IsTypeOrDeriveFrom()) + foreach(var key in maybe_tensor.Keys) { - throw new NotImplementedException(); - //((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name; + if (maybe_tensor[key].IsTypeOrDeriveFrom()) + { + maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name; + } } if(object_graph_proto is not null) @@ -202,7 +203,7 @@ namespace Tensorflow.Checkpoint /// /// /// - private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, + private static (Trackable, IDictionary>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary node_ids, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto) { Dictionary object_names = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 445fd685..c736c164 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -45,12 +45,12 @@ public class TrackableSaver _graph_view = graph_view; // TODO: cache when not executing eagerly. - // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder`, + // including `_cache`, `_file_prefix_feed_tensor`, `_file_prefix_placeholder` // `_object_graph_feed_tensor`, `_object_map`, `_restore_op_cache`, `_saveables_cache` } - private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) + private (IDictionary>>>, IDictionary, IDictionary>, TrackableObjectGraph) gather_serialized_tensors(Tensor? object_graph_tensor = null) { var (serialized_tensors, feed_additions, registered_savers, graph_proto) = SaveUtil.serialize_graph_view(_graph_view, _object_map, cache:_cache); @@ -69,9 +69,10 @@ public class TrackableSaver Debug.Assert(!serialized_tensors.ContainsKey(Trackable.None) || !serialized_tensors[Trackable.None].ContainsKey(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); if (!serialized_tensors.ContainsKey(Trackable.None)) { - serialized_tensors[Trackable.None] = new Dictionary>>(); + serialized_tensors[Trackable.None] = new Dictionary>>(); } - serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = object_graph_tensor; + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY] = new Dictionary>(); + serialized_tensors[Trackable.None][Trackable.Constants.OBJECT_GRAPH_PROTO_KEY].Add(saveable_object_util.NO_SLICE_SPEC_KEY, object_graph_tensor); return (serialized_tensors, feed_additions, registered_savers, graph_proto); } @@ -387,6 +388,7 @@ public class CheckpointRestoreCoordinator /// public List AllTrackables => _all_trackables; public HashSet MatchedProtoIds => _matched_proto_ids; + // TODO(Rinne): change to weak ref. public Dictionary ObjectByProtoId => _object_by_proto_id; public int RestoreUid => _restore_uid; public TrackableObjectGraph ObjectGraphProto => _object_graph_proto; diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 3b49fa8d..c383c219 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -160,12 +160,12 @@ namespace Tensorflow.Checkpoint /// A dictionary mapping `Trackable` to a tensor dict, which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. /// /// - public MultiDeviceSaver(IDictionary>>> serialized_tensors, + public MultiDeviceSaver(IDictionary>>> serialized_tensors, IDictionary>? registered_savers = null, bool call_with_mapped_capture = false) { _keys_to_restore_fn = new Dictionary<(string, string), RestoreFunc>(); _restore_fn_to_keys = new Dictionary>(); - Dictionary>> tensors_by_device= new(); + Dictionary>>> tensors_by_device= new(); foreach(var pair in serialized_tensors) { @@ -191,16 +191,7 @@ namespace Tensorflow.Checkpoint foreach(var item in tensor_dict) { var checkpoint_key = item.Key; - IDictionary spec_to_tensor; - if(item.Value.TryPickT0(out var t, out var dic)) - { - spec_to_tensor = new Dictionary(); - spec_to_tensor[""] = t; - } - else - { - spec_to_tensor = dic; - } + var spec_to_tensor = item.Value; foreach(var spec in spec_to_tensor) { @@ -216,11 +207,19 @@ namespace Tensorflow.Checkpoint _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); // skip the process of device name because lack of API. - var host_device = tensor.Device; - var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>()); + string host_device; + if (tensor.IsT0) + { + host_device = tensor.AsT0.Device; + } + else + { + host_device = tensor.AsT1.device; + } + var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>>()); if (!internal_dict.ContainsKey(checkpoint_key)) { - internal_dict[checkpoint_key] = new Dictionary(); + internal_dict[checkpoint_key] = new Dictionary>(); } internal_dict[checkpoint_key][slice_spec] = tensor; } @@ -425,7 +424,7 @@ namespace Tensorflow.Checkpoint public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) { - Dictionary>>> serialized_tensors = new(); + Dictionary>>> serialized_tensors = new(); foreach (var saveable in saveables) { var trackable = new SaveableCompatibilityConverter(saveable, new List() { saveable }); diff --git a/src/TensorFlowNET.Core/Checkpoint/restore.cs b/src/TensorFlowNET.Core/Checkpoint/restore.cs index e2770487..0e1a300e 100644 --- a/src/TensorFlowNET.Core/Checkpoint/restore.cs +++ b/src/TensorFlowNET.Core/Checkpoint/restore.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Security; using System.Text; using Tensorflow.Train; using Tensorflow.Training; @@ -50,7 +51,7 @@ public class CheckpointPosition { _checkpoint.AllTrackables.Add(trackable); _checkpoint.MatchedProtoIds.Add(_proto_id); - if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment)) + if(_checkpoint.ObjectByProtoId.TryGetValue(_proto_id, out var current_assignment) && current_assignment is not null) { // skip the `logging.warning`. return false; diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index e1cce1b0..deb67920 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -120,6 +120,11 @@ namespace Tensorflow.Contexts name : "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + public string anonymous_name() + { + return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + } + public void graph_mode(bool isFunc = false) => context_switches.Push(false, isFunc); diff --git a/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs b/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs index 1af29e22..5a89b90e 100644 --- a/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs +++ b/src/TensorFlowNET.Core/Framework/Models/DenseSpec.cs @@ -6,8 +6,11 @@ public class DenseSpec : TypeSpec { protected Shape _shape; - public Shape shape => _shape; - + public Shape shape + { + get { return _shape; } + set { _shape = value; } + } protected TF_DataType _dtype; public TF_DataType dtype => _dtype; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 6221354f..e0c58966 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -311,7 +311,7 @@ namespace Tensorflow /// const TF_DataType* /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output, + public static extern void TF_GraphSetOutputHandleShapesAndTypes(SafeGraphHandle graph, TF_Output output, int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types, SafeStatusHandle status); diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs index ca690774..5d5fbebb 100644 --- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs +++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs @@ -30,6 +30,18 @@ namespace Tensorflow.Operations } } + public static HandleData create_handle_data(Shape shape, TF_DataType dtype) + { + HandleData handle_data = new(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new HandleShapeAndType() + { + Shape = shape.as_proto(), + Dtype = dtype.as_datatype_enum() + }); + return handle_data; + } + public static void set_handle_data(Tensor target_t, HandleData handle_data) { if(target_t is EagerTensor) @@ -37,7 +49,8 @@ namespace Tensorflow.Operations target_t.HandleData = handle_data; return; } - c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); + // TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) + //c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); } } } diff --git a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs index 83ff50b1..7921f28b 100644 --- a/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/resource_variable_ops.cs @@ -21,6 +21,9 @@ using Tensorflow.Train; using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Variables; using static Tensorflow.CppShapeInferenceResult.Types; +using static Tensorflow.Binding; +using Tensorflow.Operations; +using System.Buffers; namespace Tensorflow { @@ -31,6 +34,7 @@ namespace Tensorflow { public static Operation shape_safe_assign_variable_handle(Tensor handle, int[] shape, Tensor value, string name = null) { + // TODO(Rinne): deal with `_handle_graph`. var value_tensor = ops.convert_to_tensor(value); return gen_resource_variable_ops.assign_variable_op(handle, value_tensor, @@ -78,6 +82,18 @@ namespace Tensorflow string shared_name, string name, bool graph_mode, Tensor initial_value = null) { var container = ops.get_default_graph().Container; + if(container is null) + { + container = ""; + } + if (!graph_mode) + { + if(shared_name is not null) + { + throw new Exception("Using an explicit shared_name is not allowed when executing eagerly."); + } + shared_name = tf.Context.anonymous_name(); + } var handle = gen_resource_variable_ops.var_handle_op(shape: shape, dtype: dtype, shared_name: shared_name, @@ -95,26 +111,20 @@ namespace Tensorflow } else { - // We do not want two distinct ResourceVariable objects for the same - // underlying resource in the runtime. - // When in eager mode, explicitly ensure so here. When in graph mode, it's - // ensured by always generating different variable names. - var exists = gen_resource_variable_ops.var_is_initialized_op(handle); - - // We create an assert Op instead of checking right away in order to be - // compatible with ASYNC execution mode. Further, since not all devices - // support string tensors, we encode the assertion string in the Op name - /*gen_logging_ops.assert(gen_math_ops.logical_not(exists), - new[] { exists }, - name: "EagerVariableNameReuse");*/ - - var handle_data = new HandleData(); - handle_data.IsSet = true; - handle_data.ShapeAndType.Add(new HandleShapeAndType + var handle_data = handle_data_util.create_handle_data(shape, dtype); + if (initial_value is not null && initial_value.dtype == dtypes.variant) { - Dtype = dtype.as_datatype_enum(), - Shape = shape.as_proto() - }); + var extra_handle_data = get_eager_safe_handle_data(initial_value); + if (extra_handle_data is not null && extra_handle_data.IsSet) + { + if (!handle_data.IsSet || handle_data.ShapeAndType.Count != 1) + { + throw new RuntimeError($"Expected VarHandleOp to return a length==1 shape_and_type, " + + $"but saw: '{handle_data}'"); + } + handle_data.ShapeAndType.AddRange(extra_handle_data.ShapeAndType); + } + } _set_handle_shapes_and_types(handle, handle_data, graph_mode); return handle; } @@ -126,24 +136,48 @@ namespace Tensorflow /// /// /// - internal static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) + internal unsafe static void _set_handle_shapes_and_types(Tensor tensor, HandleData handle_data, bool graph_mode) { + tensor.HandleData = handle_data; if (!graph_mode) return; - var size = handle_data.ShapeAndType.Count; + //var shapes = handle_data.ShapeAndType.Select(x => x.Shape); + //var types = handle_data.ShapeAndType.Select(x => x.Dtype).ToArray(); + //var ranks = shapes.Select(s => s.UnknownRank ? -1 : s.Dim.Count).ToArray(); + //var converted_shapes = shapes.Select>(s => + //{ + // if (!s.UnknownRank) + // { + // return s.Dim.Select(d => (int)d.Size).ToArray(); + // } + // else + // { + // return Memory.Empty; + // } + //}).ToArray(); - var shapes = new IntPtr[size]; - var types = new DataType[size]; - var ranks = new int[size]; + //List handles = new(); + //IntPtr[] shapes_with_ptr = new IntPtr[converted_shapes.Length]; + //foreach(var (i, m) in enumerate(converted_shapes)) + //{ + // if(m.IsEmpty) + // { + // shapes_with_ptr[i] = IntPtr.Zero; + // } + // else + // { + // var handle = m.Pin(); + // handles.Add(handle); + // shapes_with_ptr[i] = new IntPtr(handle.Pointer); + // } + //} - for (int i = 0; i < size; i++) - { - var shapeAndType = handle_data.ShapeAndType[i]; - types[i] = shapeAndType.Dtype; - ranks[i] = shapeAndType.Shape.UnknownRank ? -1 : shapeAndType.Shape.Dim.Count; - var dims = shapeAndType.Shape.Dim.Select(x => x.Size).ToArray(); - } + //Status status = new(); + //// TODO(Rinne): enable it. + //c_api.TF_GraphSetOutputHandleShapesAndTypes(tensor.op.graph.c_graph, tensor._as_tf_output(), + // shapes_with_ptr.Length, shapes_with_ptr, ranks, types, status); + //handles = null; } /// diff --git a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs index 3d056cae..e75820a9 100644 --- a/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs +++ b/src/TensorFlowNET.Core/Protobuf/SavedObjectGraph.cs @@ -330,7 +330,7 @@ namespace Tensorflow { private static readonly pb::FieldCodec _repeated_children_codec = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private static readonly pb::FieldCodec _repeated_dependencies_codec - = pb::FieldCodec.ForMessage(10, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); + = pb::FieldCodec.ForMessage(122, global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.ObjectReference.Parser); private readonly pbc::RepeatedField children_ = new pbc::RepeatedField(); private readonly pbc::RepeatedField dependencies_ = new pbc::RepeatedField(); /// @@ -698,9 +698,13 @@ namespace Tensorflow { break; case 10: { children_.AddEntriesFrom(input, _repeated_children_codec); - dependencies_.AddRange(children_.Except(dependencies_)); break; } + case 122: + { + dependencies_.AddEntriesFrom(input, _repeated_dependencies_codec); + break; + } case 26: { slotVariables_.AddEntriesFrom(input, _repeated_slotVariables_codec); break; diff --git a/src/TensorFlowNET.Core/Training/AutoTrackable.cs b/src/TensorFlowNET.Core/Training/AutoTrackable.cs index 4ba3e407..20631ce8 100644 --- a/src/TensorFlowNET.Core/Training/AutoTrackable.cs +++ b/src/TensorFlowNET.Core/Training/AutoTrackable.cs @@ -3,6 +3,7 @@ using System.Linq; using Tensorflow.Functions; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Operations.Activation; +using Tensorflow.Training; using static Tensorflow.Binding; namespace Tensorflow.Train @@ -25,6 +26,13 @@ namespace Tensorflow.Train } } + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with `self_setattr_tracking`. + value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + base.SetAttr(name, value); + } + public override IDictionary _trackable_children(SaveType save_type, IDictionary>? cache = null) { if(save_type != SaveType.SAVEDMODEL) @@ -34,6 +42,7 @@ namespace Tensorflow.Train Dictionary functions = new(); // TODO: process of logs. + // TODO(Rinne): deal with members. var properties = this.GetType().GetProperties(); foreach ( var property in properties ) { @@ -45,6 +54,16 @@ namespace Tensorflow.Train } } + foreach(var item in CustomizedFields) + { + var name = item.Key; + var value = item.Value; + if (value is Function or ConcreteFunction) + { + functions[name] = (Trackable)value; + } + } + // TODO: process the type `core_types.GenericFunction`. Dictionary children = new(); diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 2d23a325..e2bdafab 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -42,22 +42,25 @@ namespace Tensorflow _var_device = var.Device; _var_shape = var.shape; - Tensor _read_variable_closure(BaseResourceVariable v) + Func _read_variable_closure(BaseResourceVariable v) { - tf.device(v.Device); - if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + return () => { - return null; - } - var x = v.read_value_no_copy(); - tf.device("/device:CPU:0"); - return array_ops.identity(x); + tf.device(v.Device); + if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + tf.device("/device:CPU:0"); + return array_ops.identity(x); + }; } this.handle_op = var.Handle; - var tensor = _read_variable_closure(var); + var tensor_creator = _read_variable_closure(var); - var spec = new SaveSpec(tensor, slice_spec, name, dtype: var.dtype); + var spec = new SaveSpec(tensor_creator, slice_spec, name, dtype: var.dtype, device: var.Device); _op = var; specs = new SaveSpec[] { spec }; this.name = name; @@ -66,6 +69,7 @@ namespace Tensorflow public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; + tf.device(_var_device); restored_tensor = array_ops.identity(restored_tensor); return resource_variable_ops.shape_safe_assign_variable_handle( handle_op, _var_shape, restored_tensor); diff --git a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs index 393a6a98..2b300c2a 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SaveSpec.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Exceptions; + namespace Tensorflow { /// @@ -21,8 +23,24 @@ namespace Tensorflow /// public class SaveSpec { - private Tensor _tensor; - public Tensor tensor => _tensor; + private Tensor _tensor = null; + private Func _tensor_creator = null; + public Tensor tensor + { + get + { + if(_tensor is not null || _tensor_creator is null) + { + return _tensor; + } + else + { + return _tensor_creator(); + } + } + } + + internal Func TensorCreator => _tensor_creator; private string _slice_spec; public string slice_spec => _slice_spec; @@ -32,13 +50,36 @@ namespace Tensorflow private TF_DataType _dtype; public TF_DataType dtype => _dtype; + private string _device; + public string device => _device; - public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid) + public SaveSpec(Tensor tensor, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) { _tensor = tensor; _slice_spec = slice_spec; _name = name; _dtype = dtype; + if(device is not null) + { + _device = device; + } + else + { + _device = tensor.Device; + } + } + + public SaveSpec(Func tensor_creator, string slice_spec, string name, TF_DataType dtype = TF_DataType.DtInvalid, string device = null) + { + _tensor_creator = tensor_creator; + _slice_spec = slice_spec; + _name = name; + if(dtype == TF_DataType.DtInvalid || device is null) + { + throw new AssertionError("When passing a callable `tensor` to a SaveSpec, an explicit dtype and device must be provided."); + } + _dtype = dtype; + _device = device; } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs index 60188293..5bb7238e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs @@ -1,10 +1,20 @@ using System; +using System.Diagnostics; using Tensorflow.Train; +using Tensorflow.Training; namespace Tensorflow; public class RevivedTypes { + private static Dictionary _registered_revived_creator = new(); + static RevivedTypes() + { + var list_wrapper = new ListWrapper(new Trackable[] { }); + _registered_revived_creator[list_wrapper.Identifier] = list_wrapper; + var dict_wrapper = new DictWrapper(new Dictionary()); + _registered_revived_creator[dict_wrapper.Identifier] = dict_wrapper; + } /// /// Create a SavedUserObject from a trackable object. /// @@ -12,13 +22,28 @@ public class RevivedTypes /// public static SavedUserObject? serialize(Trackable obj) { - // TODO: complete the implementation. + // TODO(Rinne): complete the implementation. return null; } - public static Tuple> deserialize(object proto) + public static (Trackable, Action) deserialize(SavedUserObject proto) { - // TODO: complete the implementation. - return null; + if(_registered_revived_creator.TryGetValue(proto.Identifier, out var wrapper)) + { + return (wrapper.FromProto(proto), (x, y, z) => + { + if (x is not ITrackableWrapper trackable) + { + throw new TypeError($"The type is expected to be `ITrackableWrapper`, but got {x.GetType()}."); + } + Debug.Assert(y is string); + trackable.SetValue(y, z); + } + ); + } + else + { + return (null, null); + } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 53ac9e2a..6e6e62df 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -49,6 +49,7 @@ namespace Tensorflow var temp = _proto.ToString(); _export_dir = export_dir; // TODO: `this._concrete_functions` and `this._restored_concrete_functions` + // TODO(Rinne): This method is very slow, needs to be accelareted. _concrete_functions = function_deserialization.load_function_def_library( meta_graph.GraphDef.Library, _proto); _restored_concrete_functions = new HashSet(); @@ -523,7 +524,7 @@ namespace Tensorflow continue; } setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); - // skip the process of "__call__" + // TODO(Rinne): deal with "__call__" } } @@ -595,13 +596,12 @@ namespace Tensorflow private (Trackable, Action) _recreate_user_object(SavedUserObject? proto, int node_id) { // skip the check of proto identifier because of lack of property. - - var looked_up = RevivedTypes.deserialize(proto); - if(looked_up is null) + var (trackable, setter) = RevivedTypes.deserialize(proto); + if(trackable is null) { return _recreate_base_user_object(proto, node_id); } - return (looked_up.Item1, looked_up.Item2); + return (trackable, setter); } private (Trackable, Action) _recreate_base_user_object(SavedUserObject? proto = null, int? node_id = null) @@ -668,13 +668,20 @@ namespace Tensorflow public static Action setattr = (x, y, z) => { Debug.Assert(y is string); - var properties = x.GetType().GetProperties(); - foreach(var p in properties) + if(x is Trackable trackable) + { + trackable.SetAttr(y as string, z); + } + else { - if((string)y == p.Name) + var properties = x.GetType().GetProperties(); + foreach (var p in properties) { - p.SetValue(x, z); - return; + if ((string)y == p.Name) + { + p.SetValue(x, z); + return; + } } } // TODO(Rinne): check if the property has been set successfully. diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index 5456669e..c4ef751b 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -50,6 +50,10 @@ namespace Tensorflow } public static class saveable_object_util { + public static string NO_SLICE_SPEC_KEY = ""; + private static HashSet _VARIABLE_OPS = new HashSet(new string[] { + "Variable", "VariableV2", "AutoReloadVariable", "VarHandleOp", "ReadVariableOp" + }); /// /// Returns the variables and names that will be used for a Saver. /// @@ -123,19 +127,12 @@ namespace Tensorflow /// public static IEnumerable saveable_objects_for_op(Tensor op, string name) { - if (false) - { - - } + ops.init_scope(); + var variable = ops.convert_to_tensor(op, as_ref: true); + if (variable.dtype.is_ref_dtype()) + yield return new ReferenceVariableSaveable(variable, "", name); else - { - ops.init_scope(); - var variable = ops.convert_to_tensor(op, as_ref: true); - if (variable.dtype.is_ref_dtype()) - yield return new ReferenceVariableSaveable(variable, "", name); - else - yield return new ResourceVariableSaveable(variable, "", name); - } + yield return new ResourceVariableSaveable(variable, "", name); } /// @@ -159,7 +156,7 @@ namespace Tensorflow yield return new ResourceVariableSaveable(variable, "", name); } } - else + else if(obj is not IVariableV1) { foreach(var pair in saveable_objects_from_trackable(obj)) { @@ -191,6 +188,30 @@ namespace Tensorflow } } } + else + { + // Variable + if (tf.Context.executing_eagerly()) + { + throw new ValueError($"Can only save/restore ResourceVariables when " + + $"executing eagerly, got type: {obj.GetType()}."); + } + var variable = ops.convert_to_tensor(obj, as_ref: true); + if (!_tensor_comes_from_variable(variable)) + { + throw new TypeError($"names_to_saveables must be a dict mapping string " + + $"names to Tensors/Variables. Not a variable: {variable}"); + } + if(variable.op.type == "Variable" || variable.op.type == "VariableV2" || + variable.op.type == "AutoReloadVariable") + { + yield return new ReferenceVariableSaveable(variable, "", name); + } + else + { + yield return new ResourceVariableSaveable(variable, "", name); + } + } } /// @@ -267,24 +288,14 @@ namespace Tensorflow foreach (var pair in tensor_dict) { var tensor_name = pair.Key; - var maybe_tensor = pair.Value; + var internal_dict = pair.Value; local_names.Add(tensor_name); string spec_name = name + TrackableUtils.escape_local_name(tensor_name); - IDictionary internal_dict; - if (maybe_tensor.TryPickT0(out var tensor, out var dic)) - { - internal_dict = new Dictionary(); - internal_dict[""] = tensor; - } - else - { - internal_dict = dic; - } - foreach (var item in internal_dict) { - specs.Add(new SaveSpec(item.Value, item.Key, spec_name)); + Debug.Assert(item.Value.IsT0); + specs.Add(new SaveSpec(item.Value.AsT0, item.Key, spec_name)); } } return new TrackableSaveable(obj, specs, name, local_names, prefix); @@ -316,9 +327,9 @@ namespace Tensorflow /// Converts a list of SaveableObjects to a tensor dictionary. /// /// - public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) + public static Dictionary>> saveable_object_to_tensor_dict(IList saveables) { - Dictionary>> tensor_dict = new(); + Dictionary>> tensor_dict = new(); foreach (var saveable in saveables) { foreach (var spec in saveable.specs) @@ -326,14 +337,11 @@ namespace Tensorflow // skip the check that if `spec` is callable. var name = convert_to_string(spec.name); var slice_spec = convert_to_string(spec.slice_spec); - if (!string.IsNullOrEmpty(slice_spec)) - { - tensor_dict.SetDefault(name, new Dictionary()).AsT1[slice_spec] = spec.tensor; - } - else + if (string.IsNullOrEmpty(slice_spec)) { - tensor_dict[name] = spec.tensor; + slice_spec = NO_SLICE_SPEC_KEY; } + tensor_dict.SetDefault(name, new Dictionary>())[slice_spec] = spec.TensorCreator is null ? spec.tensor : spec; } } return tensor_dict; @@ -397,6 +405,11 @@ namespace Tensorflow { return factory(key); } + + private static bool _tensor_comes_from_variable(object v) + { + return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); + } } public class SaveableCompatibilityConverter: Trackable @@ -412,7 +425,7 @@ namespace Tensorflow public object Obj => _obj; public IList mySaveables=> _saveables; - public override IDictionary>> serialize_to_tensors() + public override IDictionary>> serialize_to_tensors() { return saveable_object_util.saveable_object_to_tensor_dict(_saveables); } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index b64b5ebc..2b5bf2a7 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -85,6 +85,72 @@ namespace Tensorflow.Train _self_saveable_object_factories = value; } } + public Dictionary CustomizedFields { get; set; } = new Dictionary(); + + public virtual void SetAttr(string name, object value) + { + var t = this.GetType(); + var field_info = t.GetField(name); + if(field_info is not null) + { + field_info.SetValue(this, value); + } + else + { + CustomizedFields[name] = value; + } + + // On account of performance, we don't use reflection to set the attribute if it exists in `Trackable`. + // When adding new members or properties to this class, please add corresponding process to this method. + //switch (name) + //{ + // case "_manual_tracking": + // { + // _manual_tracking = (bool)value; + // break; + // } + // case "_self_saveable_object_factories": + // { + // _self_saveable_object_factories = (IDictionary>>)value; + // break; + // } + // case "_self_update_uid": + // { + // _self_update_uid = (int)value; + // break; + // } + // case "_unconditional_checkpoint_dependencies": + // { + // _unconditional_checkpoint_dependencies = (IList)value; + // break; + // } + // case "_unconditional_deferred_dependencies": + // { + // _unconditional_deferred_dependencies = (Dictionary>)value; + // break; + // } + // case "_unconditional_dependency_names": + // { + // _unconditional_dependency_names = (IDictionary)value; + // break; + // } + // case "SelfSaveableObjectFactories": + // { + // SelfSaveableObjectFactories = (IDictionary>>)value; + // break; + // } + // case "UpdateUid": + // { + // UpdateUid = (int)value; + // break; + // } + // default: + // { + // CustomizedAttributes[name] = value; + // break; + // } + // } + } /// /// Restore-on-create for a variable be saved with this `Checkpointable`. @@ -279,7 +345,7 @@ namespace Tensorflow.Train /// /// /// - public virtual IDictionary>> serialize_to_tensors() + public virtual IDictionary>> serialize_to_tensors() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Training/data_structures.cs b/src/TensorFlowNET.Core/Training/data_structures.cs index 6e3336c9..a8033f59 100644 --- a/src/TensorFlowNET.Core/Training/data_structures.cs +++ b/src/TensorFlowNET.Core/Training/data_structures.cs @@ -2,6 +2,8 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO.Compression; using System.Linq; using System.Linq.Expressions; @@ -25,6 +27,48 @@ namespace Tensorflow.Training } } + static class TrackableWrapperUtils + { + internal static bool ShouldLoad(ITrackableWrapper wrapper, SavedUserObject proto) + { + if (proto.Identifier != wrapper.Identifier) + { + return false; + } + if (wrapper.Version < proto.Version.MinConsumer) + { + return false; + } + if (proto.Version.Producer < wrapper.MinProducerVersion) + { + return false; + } + foreach (var bad_version in proto.Version.BadConsumers) + { + if (bad_version == wrapper.Version) + { + return false; + } + } + return true; + } + + internal static bool is_function(Trackable x) + { + return x is Function or ConcreteFunction; + } + } + + public interface ITrackableWrapper + { + void SetValue(object name, object value); + String Identifier { get; } + int Version { get; } + int MinConsumerVersion { get; } + int MinProducerVersion { get; } + Trackable FromProto(SavedUserObject proto); + } + public abstract class TrackableDataStructure : Trackable { private bool _self_trainable; @@ -36,7 +80,7 @@ namespace Tensorflow.Training _self_extra_variables = new List(); } - public abstract IEnumerable Values { get; } + public abstract ICollection Values { get; } public bool Trainable { get => _self_trainable; set => _self_trainable = value; } public IEnumerable Layers { @@ -134,7 +178,7 @@ namespace Tensorflow.Training /// protected virtual Trackable _track_value(Trackable value, string name) { - value = sticky_attribute_assignment(this, name, value); + value = (Trackable)sticky_attribute_assignment(this, name, value); if(value is IVariableV1) { _self_extra_variables.Add(value as IVariableV1); @@ -148,44 +192,273 @@ namespace Tensorflow.Training return value.Value; } - public static Trackable wrap_or_unwrap(Trackable value) + public static object wrap_or_unwrap(object value) { + if(value is NoDependency dependency) + { + return dependency.Value; + } + if(value is Trackable trackable) + { + return trackable; + } + else if(value is IDictionary obj_dict) + { + return new DictWrapper(obj_dict); + } + else if(value is IList list) + { + return new ListWrapper(list); + } + else + { + return value; + } + } + + public static object sticky_attribute_assignment(Trackable trackable, string name, object value) + { + bool add_dependency = value is not NoDependency; + value = wrap_or_unwrap(value); + if (!add_dependency) + { + return value; + } + if(value is Trackable trackable_obj) + { + trackable._track_trackable(trackable_obj, name, true); + } return value; } + } + // TODO(Rinne): Add Dict wrapper and Tuple wrapper + + public class DictWrapper : TrackableDataStructure, IDictionary, ICloneable, ITrackableWrapper + { + private IDictionary _storage; + private bool _non_string_key; + private bool _external_modification; + private IDictionary _last_wrapped_dict_snapshot; + + public DictWrapper(IDictionary wrapped_dict = null) + { + if(wrapped_dict is not null) + { + _storage = new Dictionary(wrapped_dict); + } + else + { + _storage = new Dictionary(); + } + _update_snapshot(); + } - public static Trackable wrap_or_unwrap(IList value) + public void SetValue(object name, object value) { - return new ListWrapper(value); + Debug.Assert(value is Trackable); + this[name] = value as Trackable; + } + public String Identifier => "trackable_dict_wrapper"; + public int Version => 1; + public int MinConsumerVersion => 1; + public int MinProducerVersion => 1; + public Trackable FromProto(SavedUserObject proto) + { + return new DictWrapper(new Dictionary()); } - public static Trackable wrap_or_unwrap(IEnumerable value) + public Trackable this[object key] { - return new ListWrapper(value.ToList()); + get + { + return _storage[key]; + } + set + { + _check_self_external_modification(); + _maybe_initialize_trackable(); + bool no_dep = value is NoDependency; + if(key is string) + { + value = _track_value(value, key); + } + else + { + value = (Trackable)wrap_or_unwrap(value); + if(!no_dep && value is Trackable) + { + _non_string_key = true; + } + } + _storage[key] = value; + _update_snapshot(); + } } - protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, Trackable value) + public ICollection Keys => _storage.Keys; + + public override ICollection Values => _storage.OrderBy(x => x.Key).Select(x => x.Value).ToArray(); + + public void Add(object key, Trackable value) { - value = wrap_or_unwrap(value); - trackable._track_trackable(value, name, true); - return value; + _storage[key] = value; } - protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, NoDependency value) + public bool ContainsKey(object key) { - var wrapped_value = wrap_or_unwrap(value); - trackable._track_trackable(wrapped_value, name, true); - return wrapped_value; + return _storage.ContainsKey(key); } - protected static Trackable sticky_attribute_assignment(Trackable trackable, string name, IList value) + public bool Remove(object key) { - var wrapped_value = wrap_or_unwrap(value); - trackable._track_trackable(wrapped_value, name, true); - return wrapped_value; + _check_self_external_modification(); + var res = _storage.Remove(key); + _update_snapshot(); + return res; } - } - public class ListWrapper : TrackableDataStructure, IList, ICloneable + public bool TryGetValue(object key, out Trackable value) + { + return _storage.TryGetValue(key, out value); + } + + public int Count => _storage.Count; + + public bool IsReadOnly => _storage.IsReadOnly; + + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + public void Clear() + { + _storage.Clear(); + _update_snapshot(); + } + + public bool Contains(KeyValuePair item) + { + return _storage.Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _storage.CopyTo(array, arrayIndex); + } + + public bool Remove(KeyValuePair item) + { + _check_self_external_modification(); + var res = Remove(item); + _update_snapshot(); + return res; + } + + public IEnumerator> GetEnumerator() + { + return _storage.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() => _storage.GetEnumerator(); + + public object Clone() + { + var copied = new DictWrapper(_storage); + copied._external_modification = _external_modification; + copied._non_string_key = _non_string_key; + return copied; + } + + public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) + { + _check_self_external_modification(); + if (_non_string_key) + { + throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed \"" + + $"automatically on attribute assignment). The wrapped dictionary " + + $"contains a non-string key which maps to a trackable object or " + + $"mutable data structure.\n\nIf you don't need this dictionary " + + $"checkpointed, wrap it in a non-trackable " + + $"object; it will be subsequently ignored."); + } + if (_external_modification) + { + throw new ValueError($"Unable to save the object {this} (a dictionary wrapper constructed " + + $"automatically on attribute assignment). The wrapped dictionary was " + + $"modified outside the wrapper (its final value was {this}, its value" + + $" when a checkpoint dependency was added was " + + $"{this._last_wrapped_dict_snapshot}), which breaks " + + $"restoration on object creation.\n\nIf you don't need this " + + $"dictionary checkpointed, wrap it in a " + + $"non-trackable object; it will be subsequently ignored."); + } + Debug.Assert(!Dirty); + var children = base._trackable_children(save_type, cache); + + if(save_type == SaveType.SAVEDMODEL) + { + foreach(var item in _storage) + { + var key = item.Key; + var value = item.Value; + if (TrackableWrapperUtils.is_function(value)) + { + Debug.Assert(key is string); + children[key as string] = value; + } + } + } + + return children; + } + + protected Trackable _track_value(Trackable value, object name) + { + bool string_key = name is string; + if (!string_key) + { + name = "-non_string_key"; + } + try + { + bool no_dependency = value is NoDependency; + value = base._track_value(value, name as string); + if(!(string_key || no_dependency)) + { + _non_string_key = true; + } + return value; + } + catch (ValueError) + { + return (Trackable)sticky_attribute_assignment(this, name as string, value); + } + } + + private bool Dirty => _external_modification || _non_string_key; + + private void _check_self_external_modification() + { + if (Dirty) + { + return; + } + if(!this._storage.SequenceEqual(_last_wrapped_dict_snapshot)) + { + _external_modification = true; + _last_wrapped_dict_snapshot = null; + } + } + + private void _update_snapshot() + { + // TODO(Rinne): deal with attribute_sentinel. + if (Dirty) return; + _last_wrapped_dict_snapshot = new Dictionary(_storage); + } + } + public class ListWrapper : TrackableDataStructure, IList, ICloneable, ITrackableWrapper { private IList _storage; private bool _non_append_mutation_value; @@ -198,11 +471,51 @@ namespace Tensorflow.Training /// modified directly after constructing the `ListWrapper`, and if changes are detected the `ListWrapper` will throw an exception on save. public ListWrapper(IList wrapped_list) { - _storage = wrapped_list; + _storage = new List(wrapped_list); _non_append_mutation_value = _external_modification_value = false; _last_wrapped_list_snapshot = new List(_storage); } + public string Identifier => "trackable_list_wrapper"; + public int Version => 1; + public int MinConsumerVersion => 1; + public int MinProducerVersion => 1; + public Trackable FromProto(SavedUserObject proto) + { + if(TrackableWrapperUtils.ShouldLoad(this, proto)) + { + return new ListWrapper(new Trackable[] { }); + } + else + { + return null; + } + } + public void SetValue(object name, object value) + { + Debug.Assert(name is string); + if(int.TryParse(name as string, out var index)) + { + if(value is not Trackable trackable) + { + throw new TypeError("Cannot set an object which is not trackable to ListWrapper."); + } + if(Count <= index) + { + Add(trackable); + } + else + { + this[index] = trackable; + } + } + else + { + throw new NotImplementedException("Encounter an unexpected behavior in , please " + + "submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues"); + } + } + protected bool NonAppendMuation { get => _non_append_mutation_value; set @@ -222,7 +535,7 @@ namespace Tensorflow.Training } } - public override IEnumerable Values => this; + public override ICollection Values => this; public bool IsReadOnly { get => _storage.IsReadOnly; } /// @@ -239,7 +552,7 @@ namespace Tensorflow.Training private void update_snapshot() { - // TODO: deal with `attribute_sentinel`. + // TODO(Rinne): deal with `attribute_sentinel`. if (_external_modification_value || _non_append_mutation_value) return; _last_wrapped_list_snapshot = new List(_storage); } @@ -286,9 +599,9 @@ namespace Tensorflow.Training { base._track_value(value, name); } - catch(ValueError ex) + catch(ValueError) { - value = sticky_attribute_assignment(this, name, value); + value = (Trackable)sticky_attribute_assignment(this, name, value); } return value; } @@ -343,7 +656,11 @@ namespace Tensorflow.Training update_snapshot(); } - public void Clear() => _storage.Clear(); + public void Clear() + { + _storage.Clear(); + update_snapshot(); + } public bool Contains(Trackable item) => _storage.Contains(item); diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index d04e6bff..c4537896 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -519,6 +519,14 @@ namespace Tensorflow.Util return pack_sequence_as(structure, mapped_flat_structure) as Tensor; } + public static T2 map_structure(Func func, T1 structure) where T2: class + { + var flat_structure = flatten(structure); + var mapped_flat_structure = flat_structure.Select(func).Select(x => (object)x); + + return pack_sequence_as(structure, mapped_flat_structure) as T2; + } + /// /// Same as map_structure, but with only one structure (no combining of multiple structures) /// diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 7d0ac4f8..dcf9fbe6 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -97,7 +97,7 @@ namespace Tensorflow else { unique_id = $"{handle_name}_{ops.uid()}"; - shared_name = tf.Context.shared_name(); + shared_name = null; } var attr = new AttrValue(); diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 01aa59b9..d13990a0 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -60,7 +60,15 @@ namespace Tensorflow.Keras public void track_variable(IVariableV1 v) { + if (tf.Context.executing_eagerly()) + { + return; + } var graph = v.Graph; + if(graph is null) + { + graph = get_graph(); + } _GRAPH_VARIABLES[graph.graph_key] = v; } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 99ee66c2..0a06df2c 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -21,10 +21,13 @@ using System.Linq; using System.Threading; using Tensorflow.Eager; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; using Tensorflow.NumPy; using Tensorflow.Train; +using Tensorflow.Training; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -349,5 +352,59 @@ namespace Tensorflow.Keras.Engine { } + + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with "_self_setattr_tracking". + + value = TrackableDataStructure.sticky_attribute_assignment(this, name, value); + + foreach(var val in nest.flatten(value)) + { + if(val is Metric) + { + // TODO(Rinne): deal with metrics. + } + } + + // TODO(Rinne): deal with "_auto_track_sub_layers". + + foreach(var val in nest.flatten(value)) + { + if(val is not IVariableV1 variable) + { + continue; + } + if (variable.Trainable) + { + if (_trainable_weights.Contains(variable)) + { + continue; + } + _trainable_weights.Add(variable); + } + else + { + if (_non_trainable_weights.Contains(variable)) + { + continue; + } + _non_trainable_weights.Add(variable); + } + keras.backend.track_variable(variable); + } + + // Directly use the implementation of `Trackable`. + var t = this.GetType(); + var field_info = t.GetField(name); + if (field_info is not null) + { + field_info.SetValue(this, value); + } + else + { + CustomizedFields[name] = value; + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index c1d29f59..a3676007 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -1,7 +1,12 @@ -using Tensorflow.Keras.ArgsDefinition; +using System.Diagnostics; +using Tensorflow.Framework.Models; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; +using Tensorflow.Keras.Utils; using Tensorflow.Train; +using Tensorflow.Util; namespace Tensorflow.Keras.Engine { @@ -22,14 +27,16 @@ namespace Tensorflow.Keras.Engine IOptimizer optimizer; IVariableV1 _steps_per_execution; protected bool _is_graph_network; - protected Tensors inputs; + public Tensors inputs; protected Tensors outputs; + protected List input_names; public string[] output_names; IVariableV1 _train_counter; IVariableV1 _test_counter; IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; + TensorSpec _saved_model_inputs_spec; public bool IsGraphNetwork => _is_graph_network; @@ -45,6 +52,38 @@ namespace Tensorflow.Keras.Engine _init_batch_counters(); } + public void _set_inputs(TensorSpec inputs) + { + _set_save_spec(inputs); + } + + internal void _set_save_spec(TensorSpec inputs) + { + if(_saved_model_inputs_spec is not null) + { + return; + } + var input_names = this.input_names; + if(input_names is null || input_names.Count == 0) + { + input_names = compile_utils.create_pseudo_input_names(inputs); + } + + var flat_inputs = nest.flatten(inputs); + List specs = new(); + foreach(var (name, tensor) in zip(input_names, flat_inputs)) + { + specs.Add(tf_utils.get_tensor_spec(tensor, dynamic_batch: false, name: name)); + } + var packed_specs = nest.pack_sequence_as(inputs, specs) as TensorSpec; + Debug.Assert(specs is not null); + _saved_model_inputs_spec = packed_specs; + if(this is Sequential && _buildInputShape is null) + { + _buildInputShape = nest.map_structure(x => x is null ? null : x.shape, packed_specs); + } + } + internal override void Initialize(LayerArgs args) { _init_batch_counters(); @@ -145,6 +184,16 @@ namespace Tensorflow.Keras.Engine return children; } + public override void SetAttr(string name, object value) + { + // TODO(Rinne): deal with "_self_setattr_tracking". + //if(nest.flatten(value).All(v => v is Layer or IVariableV1 || base_layer_utils.has_weights(v))) + //{ + // this._base_model_initialized; + //} + base.SetAttr(name, value); + } + void IModel.set_stopTraining_true() { diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 90612c07..3b5d3274 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -1,12 +1,14 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using System; +using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; +using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; @@ -17,6 +19,8 @@ using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Utils; using Tensorflow.Train; using Tensorflow.Training; +using Tensorflow.Training.Saving.SavedModel; +using Tensorflow.Util; using ThirdParty.Tensorflow.Python.Keras.Protobuf; using static Tensorflow.ApiDef.Types; using static Tensorflow.Binding; @@ -190,12 +194,13 @@ namespace Tensorflow.Keras.Saving Name = config["name"].ToObject() }); //s.Name = config["name"].ToObject(); - if(s.input is null || s.input.Length == 0) + if(s.inputs is null || s.inputs.Length == 0) { var first_layer = _get_child_layer_node_ids(model_id)[0]; var input_specs = _infer_inputs(first_layer); - var input_shapes = _infer_inputs(first_layer, true); + var input_shapes = _infer_input_shapes(first_layer); // `model._set_inputs(input_specs)` + s._set_inputs(input_specs); // skip the check of input_specs is Dictionary if (!s.Built) @@ -220,12 +225,12 @@ namespace Tensorflow.Keras.Saving private void _set_network_attributes_from_metadata(Model revived_object) { - var metadata = revived_object.SerializedAttributes["matadata"] as JObject; - if (metadata.ContainsKey("dtype")) + var metadata = revived_object.SerializedAttributes["metadata"] as KerasMetaData; + if (metadata.DType != TF_DataType.DtInvalid) { // TODO(Rinne): set_dtype_policy. } - revived_object.args.Trainable = metadata["trainable"].Value(); + revived_object.args.Trainable = metadata.Trainable; } /// @@ -305,6 +310,11 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) { var metadata = JsonConvert.DeserializeObject(metadata_json); + // Debug(Rinne) + if(node_id == 11) + { + Console.WriteLine(); + } if (loaded_nodes.ContainsKey(node_id)) { @@ -472,15 +482,7 @@ namespace Tensorflow.Keras.Saving } else { - var properties = layer.GetType().GetProperties(); - foreach(var p in properties) - { - if(p.Name == name as string && p.GetValue(layer) is not null) - { - return; - } - } - Loader.setattr(layer, name, value); + layer.SetAttr(name as string, value); } } @@ -607,7 +609,7 @@ namespace Tensorflow.Keras.Saving if(build_input_shape is null) { - build_input_shape = _infer_inputs(node_id, convert_to_shapes: true); + build_input_shape = _infer_input_shapes(node_id); } if(build_input_shape is not null) @@ -633,7 +635,7 @@ namespace Tensorflow.Keras.Saving /// /// /// - private Shape _infer_inputs(int layer_node_id, bool convert_to_shapes = false) + private TensorSpec _infer_inputs(int layer_node_id) { var call_fn_id = _search_for_child_node(layer_node_id, new string[] { "call_and_return_all_conditional_losses" }); if(call_fn_id is null) @@ -648,7 +650,22 @@ namespace Tensorflow.Keras.Saving } var call_fn_name = concrete_functions[0]; var call_fn_proto = _proto.ConcreteFunctions[call_fn_name]; - throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + var structured_input_signature = nested_structure_coder.decode_proto(call_fn_proto.CanonicalizedInputSignature); + Debug.Assert(structured_input_signature is IEnumerable); + var first_enumerator = (structured_input_signature as IEnumerable).GetEnumerator(); + first_enumerator.MoveNext(); + var first = first_enumerator.Current; + Debug.Assert(first is IEnumerable); + var inputs_enumerator = (first as IEnumerable).GetEnumerator(); + inputs_enumerator.MoveNext(); + var inputs = inputs_enumerator.Current as TensorSpec; + return inputs; + } + + private Shape _infer_input_shapes(int layer_node_id) + { + var inputs = _infer_inputs(layer_node_id); + return nest.map_structure(x => x.shape, inputs); } private int? _search_for_child_node(int parent_id, IEnumerable path_to_child) diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs index d2c4a55a..6970b04e 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/ReviveUtils.cs @@ -48,19 +48,7 @@ namespace Tensorflow.Keras.Saving.SavedModel } else { - var properties = layer.GetType().GetProperties(); - foreach (var p in properties) - { - if ((string)name == p.Name) - { - if(p.GetValue(layer) is not null) - { - return; - } - p.SetValue(layer, value); - return; - } - } + layer.SetAttr(name as string, value); } } } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 2d2de28b..035b0c92 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -11,7 +11,7 @@ using Tensorflow.Keras.Optimizers; using ThirdParty.Tensorflow.Python.Keras.Protobuf; using static Tensorflow.Binding; using Tensorflow.Training; - +using System.Diagnostics; namespace Tensorflow.Keras.Saving.SavedModel; @@ -135,12 +135,17 @@ public partial class KerasSavedModelUtils if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); })); + var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); Dictionary res = new(); - res["variables"] = variables; - res["trainable_variables"] = trainable_variables; - res["non_trainable_variables"] = non_trainable_variables; - res["layers"] = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + Debug.Assert(variables is Trackable); + Debug.Assert(trainable_variables is Trackable); + Debug.Assert(non_trainable_variables is Trackable); + Debug.Assert(layers is Trackable); + res["variables"] = variables as Trackable; + res["trainable_variables"] = trainable_variables as Trackable; + res["non_trainable_variables"] = non_trainable_variables as Trackable; + res["layers"] = layers as Trackable; return res; } diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index d845f3ca..56190a22 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -165,6 +165,14 @@ namespace Tensorflow.Keras.Utils } } + public static bool has_weights(object obj) + { + var obj_type = obj.GetType(); + return obj_type.GetField("trainable_weights") is not null && + obj_type.GetField("non_trainable_weights") is not null && + obj is not Type; + } + // recusive static bool uses_keras_history(Tensor op_input) { diff --git a/src/TensorFlowNET.Keras/Utils/compile_utils.cs b/src/TensorFlowNET.Keras/Utils/compile_utils.cs new file mode 100644 index 00000000..cd411261 --- /dev/null +++ b/src/TensorFlowNET.Keras/Utils/compile_utils.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework.Models; +using Tensorflow.Util; + +namespace Tensorflow.Keras.Utils +{ + internal static class compile_utils + { + public static List create_pseudo_input_names(TensorSpec inputs) + { + return _create_pseudo_names(inputs, "input_"); + } + + private static List _create_pseudo_names(TensorSpec tensors, string prefix) + { + // TODO(Rinne): align with tensorflow + return new List() { $"{prefix}1" }; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Keras/Utils/tf_utils.cs index b144ec9f..ad31fd7c 100644 --- a/src/TensorFlowNET.Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/tf_utils.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using Tensorflow.Framework; +using Tensorflow.Framework.Models; namespace Tensorflow.Keras.Utils { @@ -69,5 +70,29 @@ namespace Tensorflow.Keras.Utils false_fn: false_fn, name: name); } + + public static TensorSpec get_tensor_spec(Tensor t, bool dynamic_batch = false, string name = null) + { + throw new NotImplementedException("The function is waited to be implemented in the future."); + } + + public static TensorSpec get_tensor_spec(TensorSpec t, bool dynamic_batch = false, string name = null) + { + var spec = t; + if (!dynamic_batch) + { + return spec; + } + var dynamic_batch_spec = new TensorSpec(t.shape, t.dtype, t.name); + var shape = dynamic_batch_spec.shape; + if(shape.rank > 0) + { + var shape_list = shape.as_int_list(); + // TODO(Rinne): check if -1 is equivalent to None in python. + shape_list[0] = -1; + dynamic_batch_spec.shape = new Shape(shape_list); + } + return dynamic_batch_spec; + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 74f610c8..eeb5f9e4 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -64,5 +64,8 @@ public class SequentialModelLoad { var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); model.summary(); + + var x = tf.ones((2, 10)); + var y = model.Apply(x); } }