From 9ff09c4f3d317901e4c984497bffe6bbbfc08114 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 17 Jan 2021 10:56:16 -0600 Subject: [PATCH] skip layer wihtout trainable weights when save_weights. --- .../Operations/NnOps/gen_nn_ops.cs | 24 ++++----- .../Operations/nn_impl.py.cs | 9 ++-- .../Variables/BaseResourceVariable.cs | 11 ++-- src/TensorFlowNET.Keras/Saving/hdf5_format.cs | 54 +++++++++---------- 4 files changed, 47 insertions(+), 51 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index e641ea86..e2815f81 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -318,16 +318,16 @@ namespace Tensorflow.Operations return _op.outputs; } - public static Tensor[] fused_batch_norm_v3(Tensor x, - Tensor scale, - Tensor offset, - IVariableV1 mean, - IVariableV1 variance, - float epsilon = 0.0001f, - float exponential_avg_factor = 1.0f, - string data_format = "NHWC", - bool is_training = true, - string name = null) + public static Tensors fused_batch_norm_v3(Tensor x, + IVariableV1 scale, + IVariableV1 offset, + IVariableV1 mean, + IVariableV1 variance, + float epsilon = 0.0001f, + float exponential_avg_factor = 1.0f, + string data_format = "NHWC", + bool is_training = true, + string name = null) { if (tf.executing_eagerly()) { @@ -337,8 +337,8 @@ namespace Tensorflow.Operations x, scale, offset, - mean.AsTensor(), - variance.AsTensor(), + mean, + variance, "epsilon", epsilon, "exponential_avg_factor", exponential_avg_factor, "data_format", data_format, diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 7b008e4e..1da2c252 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -107,9 +107,6 @@ namespace Tensorflow string name = null, float exponential_avg_factor = 1.0f) { - x = ops.convert_to_tensor(x, name: "input"); - var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); - var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); /*if (mean == null) mean = constant_op.constant(new float[0]); if (variance == null) @@ -118,11 +115,11 @@ namespace Tensorflow epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; var results = gen_nn_ops.fused_batch_norm_v3(x, - scale_tensor, - offset_tensor, + scale, + offset, mean, variance, - epsilon, + epsilon: epsilon, exponential_avg_factor: exponential_avg_factor, data_format: data_format, is_training: is_training, diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 4a30e060..41408335 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -163,11 +163,14 @@ namespace Tensorflow /// /// protected Tensor read_value() - => tf_with(ops.name_scope("Read"), delegate - { - var value = _read_variable_op(); - return array_ops.identity(value); + { + var value = tf_with(ops.name_scope("Read"), delegate + { + return _read_variable_op(); }); + return array_ops.identity(value); + } + public Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true) { diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index f0a3d4b1..c87da45b 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -101,26 +101,28 @@ namespace Tensorflow.Keras.Saving if (success) original_backend = attr.First(); } - List filtered_layers = new List(); - List weights; + + var filtered_layers = new List(); foreach (var layer in layers) { - weights = _legacy_weights(layer); + var weights = _legacy_weights(layer); if (weights.Count > 0) - { filtered_layers.append(layer); - } } + string[] layer_names = load_attributes_from_hdf5_group(f, "layer_names"); var filtered_layer_names = new List(); foreach(var name in layer_names) { + if (!filtered_layers.Select(x => x.Name).Contains(name)) + continue; long g = H5G.open(f, name); var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); if (weight_names.Count() > 0) filtered_layer_names.Add(name); H5G.close(g); } + layer_names = filtered_layer_names.ToArray(); if (layer_names.Length != filtered_layers.Count()) throw new ValueError("You are trying to load a weight file " + @@ -133,7 +135,6 @@ namespace Tensorflow.Keras.Saving var weight_values = new List(); long g = H5G.open(f, name); var weight_names = load_attributes_from_hdf5_group(g, "weight_names"); - var get_Name = ""; foreach (var i_ in weight_names) { (bool success, Array result) = Hdf5.ReadDataset(g, i_); @@ -153,6 +154,7 @@ namespace Tensorflow.Keras.Saving $"{weight_values.Count()} elements."); weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); } + keras.backend.batch_set_value(weight_value_tuples); } public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false) @@ -175,43 +177,37 @@ namespace Tensorflow.Keras.Saving Hdf5.WriteAttribute(f, "keras_version", "2.5.0"); long g = 0, crDataGroup=0; - List weights = new List(); - //List weight_values = new List(); - List weight_names = new List(); - foreach (var layer in layers) { - weight_names = new List(); - g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name)); - weights = _legacy_weights(layer); - //weight_values= keras.backend.batch_get_value(weights); + foreach (var layer in layers) + { + var weights = _legacy_weights(layer); + if (weights.Count == 0) + continue; + + var weight_names = new List(); + // weight_values= keras.backend.batch_get_value(weights); foreach (var weight in weights) - { weight_names.Add(weight.Name); - } + + g = Hdf5.CreateOrOpenGroup(f, Hdf5Utils.NormalizedName(layer.Name)); save_attributes_to_hdf5_group(g, "weight_names", weight_names.ToArray()); - Tensor tensor = null; - foreach (var (name, val) in zip(weight_names, weights)) { - - tensor = val.AsTensor(); + foreach (var (name, val) in zip(weight_names, weights)) + { + var tensor = val.AsTensor(); if (name.IndexOf("/") > 1) { crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0])); WriteDataset(crDataGroup, name.Split('/')[1], tensor); Hdf5.CloseGroup(crDataGroup); } - else { + else + { WriteDataset(crDataGroup, name, tensor); } - - tensor = null; - } + } Hdf5.CloseGroup(g); - weight_names = null; } - weights = null; - // weight_values = null; - - } + private static void save_attributes_to_hdf5_group(long f,string name ,Array data) { int num_chunks = 1;