Browse Source

Fix the duplicated weights in Keras.Model.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
94751b1acd
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
10 changed files with 116 additions and 53 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  3. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  4. +26
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Layers.cs
  5. +51
    -40
      src/TensorFlowNET.Keras/Engine/Layer.cs
  6. +30
    -7
      src/TensorFlowNET.Keras/Engine/Model.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Metrics/Metric.cs
  8. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Utils/layer_utils.cs
  10. +3
    -2
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Common
} }
else else
{ {
return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType));
return (TF_DataType)serializer.Deserialize(reader, typeof(int));
} }
} }
} }


+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -19,6 +19,7 @@ namespace Tensorflow.Keras
List<IVariableV1> TrainableVariables { get; } List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; } List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; }
Shape OutputShape { get; } Shape OutputShape { get; }
Shape BatchInputShape { get; } Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; } TensorShapeConfig BuildInputShape { get; }


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -71,6 +71,7 @@ namespace Tensorflow


public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
public List<IVariableV1> Weights => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();


public Shape OutputShape => throw new NotImplementedException(); public Shape OutputShape => throw new NotImplementedException();


+ 26
- 0
src/TensorFlowNET.Keras/Engine/Layer.Layers.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine


public virtual Shape ComputeOutputShape(Shape input_shape) public virtual Shape ComputeOutputShape(Shape input_shape)
=> throw new NotImplementedException(""); => throw new NotImplementedException("");

protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false)
{
List<IVariableV1> res = new();
var nested_layers = _flatten_layers(false, false);
foreach (var layer in nested_layers)
{
if (layer is Layer l)
{
if (include_trainable == true && include_non_trainable == true)
{
res.AddRange(l.Variables);
}
else if (include_trainable == true && include_non_trainable == false)
{
res.AddRange(l.TrainableVariables);
}
else if(include_trainable == false && include_non_trainable == true)
{
res.AddRange(l.NonTrainableVariables);
}
}
}
return res;
}
} }
} }

+ 51
- 40
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -67,10 +67,58 @@ namespace Tensorflow.Keras.Engine
public bool SupportsMasking { get; set; } public bool SupportsMasking { get; set; }
protected List<IVariableV1> _trainable_weights; protected List<IVariableV1> _trainable_weights;


public virtual List<IVariableV1> TrainableVariables => _trainable_weights;
public virtual List<IVariableV1> TrainableVariables => TrainableWeights;


protected List<IVariableV1> _non_trainable_weights; protected List<IVariableV1> _non_trainable_weights;
public List<IVariableV1> non_trainable_variables => _non_trainable_weights;
public List<IVariableV1> NonTrainableVariables => NonTrainableWeights;
public List<IVariableV1> Variables => Weights;

public virtual List<IVariableV1> TrainableWeights
{
get
{
if (!this.Trainable)
{
return new List<IVariableV1>();
}
var children_weights = _gather_children_variables(true);
return children_weights.Concat(_trainable_weights).Distinct().ToList();
}
}

public virtual List<IVariableV1> NonTrainableWeights
{
get
{
if (!this.Trainable)
{
var children_weights = _gather_children_variables(true, true);
return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList();
}
else
{
var children_weights = _gather_children_variables(include_non_trainable: true);
return children_weights.Concat(_non_trainable_weights).Distinct().ToList();
}
}
}

public virtual List<IVariableV1> Weights
{
get
{
return TrainableWeights.Concat(NonTrainableWeights).ToList();
}
set
{
if (Weights.Count() != value.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(value)}, but the layer was " +
$"expecting {len(Weights)} weights.");
foreach (var (this_w, v_w) in zip(Weights, value))
this_w.assign(v_w, read_value: true);
}
}


protected int id; protected int id;
public int Id => id; public int Id => id;
@@ -290,46 +338,9 @@ namespace Tensorflow.Keras.Engine
public int count_params() public int count_params()
{ {
if (Trainable) if (Trainable)
return layer_utils.count_params(this, weights);
return layer_utils.count_params(this, Weights);
return 0; return 0;
} }
List<IVariableV1> ILayer.TrainableWeights
{
get
{
return _trainable_weights;
}
}

List<IVariableV1> ILayer.NonTrainableWeights
{
get
{
return _non_trainable_weights;
}
}

public List<IVariableV1> weights
{
get
{
var weights = new List<IVariableV1>();
weights.AddRange(_trainable_weights);
weights.AddRange(_non_trainable_weights);
return weights;
}
set
{
if (weights.Count() != value.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(value)}, but the layer was " +
$"expecting {len(weights)} weights.");
foreach (var (this_w, v_w) in zip(weights, value))
this_w.assign(v_w, read_value: true);
}
}

public List<IVariableV1> Variables => weights;


public virtual IKerasConfig get_config() public virtual IKerasConfig get_config()
=> args; => args;


+ 30
- 7
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -89,10 +89,11 @@ namespace Tensorflow.Keras.Engine
public override List<ILayer> Layers public override List<ILayer> Layers
=> _flatten_layers(recursive: false, include_self: false).ToList(); => _flatten_layers(recursive: false, include_self: false).ToList();


public override List<IVariableV1> TrainableVariables
public override List<IVariableV1> TrainableWeights
{ {
get get
{ {
// skip the assertion of weights created.
var variables = new List<IVariableV1>(); var variables = new List<IVariableV1>();


if (!Trainable) if (!Trainable)
@@ -103,18 +104,40 @@ namespace Tensorflow.Keras.Engine
foreach (var trackable_obj in _self_tracked_trackables) foreach (var trackable_obj in _self_tracked_trackables)
{ {
if (trackable_obj.Trainable) if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.TrainableVariables);
variables.AddRange(trackable_obj.TrainableWeights);
} }


foreach (var layer in _self_tracked_trackables)
variables.AddRange(_trainable_weights);

return variables.Distinct().ToList();
}
}

public override List<IVariableV1> NonTrainableWeights
{
get
{
// skip the assertion of weights created.
var variables = new List<IVariableV1>();

foreach (var trackable_obj in _self_tracked_trackables)
{ {
if (layer.Trainable)
variables.AddRange(layer.TrainableVariables);
variables.AddRange(trackable_obj.NonTrainableWeights);
} }


// variables.AddRange(_trainable_weights);
if (!Trainable)
{
var trainable_variables = new List<IVariableV1>();
foreach (var trackable_obj in _self_tracked_trackables)
{
variables.AddRange(trackable_obj.TrainableWeights);
}
variables.AddRange(trainable_variables);
variables.AddRange(_trainable_weights);
variables.AddRange(_non_trainable_weights);
}


return variables;
return variables.Distinct().ToList();
} }
} }




+ 1
- 1
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics


public virtual void reset_states() public virtual void reset_states()
{ {
foreach (var v in weights)
foreach (var v in Weights)
v.assign(0); v.assign(0);
} }




+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
})); }));
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x =>
{ {
if (x is ResourceVariable or RefVariable) return (Trackable)x; if (x is ResourceVariable or RefVariable) return (Trackable)x;
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");


+ 1
- 1
src/TensorFlowNET.Keras/Utils/layer_utils.cs View File

@@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils
} }


var trainable_count = count_params(model, model.TrainableVariables); var trainable_count = count_params(model, model.TrainableVariables);
var non_trainable_count = count_params(model, model.non_trainable_variables);
var non_trainable_count = count_params(model, model.NonTrainableVariables);


print($"Total params: {trainable_count + non_trainable_count}"); print($"Total params: {trainable_count + non_trainable_count}");
print($"Trainable params: {trainable_count}"); print($"Trainable params: {trainable_count}");


+ 3
- 2
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -21,8 +21,8 @@ public class SequentialModelLoad
[TestMethod] [TestMethod]
public void SimpleModelFromSequential() public void SimpleModelFromSequential()
{ {
new SequentialModelSave().SimpleModelFromSequential();
var model = keras.models.load_model(@"./pb_simple_sequential");
//new SequentialModelSave().SimpleModelFromSequential();
var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential");


model.summary(); model.summary();


@@ -40,5 +40,6 @@ public class SequentialModelLoad
}).Result; }).Result;


model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.summary();
} }
} }

Loading…
Cancel
Save