From 2f3a17d4482f54c23a1cd62ca6c8dd4b07884771 Mon Sep 17 00:00:00 2001 From: dataangel Date: Sun, 17 Jan 2021 09:11:39 +0800 Subject: [PATCH] update:keras.save_weights or keras.load_weights --- .../Engine/Model.Training.cs | 10 +- src/TensorFlowNET.Keras/Losses/LogCosh.cs | 2 +- .../Saving/{fdf5_format.cs => hdf5_format.cs} | 180 +++++++++++++++++- 3 files changed, 181 insertions(+), 11 deletions(-) rename src/TensorFlowNET.Keras/Saving/{fdf5_format.cs => hdf5_format.cs} (54%) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs index 2ba215e8..91c4c0f6 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -27,9 +27,15 @@ namespace Tensorflow.Keras.Engine } else { - fdf5_format.load_weights_from_hdf5_group(fileId, Layers); + hdf5_format.load_weights_from_hdf5_group(fileId, Layers); } - H5G.close(fileId); + Hdf5.CloseFile(fileId); + } + public void save_weights(string filepath, bool overwrite = true, string save_format = null, object options = null) + { + long fileId = Hdf5.CreateFile(filepath); + hdf5_format.save_weights_to_hdf5_group(fileId, Layers); + Hdf5.CloseFile(fileId); } } } diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs index 6db10bc8..1c894904 100644 --- a/src/TensorFlowNET.Keras/Losses/LogCosh.cs +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Losses public LogCosh( string reduction = null, string name = null) : - base(reduction: reduction, name: name == null ? "huber" : name){ } + base(reduction: reduction, name: name == null ? "log_cosh" : name){ } public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) { diff --git a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs similarity index 54% rename from src/TensorFlowNET.Keras/Saving/fdf5_format.cs rename to src/TensorFlowNET.Keras/Saving/hdf5_format.cs index a2a9e537..04990c55 100644 --- a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -8,12 +8,12 @@ using HDF5CSharp; using static Tensorflow.Binding; using static Tensorflow.KerasApi; using System.Linq; - +using Tensorflow.Util; namespace Tensorflow.Keras.Saving { - public class fdf5_format + public class hdf5_format { - + private static int HDF5_OBJECT_HEADER_LIMIT = 64512; public static void load_model_from_hdf5(string filepath = "", Dictionary custom_objects = null, bool compile = false) { long root = Hdf5.OpenFile(filepath,true); @@ -79,10 +79,7 @@ namespace Tensorflow.Keras.Saving { } - public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) - { - } public static void load_weights_from_hdf5_group(long f, List layers) { string original_keras_version = "2.4.0"; @@ -136,9 +133,14 @@ namespace Tensorflow.Keras.Saving var weight_values = new List(); long g = H5G.open(f, name); var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + var get_Name = ""; foreach (var i_ in weight_names) { - (bool success, Array result) = Hdf5.ReadDataset(g, i_); + get_Name = i_; + if (get_Name.IndexOf("/") > 1) { + get_Name = get_Name.Split('/')[1]; + } + (bool success, Array result) = Hdf5.ReadDataset(g, get_Name); if (success) weight_values.Add(np.array(result)); } @@ -165,9 +167,171 @@ namespace Tensorflow.Keras.Saving { } - public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + public static void save_weights_to_hdf5_group(long f, List layers) + { + List layerName=new List(); + foreach (var layer in layers) + { + layerName.Add(layer.Name); + } + save_attributes_to_hdf5_group(f, "layer_names", layerName.ToArray()); + Hdf5.WriteAttribute(f, "backend", "tensorflow"); + Hdf5.WriteAttribute(f, "keras_version", "2.5.0"); + + long g = 0, crDataGroup=0; + List weights = new List(); + //List weight_values = new List(); + List weight_names = new List(); + foreach (var layer in layers) { + weight_names = new List(); + g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name)); + weights = _legacy_weights(layer); + //weight_values= keras.backend.batch_get_value(weights); + foreach (var weight in weights) + { + weight_names.Add(weight.Name); + } + save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray()); + Tensor tensor = null; + string get_Name = ""; + foreach (var (name, val) in zip(weight_names, weights)) { + get_Name = name; + tensor = val.AsTensor(); + if (get_Name.IndexOf("/") > 1) + { + get_Name = name.Split('/')[1]; + crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(get_Name)); + Hdf5.CloseGroup(crDataGroup); + } + WriteDataset(g, get_Name, tensor); + tensor = null; + } + Hdf5.CloseGroup(g); + weight_names = null; + } + weights = null; + // weight_values = null; + + + } + private static void save_attributes_to_hdf5_group(long f,string name ,Array data) + { + int num_chunks = 1; + + var chunked_data = Split(data, num_chunks); + int getSize= 0; + + string getType = data.Length>0?data.GetValue(0).GetType().Name.ToLower():"string"; + + switch (getType) + { + case "single": + getSize=sizeof(float); + break; + case "double": + getSize = sizeof(double); + break; + case "string": + getSize = -1; + break; + case "int32": + getSize = sizeof(int); + break; + case "int64": + getSize = sizeof(long); + break; + default: + getSize=-1; + break; + } + int getCount = chunked_data.Count; + + if (getSize != -1) { + num_chunks = (int)Math.Ceiling((double)(getCount * getSize) / (double)HDF5_OBJECT_HEADER_LIMIT); + if (num_chunks > 1) chunked_data = Split(data, num_chunks); + } + + if (num_chunks > 1) + { + foreach (var (chunk_id, chunk_data) in enumerate(chunked_data)) + { + + WriteAttrs(f, getType, $"{name}{chunk_id}", chunk_data.ToArray()); + + } + + } + else { + + WriteAttrs(f, getType,name, data); + + } + + } + private static void WriteDataset(long f, string name, Tensor data) + { + switch (data.dtype) + { + case TF_DataType.TF_FLOAT: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + break; + case TF_DataType.TF_DOUBLE: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + break; + case TF_DataType.TF_INT32: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + break; + case TF_DataType.TF_INT64: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + break; + default: + Hdf5.WriteDatasetFromArray(f, name, data.numpy().ToMuliDimArray()); + break; + } + } + private static void WriteAttrs(long f,string typename, string name, Array data) { + switch (typename) + { + case "single": + Hdf5.WriteAttributes(f, name, data); + break; + case "double": + Hdf5.WriteAttributes(f, name, data); + break; + case "string": + Hdf5.WriteAttributes(f, name, data); + break; + case "int32": + Hdf5.WriteAttributes(f, name, data); + break; + case "int64": + Hdf5.WriteAttributes(f, name, data); + break; + default: + Hdf5.WriteAttributes(f, name,data); + break; + } + } + private static List> Split(Array list, int chunkSize) + { + var splitList = new List>(); + var chunkCount = (int)Math.Ceiling((double)list.Length / (double)chunkSize); + + for (int c = 0; c < chunkCount; c++) + { + var skip = c * chunkSize; + var take = skip + chunkSize; + var chunk = new List(chunkSize); + + for (int e = skip; e < take && e < list.Length; e++) + { + chunk.Add(list.GetValue(e)); + } + splitList.Add(chunk); + } + return splitList; } public static string[] load_attributes_from_hdf5_group(long group, string name) {