From 6a9ccea29ffb8af77e07bc31b4934a6f53ec3105 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 5 Apr 2023 12:56:23 +0800 Subject: [PATCH] Resolve some wrong implementations. --- src/TensorFlowNET.Core/Buffers/TF_Buffer.cs | 27 ++++++++++++++ src/TensorFlowNET.Core/Eager/execute.cs | 4 +++ src/TensorFlowNET.Core/Framework/importer.cs | 1 + .../Functions/EagerDefinedFunction.cs | 36 ++++++++++++++----- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 2 +- .../Operations/Operation.cs | 13 +++++++ .../Training/Saving/SavedModel/loader.cs | 28 ++------------- .../Variables/BaseResourceVariable.cs | 10 ++++++ .../Variables/ResourceVariable.cs | 10 ++++-- .../Variables/UninitializedVariable.cs | 10 +++--- src/TensorFlowNET.Keras/Engine/Model.cs | 6 ---- .../Saving/KerasObjectLoader.cs | 5 --- .../Saving/SavedModel/Save.cs | 10 +++--- .../SavedModel/serialized_attributes.cs | 10 ++++-- .../Callbacks/EarlystoppingTest.cs | 4 +-- 15 files changed, 114 insertions(+), 62 deletions(-) diff --git a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs index 7ebdd5b8..c10f7b5f 100644 --- a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs @@ -25,5 +25,32 @@ namespace Tensorflow public IntPtr data; public ulong length; public IntPtr data_deallocator; + + public unsafe Span AsSpan() where T: unmanaged + { + if(length > int.MaxValue) + { + throw new ValueError($"The length {length} is too large to use in the span."); + } + return new Span(data.ToPointer(), (int)length); + } + + public unsafe byte[] ToByteArray() + { + byte[] res = new byte[length]; + if(length > int.MaxValue) + { + byte* root = (byte*)data; + for(ulong i = 0; i < length; i++) + { + res[i] = *(root++); + } + } + else + { + new Span(data.ToPointer(), (int)length).CopyTo(res.AsSpan()); + } + return res; + } } } diff --git a/src/TensorFlowNET.Core/Eager/execute.cs b/src/TensorFlowNET.Core/Eager/execute.cs index 2926f8e2..1804992a 100644 --- a/src/TensorFlowNET.Core/Eager/execute.cs +++ b/src/TensorFlowNET.Core/Eager/execute.cs @@ -18,6 +18,10 @@ namespace Tensorflow.Eager var types = v.Select(t => t.dtype.as_datatype_enum()); return (types.ToArray(), v.ToArray()); } + public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) + { + return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); + } public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) { string device_name = ctx.DeviceName; diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs index a4e6c72e..b569c8e1 100644 --- a/src/TensorFlowNET.Core/Framework/importer.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -149,6 +149,7 @@ namespace Tensorflow foreach (var new_op in graph._add_new_tf_operations()) { var original_device = new_op.Device; + new_op._set_device(original_device); } } diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index 4c2d4c37..fb9db8bd 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -1,9 +1,11 @@ using Google.Protobuf; using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Text; using Tensorflow.Contexts; +using Tensorflow.Eager; using Tensorflow.Graphs; using Tensorflow.Operations; using Tensorflow.Util; @@ -16,6 +18,8 @@ namespace Tensorflow.Functions public int _num_outputs; FuncGraph _func_graph; FunctionDef _definition; + OpDef _signature; + string _name; Tensor[] _func_graph_outputs; public string Name => _func_graph.FuncName; public DataType[] OutputTypes { get; protected set; } @@ -31,6 +35,18 @@ namespace Tensorflow.Functions return _definition; } } + + public OpDef Signature + { + get + { + if( _signature is null) + { + _signature = Definition.Signature; + } + return _signature; + } + } public EagerDefinedFunction(string name, FuncGraph graph, Tensors inputs, Tensors outputs, Dictionary attrs) @@ -75,12 +91,12 @@ namespace Tensorflow.Functions Tensor[] outputs; if (executing_eagerly) { - outputs = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - _func_graph.FuncName, - args, - attrs, - _num_outputs); + outputs = execute.executes( + Signature.Name, + _num_outputs, + args, + attrs, + tf.Context); } else { @@ -135,9 +151,13 @@ namespace Tensorflow.Functions private FunctionDef _get_definition() { var buffer = c_api_util.tf_buffer(); - // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef + Status status = new(); + c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status); + status.Check(true); var proto_data = c_api.TF_GetBuffer(buffer); - throw new NotImplementedException(); + FunctionDef function_def = new(); + function_def.MergeFrom(proto_data.AsSpan()); + return function_def; } } } diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index b086907e..9367414e 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Graphs; /// public class FuncGraph : Graph, IDisposable { - SafeFuncGraphHandle _func_graph_handle; + internal SafeFuncGraphHandle _func_graph_handle; public string FuncName => _graph_key; public Tensors Inputs { get; set; } = new Tensors(); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 751ade5d..28e69886 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -238,6 +238,19 @@ namespace Tensorflow return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); } + [Obsolete("The implementation is not complete.")] + internal void _set_device_from_string(string device_str) + { + // TODO(Rinne): complete it with new C API `SetRequestedDevice`. + //c_api.TF_SetDevice(_handle, device_str); + } + + [Obsolete("The implementation is not complete.")] + internal void _set_device(string device) + { + _set_device_from_string(device); + } + private NodeDef GetNodeDef() { var buffer = new Buffer(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 6e6e62df..6f26e07b 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -45,11 +45,8 @@ namespace Tensorflow _asset_file_def = meta_graph.AssetFileDef; _operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr); _proto = object_graph_proto; - // Debug(Rinne) - var temp = _proto.ToString(); _export_dir = export_dir; - // TODO: `this._concrete_functions` and `this._restored_concrete_functions` - // TODO(Rinne): This method is very slow, needs to be accelareted. + // TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted. _concrete_functions = function_deserialization.load_function_def_library( meta_graph.GraphDef.Library, _proto); _restored_concrete_functions = new HashSet(); @@ -322,11 +319,6 @@ namespace Tensorflow foreach(var (node_id, proto) in _iter_all_nodes()) { var node = get(node_id); - if(node is null) - { - // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. - continue; - } if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) { // Restore Trackable serialize- and restore-from-tensor functions. @@ -390,7 +382,7 @@ namespace Tensorflow var optimizer_object = nodes[optimizer_node_id]; var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; - // TODO: implement it. + // TODO(Rinne): implement it. throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); } @@ -508,21 +500,11 @@ namespace Tensorflow /// private void _add_object_graph_edges(SavedObject proto, int node_id) { - // Debug(Rinne) - if(node_id == 1) - { - Console.WriteLine(); - } var obj = _nodes[node_id]; var setter = _node_setters[node_id]; foreach(var refer in proto.Children) { - if(obj is null) - { - // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. - continue; - } setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); // TODO(Rinne): deal with "__call__" } @@ -553,12 +535,6 @@ namespace Tensorflow private (Trackable, Action) _recreate(SavedObject proto, int node_id, IDictionary nodes) { // skip the registered classes. - if(node_id == 16) - { - // Debug(Rinne) - Console.WriteLine(); - } - Dictionary, Trackable> dependencies = new(); foreach(var item in _get_node_dependencies(proto)) { diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index cc5ee542..faaa0274 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -65,6 +65,8 @@ namespace Tensorflow } public void __init__(bool trainable = true, + Shape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, Tensor handle = null, string name = null, string unique_id = null, @@ -75,6 +77,14 @@ namespace Tensorflow _unique_id = unique_id; this.handle = handle; _name = name; + if(shape is not null) + { + _shape = shape; + } + if(dtype != TF_DataType.DtInvalid) + { + _dtype = dtype; + } // After the handle has been created, set up a way to clean it up when // executing eagerly. We'll hold the only reference to the deleter, so that diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index dcf9fbe6..512e8152 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -116,7 +116,11 @@ namespace Tensorflow } }); - _shape = shape ?? _initial_value.shape; + if(shape is null) + { + shape = _initial_value.shape; + } + dtype = _initial_value.dtype; if (_in_graph_mode) { @@ -135,7 +139,7 @@ namespace Tensorflow { handle = resource_variable_ops.eager_safe_variable_handle( initial_value: _initial_value, - shape: _shape, + shape: shape, shared_name: shared_name, name: name, graph_mode: _in_graph_mode); @@ -154,6 +158,8 @@ namespace Tensorflow } base.__init__(trainable: trainable, + shape: shape, + dtype: dtype, handle: handle, name: name, unique_id: unique_id, diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs index 8ee3c62b..637d0983 100644 --- a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -50,9 +50,9 @@ namespace Tensorflow.Variables { tf_with(ops.name_scope("Read"), _ => { - tf.device(handle.Device); - var value = gen_resource_variable_ops.read_variable_op(handle, dtype); - resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + tf.device(created_handle.Device); + var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value); _graph_element = value; }); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); @@ -63,9 +63,7 @@ namespace Tensorflow.Variables } }); }); - _shape = shape; - _dtype = dtype; - base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name); + base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 1d9e9f06..83702b23 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -199,11 +199,5 @@ namespace Tensorflow.Keras.Engine //} base.SetAttr(name, value); } - - - void IModel.set_stopTraining_true() - { - stop_training = true; - } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 29c29405..aed6769a 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -307,11 +307,6 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) { var metadata = JsonConvert.DeserializeObject(metadata_json); - // Debug(Rinne) - if(node_id == 11) - { - Console.WriteLine(); - } if (loaded_nodes.ContainsKey(node_id)) { diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 035b0c92..331b283a 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -124,18 +124,18 @@ public partial class KerasSavedModelUtils { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); + }).ToArray()); var trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.TrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); + }).ToArray()); var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); - })); - var layers = TrackableDataStructure.wrap_or_unwrap(KerasSavedModelUtils.list_all_layers(layer).Select(x => x.GetTrackable())); + }).ToArray()); + var layers = TrackableDataStructure.wrap_or_unwrap(list_all_layers(layer).Select(x => x.GetTrackable()).ToArray()); Dictionary res = new(); Debug.Assert(variables is Trackable); @@ -158,6 +158,8 @@ public partial class KerasSavedModelUtils /// public static IDictionary wrap_layer_functions(Layer layer, IDictionary> serialization_cache) { + + // high priority // TODO: deal with type `RevivedLayer` and `Sequential`. // skip the process because of lack of APIs of `Layer`. diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs index db3b782e..d7df6eb2 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs @@ -121,7 +121,10 @@ namespace Tensorflow.Keras.Saving.SavedModel } else { - throw new ValueError($"Function {key} missing from serialized function dict."); + // high priority + // TODO(Rinne): complete the implementation. + continue; + //throw new ValueError($"Function {key} missing from serialized function dict."); } } return Functions; @@ -151,7 +154,10 @@ namespace Tensorflow.Keras.Saving.SavedModel } else { - throw new ValueError($"Object {key} missing from serialized object dict."); + // high priority. + // TODO(Rinne): Add the implementation. + continue; + //throw new ValueError($"Object {key} missing from serialized object dict."); } } return CheckpointableObjects; diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs index 636b424f..0eee6904 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs @@ -13,12 +13,12 @@ using Tensorflow.Keras; namespace TensorFlowNET.Keras.UnitTest { [TestClass] - public class EarltstoppingTest + public class EarlystoppingTest { [TestMethod] // Because loading the weight variable into the model has not yet been implemented, // so you'd better not set patience too large, because the weights will equal to the last epoch's weights. - public void Earltstopping() + public void Earlystopping() { var layers = keras.layers; var model = keras.Sequential(new List