| @@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Common | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); | |||||
| return (TF_DataType)serializer.Deserialize(reader, typeof(int)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ namespace Tensorflow.Keras | |||||
| List<IVariableV1> TrainableVariables { get; } | List<IVariableV1> TrainableVariables { get; } | ||||
| List<IVariableV1> TrainableWeights { get; } | List<IVariableV1> TrainableWeights { get; } | ||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| List<IVariableV1> Weights { get; } | |||||
| Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
| Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
| TensorShapeConfig BuildInputShape { get; } | TensorShapeConfig BuildInputShape { get; } | ||||
| @@ -71,6 +71,7 @@ namespace Tensorflow | |||||
| public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | ||||
| public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | ||||
| public List<IVariableV1> Weights => throw new NotImplementedException(); | |||||
| public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | ||||
| public Shape OutputShape => throw new NotImplementedException(); | public Shape OutputShape => throw new NotImplementedException(); | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine | |||||
| public virtual Shape ComputeOutputShape(Shape input_shape) | public virtual Shape ComputeOutputShape(Shape input_shape) | ||||
| => throw new NotImplementedException(""); | => throw new NotImplementedException(""); | ||||
| protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) | |||||
| { | |||||
| List<IVariableV1> res = new(); | |||||
| var nested_layers = _flatten_layers(false, false); | |||||
| foreach (var layer in nested_layers) | |||||
| { | |||||
| if (layer is Layer l) | |||||
| { | |||||
| if (include_trainable == true && include_non_trainable == true) | |||||
| { | |||||
| res.AddRange(l.Variables); | |||||
| } | |||||
| else if (include_trainable == true && include_non_trainable == false) | |||||
| { | |||||
| res.AddRange(l.TrainableVariables); | |||||
| } | |||||
| else if(include_trainable == false && include_non_trainable == true) | |||||
| { | |||||
| res.AddRange(l.NonTrainableVariables); | |||||
| } | |||||
| } | |||||
| } | |||||
| return res; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -67,10 +67,58 @@ namespace Tensorflow.Keras.Engine | |||||
| public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
| protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
| public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||||
| public virtual List<IVariableV1> TrainableVariables => TrainableWeights; | |||||
| protected List<IVariableV1> _non_trainable_weights; | protected List<IVariableV1> _non_trainable_weights; | ||||
| public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | |||||
| public List<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||||
| public List<IVariableV1> Variables => Weights; | |||||
| public virtual List<IVariableV1> TrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!this.Trainable) | |||||
| { | |||||
| return new List<IVariableV1>(); | |||||
| } | |||||
| var children_weights = _gather_children_variables(true); | |||||
| return children_weights.Concat(_trainable_weights).Distinct().ToList(); | |||||
| } | |||||
| } | |||||
| public virtual List<IVariableV1> NonTrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| if (!this.Trainable) | |||||
| { | |||||
| var children_weights = _gather_children_variables(true, true); | |||||
| return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); | |||||
| } | |||||
| else | |||||
| { | |||||
| var children_weights = _gather_children_variables(include_non_trainable: true); | |||||
| return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); | |||||
| } | |||||
| } | |||||
| } | |||||
| public virtual List<IVariableV1> Weights | |||||
| { | |||||
| get | |||||
| { | |||||
| return TrainableWeights.Concat(NonTrainableWeights).ToList(); | |||||
| } | |||||
| set | |||||
| { | |||||
| if (Weights.Count() != value.Count()) throw new ValueError( | |||||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||||
| $"with a weight list of length {len(value)}, but the layer was " + | |||||
| $"expecting {len(Weights)} weights."); | |||||
| foreach (var (this_w, v_w) in zip(Weights, value)) | |||||
| this_w.assign(v_w, read_value: true); | |||||
| } | |||||
| } | |||||
| protected int id; | protected int id; | ||||
| public int Id => id; | public int Id => id; | ||||
| @@ -290,46 +338,9 @@ namespace Tensorflow.Keras.Engine | |||||
| public int count_params() | public int count_params() | ||||
| { | { | ||||
| if (Trainable) | if (Trainable) | ||||
| return layer_utils.count_params(this, weights); | |||||
| return layer_utils.count_params(this, Weights); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| List<IVariableV1> ILayer.TrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| return _trainable_weights; | |||||
| } | |||||
| } | |||||
| List<IVariableV1> ILayer.NonTrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| return _non_trainable_weights; | |||||
| } | |||||
| } | |||||
| public List<IVariableV1> weights | |||||
| { | |||||
| get | |||||
| { | |||||
| var weights = new List<IVariableV1>(); | |||||
| weights.AddRange(_trainable_weights); | |||||
| weights.AddRange(_non_trainable_weights); | |||||
| return weights; | |||||
| } | |||||
| set | |||||
| { | |||||
| if (weights.Count() != value.Count()) throw new ValueError( | |||||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||||
| $"with a weight list of length {len(value)}, but the layer was " + | |||||
| $"expecting {len(weights)} weights."); | |||||
| foreach (var (this_w, v_w) in zip(weights, value)) | |||||
| this_w.assign(v_w, read_value: true); | |||||
| } | |||||
| } | |||||
| public List<IVariableV1> Variables => weights; | |||||
| public virtual IKerasConfig get_config() | public virtual IKerasConfig get_config() | ||||
| => args; | => args; | ||||
| @@ -89,10 +89,11 @@ namespace Tensorflow.Keras.Engine | |||||
| public override List<ILayer> Layers | public override List<ILayer> Layers | ||||
| => _flatten_layers(recursive: false, include_self: false).ToList(); | => _flatten_layers(recursive: false, include_self: false).ToList(); | ||||
| public override List<IVariableV1> TrainableVariables | |||||
| public override List<IVariableV1> TrainableWeights | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| // skip the assertion of weights created. | |||||
| var variables = new List<IVariableV1>(); | var variables = new List<IVariableV1>(); | ||||
| if (!Trainable) | if (!Trainable) | ||||
| @@ -103,18 +104,40 @@ namespace Tensorflow.Keras.Engine | |||||
| foreach (var trackable_obj in _self_tracked_trackables) | foreach (var trackable_obj in _self_tracked_trackables) | ||||
| { | { | ||||
| if (trackable_obj.Trainable) | if (trackable_obj.Trainable) | ||||
| variables.AddRange(trackable_obj.TrainableVariables); | |||||
| variables.AddRange(trackable_obj.TrainableWeights); | |||||
| } | } | ||||
| foreach (var layer in _self_tracked_trackables) | |||||
| variables.AddRange(_trainable_weights); | |||||
| return variables.Distinct().ToList(); | |||||
| } | |||||
| } | |||||
| public override List<IVariableV1> NonTrainableWeights | |||||
| { | |||||
| get | |||||
| { | |||||
| // skip the assertion of weights created. | |||||
| var variables = new List<IVariableV1>(); | |||||
| foreach (var trackable_obj in _self_tracked_trackables) | |||||
| { | { | ||||
| if (layer.Trainable) | |||||
| variables.AddRange(layer.TrainableVariables); | |||||
| variables.AddRange(trackable_obj.NonTrainableWeights); | |||||
| } | } | ||||
| // variables.AddRange(_trainable_weights); | |||||
| if (!Trainable) | |||||
| { | |||||
| var trainable_variables = new List<IVariableV1>(); | |||||
| foreach (var trackable_obj in _self_tracked_trackables) | |||||
| { | |||||
| variables.AddRange(trackable_obj.TrainableWeights); | |||||
| } | |||||
| variables.AddRange(trainable_variables); | |||||
| variables.AddRange(_trainable_weights); | |||||
| variables.AddRange(_non_trainable_weights); | |||||
| } | |||||
| return variables; | |||||
| return variables.Distinct().ToList(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics | |||||
| public virtual void reset_states() | public virtual void reset_states() | ||||
| { | { | ||||
| foreach (var v in weights) | |||||
| foreach (var v in Weights) | |||||
| v.assign(0); | v.assign(0); | ||||
| } | } | ||||
| @@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils | |||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| })); | })); | ||||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||||
| { | { | ||||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | if (x is ResourceVariable or RefVariable) return (Trackable)x; | ||||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | ||||
| @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils | |||||
| } | } | ||||
| var trainable_count = count_params(model, model.TrainableVariables); | var trainable_count = count_params(model, model.TrainableVariables); | ||||
| var non_trainable_count = count_params(model, model.non_trainable_variables); | |||||
| var non_trainable_count = count_params(model, model.NonTrainableVariables); | |||||
| print($"Total params: {trainable_count + non_trainable_count}"); | print($"Total params: {trainable_count + non_trainable_count}"); | ||||
| print($"Trainable params: {trainable_count}"); | print($"Trainable params: {trainable_count}"); | ||||
| @@ -21,8 +21,8 @@ public class SequentialModelLoad | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromSequential() | public void SimpleModelFromSequential() | ||||
| { | { | ||||
| new SequentialModelSave().SimpleModelFromSequential(); | |||||
| var model = keras.models.load_model(@"./pb_simple_sequential"); | |||||
| //new SequentialModelSave().SimpleModelFromSequential(); | |||||
| var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||||
| model.summary(); | model.summary(); | ||||
| @@ -40,5 +40,6 @@ public class SequentialModelLoad | |||||
| }).Result; | }).Result; | ||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
| model.summary(); | |||||
| } | } | ||||
| } | } | ||||