From 0060039a73b6a6f596d11126e986baebf47ac161 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 2 Mar 2023 23:24:47 +0800 Subject: [PATCH] Resolve the comments and errors. --- .../Checkpoint/CheckpointReader.cs | 28 ++++---- .../Common/CustomizedDTypeJsonConverter.cs | 5 +- .../Utils/generic_utils.cs | 66 ++++++------------- 3 files changed, 38 insertions(+), 61 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs index c6896ad7..2a8e2382 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -4,26 +4,26 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; +using Tensorflow.Util; namespace Tensorflow.Checkpoint { - public class CheckpointReader : IDisposable + public class CheckpointReader : SafeTensorflowHandle { - private IntPtr _reader; public Dictionary VariableToDataTypeMap { get; set; } public Dictionary VariableToShapeMap { get; set; } public CheckpointReader(string filename) { Status status = new Status(); - _reader = c_api.TF_NewCheckpointReader(filename, status.Handle); + handle = c_api.TF_NewCheckpointReader(filename, status.Handle); status.Check(true); ReadAllShapeAndType(); } public int HasTensor(string name) { - return c_api.TF_CheckpointReaderHasTensor(_reader, name); + return c_api.TF_CheckpointReaderHasTensor(handle, name); } /// @@ -33,17 +33,17 @@ namespace Tensorflow.Checkpoint /// public string GetVariable(int index) { - return c_api.TF_CheckpointReaderGetVariable(_reader, index); + return c_api.TF_CheckpointReaderGetVariable(handle, index); } public int Size() { - return c_api.TF_CheckpointReaderSize(_reader); + return c_api.TF_CheckpointReaderSize(handle); } public TF_DataType GetVariableDataType(string name) { - return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name); + return c_api.TF_CheckpointReaderGetVariableDataType(handle, name); } public Shape GetVariableShape(string name) @@ -52,20 +52,20 @@ namespace Tensorflow.Checkpoint int num_dims = GetVariableNumDims(name); long[] dims = new long[num_dims]; Status status = new Status(); - c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle); + c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle); status.Check(true); return new Shape(dims); } public int GetVariableNumDims(string name) { - return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); + return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name); } public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) { Status status = new Status(); - var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); + var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle); status.Check(true); var shape = GetVariableShape(name); if(dtype == TF_DataType.DtInvalid) @@ -90,9 +90,15 @@ namespace Tensorflow.Checkpoint } } + protected override bool ReleaseHandle() + { + c_api.TF_DeleteCheckpointReader(handle); + return true; + } + public void Dispose() { - c_api.TF_DeleteCheckpointReader(_reader); + c_api.TF_DeleteCheckpointReader(handle); } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs index 110f6b25..fce7bec5 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs @@ -1,8 +1,5 @@ using Newtonsoft.Json.Linq; using Newtonsoft.Json; -using System; -using System.Collections.Generic; -using System.Text; namespace Tensorflow.Keras.Common { @@ -19,7 +16,7 @@ namespace Tensorflow.Keras.Common public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) { - var token = JToken.FromObject(value); + var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value)); token.WriteTo(writer); } diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 216df0ef..03acce0c 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -22,6 +22,7 @@ using System.Collections.Generic; using System.Data; using System.Diagnostics; using System.Linq; +using System.Reflection; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; @@ -58,59 +59,32 @@ namespace Tensorflow.Keras.Utils public static Layer deserialize_keras_object(string class_name, JToken config) { - return class_name switch - { - "Sequential" => new Sequential(config.ToObject()), - "InputLayer" => new InputLayer(config.ToObject()), - "Flatten" => new Flatten(config.ToObject()), - "ELU" => new ELU(config.ToObject()), - "Dense" => new Dense(config.ToObject()), - "Softmax" => new Softmax(config.ToObject()), - "Conv2D" => new Conv2D(config.ToObject()), - "BatchNormalization" => new BatchNormalization(config.ToObject()), - "MaxPooling2D" => new MaxPooling2D(config.ToObject()), - "Dropout" => new Dropout(config.ToObject()), - _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + - $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") - }; + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + Debug.Assert(layer is Layer); + return layer as Layer; } public static Layer deserialize_keras_object(string class_name, LayerArgs args) { - return class_name switch - { - "Sequential" => new Sequential(args as SequentialArgs), - "InputLayer" => new InputLayer(args as InputLayerArgs), - "Flatten" => new Flatten(args as FlattenArgs), - "ELU" => new ELU(args as ELUArgs), - "Dense" => new Dense(args as DenseArgs), - "Softmax" => new Softmax(args as SoftmaxArgs), - "Conv2D" => new Conv2D(args as Conv2DArgs), - "BatchNormalization" => new BatchNormalization(args as BatchNormalizationArgs), - "MaxPooling2D" => new MaxPooling2D(args as MaxPooling2DArgs), - "Dropout" => new Dropout(args as DropoutArgs), - _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + - $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") - }; + var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null); + Debug.Assert(layer is Layer); + return layer as Layer; } - public static LayerArgs? deserialize_layer_args(string class_name, JToken config) + public static LayerArgs deserialize_layer_args(string class_name, JToken config) { - return class_name switch - { - "Sequential" => config.ToObject(), - "InputLayer" => config.ToObject(), - "Flatten" => config.ToObject(), - "ELU" => config.ToObject(), - "Dense" => config.ToObject(), - "Softmax" => config.ToObject(), - "Conv2D" => config.ToObject(), - "BatchNormalization" => config.ToObject(), - "MaxPooling2D" => config.ToObject(), - "Dropout" => config.ToObject(), - _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + - $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") - }; + var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args"); + var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public) + .Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0); + var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType); + var args = deserializationGenericMethod.Invoke(config, null); + Debug.Assert(args is LayerArgs); + return args as LayerArgs; } public static ModelConfig deserialize_model_config(JToken json)