From f1e7795c4ade5ba453f9ef97253763a5ad2e270c Mon Sep 17 00:00:00 2001 From: dataangel Date: Wed, 6 Jan 2021 19:50:48 +0800 Subject: [PATCH] update:Keras --- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 2 + .../Operations/NnOps/RNNCell.cs | 2 + src/TensorFlowNET.Keras/Engine/Layer.cs | 15 ++ .../Engine/Model.Training.cs | 56 ++++++ src/TensorFlowNET.Keras/Saving/fdf5_format.cs | 179 ++++++++++++++++++ .../Tensorflow.Keras.csproj | 9 +- 6 files changed, 258 insertions(+), 5 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Engine/Model.Training.cs create mode 100644 src/TensorFlowNET.Keras/Saving/fdf5_format.cs diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index d6bbf11a..4a48ba79 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -13,6 +13,8 @@ namespace Tensorflow.Keras List OutboundNodes { get; } Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); List trainable_variables { get; } + List trainable_weights { get; } + List non_trainable_weights { get; } TensorShape output_shape { get; } int count_params(); LayerArgs get_config(); diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index b2e1566b..edd81133 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -67,6 +67,8 @@ namespace Tensorflow public bool Trainable => throw new NotImplementedException(); public List trainable_variables => throw new NotImplementedException(); + public List trainable_weights => throw new NotImplementedException(); + public List non_trainable_weights => throw new NotImplementedException(); public TensorShape output_shape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 8daf60a2..3e78f422 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -239,6 +239,21 @@ namespace Tensorflow.Keras.Engine return layer_utils.count_params(this, weights); return 0; } + List ILayer.trainable_weights + { + get + { + return trainable_weights; + } + } + + List ILayer.non_trainable_weights + { + get + { + return non_trainable_weights; + } + } public List weights { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs new file mode 100644 index 00000000..4be5cc0d --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -0,0 +1,56 @@ +using System; +using System.Collections.Generic; +using System.Text; +using HDF.PInvoke; +using HDF5CSharp; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + private long fileId = -1; + private long f = -1; + public void load_weights(string filepath ="",bool by_name= false, bool skip_mismatch=false, object options = null) + { + long root = Hdf5.OpenFile(filepath, true); + + long fileId = root; + //try + //{ + + bool msuccess = Hdf5.GroupExists(fileId, "model_weights"); + bool lsuccess = Hdf5.GroupExists(fileId, "layer_names"); + + if (!lsuccess && msuccess) + { + f = H5G.open(fileId, "model_weights"); + + } + if (by_name) + { + //fdf5_format.load_weights_from_hdf5_group_by_name(); + } + else + { + fdf5_format.load_weights_from_hdf5_group(f, this); + } + H5G.close(f); + //} + //catch (Exception ex) + //{ + // if (fileId != -1) + // { + // Hdf5.CloseFile(fileId); + // } + // if (f != -1) + // { + // H5G.close(f); + // } + // throw new Exception(ex.ToString()); + //} + } + + } +} + diff --git a/src/TensorFlowNET.Keras/Saving/fdf5_format.cs b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs new file mode 100644 index 00000000..3a9d2438 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/fdf5_format.cs @@ -0,0 +1,179 @@ +using System; +using System.Collections.Generic; +using System.Text; +using HDF.PInvoke; +using NumSharp; +using Tensorflow.Keras.Engine; +using HDF5CSharp; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Saving +{ + public class fdf5_format + { + + public static void load_model_from_hdf5(string filepath = "", Dictionary custom_objects = null, bool compile = false) + { + long root = Hdf5.OpenFile(filepath,true); + load_model_from_hdf5(root, custom_objects, compile); + } + public static void load_model_from_hdf5(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + //long fileId = filepath; + //try + //{ + // groupId = H5G.open(fileId, "/"); + // (bool success, string[] attrId) = Hdf5.ReadStringAttributes(groupId, "model_config", ""); + // H5G.close(groupId); + // if (success == true) { + // Console.WriteLine(attrId[0]); + // } + //} + //catch (Exception ex) + //{ + // if (filepath != -1) { + // Hdf5.CloseFile(filepath); + // } + // if (groupId != -1) { + // H5G.close(groupId); + // } + // throw new Exception(ex.ToString()); + //} + + } + public static void save_model_to_hdf5(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void preprocess_weights_for_loading(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void _convert_rnn_weights(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void save_optimizer_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void save_weights_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void load_weights_from_hdf5_group(long f=-1,Model model=null) + { + string original_keras_version = "1"; + string original_backend = null; + if (Hdf5.AttributeExists(f, "keras_version")) + { + (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "keras_version", ""); + if (success) + { + original_keras_version = attr[0]; + } + } + if (Hdf5.AttributeExists(f, "backend")) + { + (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, "backend", ""); + if (success) + { + original_backend = attr[0]; + } + } + List filtered_layers = new List(); + List weights; + foreach (var layer in model.Layers) + { + weights = _legacy_weights(layer); + if (weights.Count>0) + { + filtered_layers.append(layer); + } + } + string[] layer_names = load_attributes_from_hdf5_group(f,"layer_names"); + List weight_values=new List(); + foreach (var i in filtered_layers) { + long g = H5G.open(f, i.Name); + string[] weight_names = null; + if (g != -1) + { + weight_names = load_attributes_from_hdf5_group(g, "weight_names"); + } + if (weight_names != null) + { + foreach (var i_ in weight_names) { + (bool success, Array result) = Hdf5.ReadDataset(g, i_); + // + weight_values.Add(np.array(result)); + } + } + H5G.close(g); + } + + } + public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void load_weights_from_hdf5_group_by_name(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static void save_attributes_to_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + public static string[] load_attributes_from_hdf5_group(long f = -1, string name = "") + { + if (Hdf5.AttributeExists(f, name)) + { + (bool success, string[] attr) = Hdf5.ReadStringAttributes(f, name, ""); + if (success) + { + return attr; + } + } + return null; + } + public static void load_attributes_from_hdf5_group(long filepath = -1, Dictionary custom_objects = null, bool compile = false) + { + + } + + public static List _legacy_weights(ILayer layer) + { + + List weights= new List(); + if (layer.trainable_weights.Count != 0) + { + Tensor[] trainable_weights = Array.ConvertAll(layer.trainable_weights.ToArray(), s => s.AsTensor()); + Tensor[] non_trainable_weights =null; + if (layer.non_trainable_weights.Count != 0) + { + non_trainable_weights = Array.ConvertAll(layer.non_trainable_weights.ToArray(), s => s.AsTensor()); + } + foreach (var i in trainable_weights) { + if (non_trainable_weights != null) + { + foreach (var i_ in non_trainable_weights) + { + weights.Add(i + i_); + } + } + else { + weights.Add(i); + }; + + + } + } + return weights; + } + } +} + diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 14c5719d..81e24e7e 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -24,7 +24,7 @@ Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. SciSharp STACK true - tensorflow, keras, deep learning, machine learning, scisharp + tensorflow, keras, deep learning, machine learning true Git true @@ -44,16 +44,15 @@ Keras is an API designed for human beings, not machines. Keras follows best prac + + + - - - -