From 94751b1acd341308444b72dd132297673ad7f989 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 2 Mar 2023 11:28:27 +0800 Subject: [PATCH] Fix the duplicated weights in Keras.Model. --- .../Common/CustomizedDTypeJsonConverter.cs | 2 +- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 1 + .../Operations/NnOps/RNNCell.cs | 1 + .../Engine/Layer.Layers.cs | 26 ++++++ src/TensorFlowNET.Keras/Engine/Layer.cs | 91 +++++++++++-------- src/TensorFlowNET.Keras/Engine/Model.cs | 37 ++++++-- src/TensorFlowNET.Keras/Metrics/Metric.cs | 2 +- .../Saving/SavedModel/Save.cs | 2 +- src/TensorFlowNET.Keras/Utils/layer_utils.cs | 2 +- .../SaveModel/SequentialModelLoad.cs | 5 +- 10 files changed, 116 insertions(+), 53 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs index e9086ae9..110f6b25 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs @@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Common } else { - return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); + return (TF_DataType)serializer.Deserialize(reader, typeof(int)); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 03629107..20a98e3d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -19,6 +19,7 @@ namespace Tensorflow.Keras List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } + List Weights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 2b83dd1d..4e9369a8 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -71,6 +71,7 @@ namespace Tensorflow public List TrainableVariables => throw new NotImplementedException(); public List TrainableWeights => throw new NotImplementedException(); + public List Weights => throw new NotImplementedException(); public List NonTrainableWeights => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index a2d212cb..81fc2635 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; namespace Tensorflow.Keras.Engine { @@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine public virtual Shape ComputeOutputShape(Shape input_shape) => throw new NotImplementedException(""); + + protected List _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) + { + List res = new(); + var nested_layers = _flatten_layers(false, false); + foreach (var layer in nested_layers) + { + if (layer is Layer l) + { + if (include_trainable == true && include_non_trainable == true) + { + res.AddRange(l.Variables); + } + else if (include_trainable == true && include_non_trainable == false) + { + res.AddRange(l.TrainableVariables); + } + else if(include_trainable == false && include_non_trainable == true) + { + res.AddRange(l.NonTrainableVariables); + } + } + } + return res; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e54b939f..3934950b 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -67,10 +67,58 @@ namespace Tensorflow.Keras.Engine public bool SupportsMasking { get; set; } protected List _trainable_weights; - public virtual List TrainableVariables => _trainable_weights; + public virtual List TrainableVariables => TrainableWeights; protected List _non_trainable_weights; - public List non_trainable_variables => _non_trainable_weights; + public List NonTrainableVariables => NonTrainableWeights; + public List Variables => Weights; + + public virtual List TrainableWeights + { + get + { + if (!this.Trainable) + { + return new List(); + } + var children_weights = _gather_children_variables(true); + return children_weights.Concat(_trainable_weights).Distinct().ToList(); + } + } + + public virtual List NonTrainableWeights + { + get + { + if (!this.Trainable) + { + var children_weights = _gather_children_variables(true, true); + return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); + } + else + { + var children_weights = _gather_children_variables(include_non_trainable: true); + return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); + } + } + } + + public virtual List Weights + { + get + { + return TrainableWeights.Concat(NonTrainableWeights).ToList(); + } + set + { + if (Weights.Count() != value.Count()) throw new ValueError( + $"You called `set_weights` on layer \"{this.name}\"" + + $"with a weight list of length {len(value)}, but the layer was " + + $"expecting {len(Weights)} weights."); + foreach (var (this_w, v_w) in zip(Weights, value)) + this_w.assign(v_w, read_value: true); + } + } protected int id; public int Id => id; @@ -290,46 +338,9 @@ namespace Tensorflow.Keras.Engine public int count_params() { if (Trainable) - return layer_utils.count_params(this, weights); + return layer_utils.count_params(this, Weights); return 0; } - List ILayer.TrainableWeights - { - get - { - return _trainable_weights; - } - } - - List ILayer.NonTrainableWeights - { - get - { - return _non_trainable_weights; - } - } - - public List weights - { - get - { - var weights = new List(); - weights.AddRange(_trainable_weights); - weights.AddRange(_non_trainable_weights); - return weights; - } - set - { - if (weights.Count() != value.Count()) throw new ValueError( - $"You called `set_weights` on layer \"{this.name}\"" + - $"with a weight list of length {len(value)}, but the layer was " + - $"expecting {len(weights)} weights."); - foreach (var (this_w, v_w) in zip(weights, value)) - this_w.assign(v_w, read_value: true); - } - } - - public List Variables => weights; public virtual IKerasConfig get_config() => args; diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 2a2a3662..bbc6e829 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -89,10 +89,11 @@ namespace Tensorflow.Keras.Engine public override List Layers => _flatten_layers(recursive: false, include_self: false).ToList(); - public override List TrainableVariables + public override List TrainableWeights { get { + // skip the assertion of weights created. var variables = new List(); if (!Trainable) @@ -103,18 +104,40 @@ namespace Tensorflow.Keras.Engine foreach (var trackable_obj in _self_tracked_trackables) { if (trackable_obj.Trainable) - variables.AddRange(trackable_obj.TrainableVariables); + variables.AddRange(trackable_obj.TrainableWeights); } - foreach (var layer in _self_tracked_trackables) + variables.AddRange(_trainable_weights); + + return variables.Distinct().ToList(); + } + } + + public override List NonTrainableWeights + { + get + { + // skip the assertion of weights created. + var variables = new List(); + + foreach (var trackable_obj in _self_tracked_trackables) { - if (layer.Trainable) - variables.AddRange(layer.TrainableVariables); + variables.AddRange(trackable_obj.NonTrainableWeights); } - // variables.AddRange(_trainable_weights); + if (!Trainable) + { + var trainable_variables = new List(); + foreach (var trackable_obj in _self_tracked_trackables) + { + variables.AddRange(trackable_obj.TrainableWeights); + } + variables.AddRange(trainable_variables); + variables.AddRange(_trainable_weights); + variables.AddRange(_non_trainable_weights); + } - return variables; + return variables.Distinct().ToList(); } } diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs index 1dfc39c4..435eebd4 100644 --- a/src/TensorFlowNET.Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics public virtual void reset_states() { - foreach (var v in weights) + foreach (var v in Weights) v.assign(0); } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 60ca6332..220eae4b 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); })); - var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => + var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => { if (x is ResourceVariable or RefVariable) return (Trackable)x; else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs index 3c38a6d1..07d9f685 100644 --- a/src/TensorFlowNET.Keras/Utils/layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils } var trainable_count = count_params(model, model.TrainableVariables); - var non_trainable_count = count_params(model, model.non_trainable_variables); + var non_trainable_count = count_params(model, model.NonTrainableVariables); print($"Total params: {trainable_count + non_trainable_count}"); print($"Trainable params: {trainable_count}"); diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 672f8d09..1fe9e058 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -21,8 +21,8 @@ public class SequentialModelLoad [TestMethod] public void SimpleModelFromSequential() { - new SequentialModelSave().SimpleModelFromSequential(); - var model = keras.models.load_model(@"./pb_simple_sequential"); + //new SequentialModelSave().SimpleModelFromSequential(); + var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); model.summary(); @@ -40,5 +40,6 @@ public class SequentialModelLoad }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + model.summary(); } }