From 48d96f4afc407b75032e4b9bbf032a9ae080d51e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 23 Jan 2021 16:29:03 -0600 Subject: [PATCH] return weights for load_weights. --- .../Eager/EagerTensor.Creation.cs | 2 +- .../Operations/gen_resource_variable_ops.cs | 2 +- .../Sessions/BaseSession.cs | 178 +----------------- .../Tensorflow.Binding.csproj | 2 +- .../Tensors/Tensor.Creation.cs | 8 + .../Tensors/Tensor.Value.cs | 21 +-- .../Variables/BaseResourceVariable.cs | 3 + .../Engine/Model.Training.cs | 15 +- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 16 +- 9 files changed, 40 insertions(+), 207 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 9c423dce..4c550d89 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -7,7 +7,7 @@ namespace Tensorflow.Eager { public partial class EagerTensor { - public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero) + public EagerTensor(SafeTensorHandleHandle handle) { _id = ops.uid(); EagerTensorHandle = handle; diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index f190c2b4..a59dda67 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -63,7 +63,7 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "AssignVariableOp", name, null, resource, value); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index dba2b4f7..bfbe028c 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -265,182 +265,8 @@ namespace Tensorflow private static unsafe NDArray fetchValue(IntPtr output) { - NDArray ret; - using (var tensor = new Tensor(output)) - { - var ndims = tensor.shape; - var srcAddress = c_api.TF_TensorData(output).ToInt64(); - - if (ndims.Length == 0) - { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - ret = NDArray.Scalar(*(bool*)srcAddress); - break; - case TF_DataType.TF_STRING: - using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) - ret = new NDArray(reader.ReadBytes().ToByteArray()); - break; - case TF_DataType.TF_UINT8: - ret = NDArray.Scalar(*(byte*)srcAddress); - break; - case TF_DataType.TF_INT16: - ret = NDArray.Scalar(*(short*)srcAddress); - break; - case TF_DataType.TF_INT32: - ret = NDArray.Scalar(*(int*)srcAddress); - break; - case TF_DataType.TF_INT64: - ret = NDArray.Scalar(*(long*)srcAddress); - break; - case TF_DataType.TF_UINT16: - ret = NDArray.Scalar(*(ushort*)srcAddress); - break; - case TF_DataType.TF_UINT32: - ret = NDArray.Scalar(*(uint*)srcAddress); - break; - case TF_DataType.TF_UINT64: - ret = NDArray.Scalar(*(ulong*)srcAddress); - break; - case TF_DataType.TF_FLOAT: - ret = NDArray.Scalar(*(float*)srcAddress); - break; - case TF_DataType.TF_DOUBLE: - ret = NDArray.Scalar(*(double*)srcAddress); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - else - { - //var size = (long) tensor.size; - //var itemsize = (long) tensor.itemsize; - var bytesize = (long)tensor.bytesize; - var src = (void*)srcAddress; - -#if _REGEN - #region Compute - switch (tensor.dtype) - { - %foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")% - case TF_DataType.#3: - { - ret = new NDArray(NPTypeCode.#1, ndims, false); - System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize); - break; - } - % - case TF_DataType.TF_STRING: - { - //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString - using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) - ret = NDArray.FromString(reader.ReadString()); - break; - } - default: - throw new NotSupportedException(); - } - #endregion -#else - - #region Compute - - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - { - ret = new NDArray(NPTypeCode.Boolean, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT8: - { - ret = new NDArray(NPTypeCode.Byte, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT16: - { - ret = new NDArray(NPTypeCode.Int16, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT16: - { - ret = new NDArray(NPTypeCode.UInt16, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT32: - { - ret = new NDArray(NPTypeCode.Int32, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT32: - { - ret = new NDArray(NPTypeCode.UInt32, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_INT64: - { - ret = new NDArray(NPTypeCode.Int64, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_UINT64: - { - ret = new NDArray(NPTypeCode.UInt64, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_DOUBLE: - { - ret = new NDArray(NPTypeCode.Double, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_FLOAT: - { - ret = new NDArray(NPTypeCode.Single, ndims, false); - System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); - break; - } - - case TF_DataType.TF_STRING: - { - throw new NotImplementedException(); - //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString -#pragma warning disable CS0162 // Unreachable code detected - using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) -#pragma warning restore CS0162 // Unreachable code detected - ret = NDArray.FromString(reader.ReadString()); - break; - } - - default: - throw new NotSupportedException(); - } - - #endregion - -#endif - } - } - - return ret; + var tensor = new Tensor(output); + return tensor.numpy(); } /// diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 3bdbd08c..e3d7f6ae 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -82,7 +82,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works - + diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index d834de91..a2838925 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -48,6 +48,11 @@ namespace Tensorflow public IntPtr TensorDataPointer => _handle == IntPtr.Zero ? IntPtr.Zero : TF_TensorData(_handle); + public Tensor() + { + + } + /// /// Create a Tensor object from an existing TF handle /// @@ -56,6 +61,9 @@ namespace Tensorflow { _handle = handle; //no need to set AllocationType = AllocationType.None; +#if TRACK_TENSOR_LIFE + print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}"); +#endif } public Tensor(int value) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index 0ca4d269..ed9e67e6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -163,6 +163,9 @@ namespace Tensorflow break; case TF_DataType.TF_STRING: return np.array(StringBytes()[0]); + case TF_DataType.TF_UINT8: + storage = new UnmanagedStorage(NPTypeCode.Byte); + break; case TF_DataType.TF_INT32: storage = new UnmanagedStorage(NPTypeCode.Int32); break; @@ -186,23 +189,6 @@ namespace Tensorflow return new NDArray(storage); } - /*protected unsafe NDArray GetScalar(TF_DataType dtype) - { - switch(dtype) - { - case TF_DataType.TF_STRING: - return (NDArray)StringData()[0]; - case TF_DataType.TF_INT32: - return *(int*)buffer; - case TF_DataType.TF_FLOAT: - return *(float*)buffer; - case TF_DataType.TF_DOUBLE: - return *(double*)buffer; - default: - return BufferToArray(); - } - }*/ - /// /// Copies the memory of current buffer onto newly allocated array. /// @@ -210,7 +196,6 @@ namespace Tensorflow public unsafe byte[] BufferToArray() { // ReSharper disable once LocalVariableHidesMember - var bytesize = (long)this.bytesize; var data = new byte[bytesize]; fixed (byte* dst = data) System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index c5e991e5..1fc23841 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -100,6 +100,9 @@ namespace Tensorflow if (read_value) return gen_resource_variable_ops.read_variable_op(handle, dtype); + if (assign_op == null) + return null; + return assign_op; } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs index 91c4c0f6..6bf0eed9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -3,13 +3,14 @@ using System.Collections.Generic; using System.Text; using HDF.PInvoke; using HDF5CSharp; +using NumSharp; using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { public partial class Model { - public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) + public List<(IVariableV1, NDArray)> load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) { long fileId = Hdf5.OpenFile(filepath, true); @@ -17,20 +18,20 @@ namespace Tensorflow.Keras.Engine bool lsuccess = Hdf5.GroupExists(fileId, "layer_names"); if (!lsuccess && msuccess) - { fileId = H5G.open(fileId, "model_weights"); - } + if (by_name) - { //fdf5_format.load_weights_from_hdf5_group_by_name(); throw new NotImplementedException(""); - } else { - hdf5_format.load_weights_from_hdf5_group(fileId, Layers); + var weights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); + Hdf5.CloseFile(fileId); + // return a reference to prevent GC collect Variable. + return weights; } - Hdf5.CloseFile(fileId); } + public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null) { long fileId = Hdf5.CreateFile(filepath); diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index c87da45b..33e07441 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -71,18 +71,20 @@ namespace Tensorflow.Keras.Saving var target_class = layer.GetType().Name; return weights; } + public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { } + public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { } - public static void load_weights_from_hdf5_group(long f, List layers) + public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List layers) { - string original_keras_version = "2.4.0"; + string original_keras_version = "2.5.0"; string original_backend = null; if (Hdf5.AttributeExists(f, "keras_version")) { @@ -156,15 +158,19 @@ namespace Tensorflow.Keras.Saving } keras.backend.batch_set_value(weight_value_tuples); + return weight_value_tuples; } + public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { } + public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary custom_objects = null, bool compile = false) { } + public static void save_weights_to_hdf5_group(long f, List layers) { List layerName=new List(); @@ -260,8 +266,8 @@ namespace Tensorflow.Keras.Saving WriteAttrs(f, getType,name, data); } - } + private static void WriteDataset(long f, string name, Tensor data) { switch (data.dtype) @@ -283,6 +289,7 @@ namespace Tensorflow.Keras.Saving break; } } + private static void WriteAttrs(long f,string typename, string name, Array data) { switch (typename) @@ -307,6 +314,7 @@ namespace Tensorflow.Keras.Saving break; } } + private static List> Split(Array list, int chunkSize) { var splitList = new List>(); @@ -327,6 +335,7 @@ namespace Tensorflow.Keras.Saving return splitList; } + public static string[] load_attributes_from_hdf5_group(long group, string name) { if (Hdf5.AttributeExists(group, name)) @@ -337,6 +346,7 @@ namespace Tensorflow.Keras.Saving } return null; } + public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) {