| @@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Common | |||
| } | |||
| else | |||
| { | |||
| return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); | |||
| return (TF_DataType)serializer.Deserialize(reader, typeof(int)); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ namespace Tensorflow.Keras | |||
| List<IVariableV1> TrainableVariables { get; } | |||
| List<IVariableV1> TrainableWeights { get; } | |||
| List<IVariableV1> NonTrainableWeights { get; } | |||
| List<IVariableV1> Weights { get; } | |||
| Shape OutputShape { get; } | |||
| Shape BatchInputShape { get; } | |||
| TensorShapeConfig BuildInputShape { get; } | |||
| @@ -71,6 +71,7 @@ namespace Tensorflow | |||
| public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||
| public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||
| public List<IVariableV1> Weights => throw new NotImplementedException(); | |||
| public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||
| public Shape OutputShape => throw new NotImplementedException(); | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -14,5 +15,30 @@ namespace Tensorflow.Keras.Engine | |||
| public virtual Shape ComputeOutputShape(Shape input_shape) | |||
| => throw new NotImplementedException(""); | |||
| protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false) | |||
| { | |||
| List<IVariableV1> res = new(); | |||
| var nested_layers = _flatten_layers(false, false); | |||
| foreach (var layer in nested_layers) | |||
| { | |||
| if (layer is Layer l) | |||
| { | |||
| if (include_trainable == true && include_non_trainable == true) | |||
| { | |||
| res.AddRange(l.Variables); | |||
| } | |||
| else if (include_trainable == true && include_non_trainable == false) | |||
| { | |||
| res.AddRange(l.TrainableVariables); | |||
| } | |||
| else if(include_trainable == false && include_non_trainable == true) | |||
| { | |||
| res.AddRange(l.NonTrainableVariables); | |||
| } | |||
| } | |||
| } | |||
| return res; | |||
| } | |||
| } | |||
| } | |||
| @@ -67,10 +67,58 @@ namespace Tensorflow.Keras.Engine | |||
| public bool SupportsMasking { get; set; } | |||
| protected List<IVariableV1> _trainable_weights; | |||
| public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||
| public virtual List<IVariableV1> TrainableVariables => TrainableWeights; | |||
| protected List<IVariableV1> _non_trainable_weights; | |||
| public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | |||
| public List<IVariableV1> NonTrainableVariables => NonTrainableWeights; | |||
| public List<IVariableV1> Variables => Weights; | |||
| public virtual List<IVariableV1> TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| if (!this.Trainable) | |||
| { | |||
| return new List<IVariableV1>(); | |||
| } | |||
| var children_weights = _gather_children_variables(true); | |||
| return children_weights.Concat(_trainable_weights).Distinct().ToList(); | |||
| } | |||
| } | |||
| public virtual List<IVariableV1> NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| if (!this.Trainable) | |||
| { | |||
| var children_weights = _gather_children_variables(true, true); | |||
| return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList(); | |||
| } | |||
| else | |||
| { | |||
| var children_weights = _gather_children_variables(include_non_trainable: true); | |||
| return children_weights.Concat(_non_trainable_weights).Distinct().ToList(); | |||
| } | |||
| } | |||
| } | |||
| public virtual List<IVariableV1> Weights | |||
| { | |||
| get | |||
| { | |||
| return TrainableWeights.Concat(NonTrainableWeights).ToList(); | |||
| } | |||
| set | |||
| { | |||
| if (Weights.Count() != value.Count()) throw new ValueError( | |||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||
| $"with a weight list of length {len(value)}, but the layer was " + | |||
| $"expecting {len(Weights)} weights."); | |||
| foreach (var (this_w, v_w) in zip(Weights, value)) | |||
| this_w.assign(v_w, read_value: true); | |||
| } | |||
| } | |||
| protected int id; | |||
| public int Id => id; | |||
| @@ -290,46 +338,9 @@ namespace Tensorflow.Keras.Engine | |||
| public int count_params() | |||
| { | |||
| if (Trainable) | |||
| return layer_utils.count_params(this, weights); | |||
| return layer_utils.count_params(this, Weights); | |||
| return 0; | |||
| } | |||
| List<IVariableV1> ILayer.TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| return _trainable_weights; | |||
| } | |||
| } | |||
| List<IVariableV1> ILayer.NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| return _non_trainable_weights; | |||
| } | |||
| } | |||
| public List<IVariableV1> weights | |||
| { | |||
| get | |||
| { | |||
| var weights = new List<IVariableV1>(); | |||
| weights.AddRange(_trainable_weights); | |||
| weights.AddRange(_non_trainable_weights); | |||
| return weights; | |||
| } | |||
| set | |||
| { | |||
| if (weights.Count() != value.Count()) throw new ValueError( | |||
| $"You called `set_weights` on layer \"{this.name}\"" + | |||
| $"with a weight list of length {len(value)}, but the layer was " + | |||
| $"expecting {len(weights)} weights."); | |||
| foreach (var (this_w, v_w) in zip(weights, value)) | |||
| this_w.assign(v_w, read_value: true); | |||
| } | |||
| } | |||
| public List<IVariableV1> Variables => weights; | |||
| public virtual IKerasConfig get_config() | |||
| => args; | |||
| @@ -89,10 +89,11 @@ namespace Tensorflow.Keras.Engine | |||
| public override List<ILayer> Layers | |||
| => _flatten_layers(recursive: false, include_self: false).ToList(); | |||
| public override List<IVariableV1> TrainableVariables | |||
| public override List<IVariableV1> TrainableWeights | |||
| { | |||
| get | |||
| { | |||
| // skip the assertion of weights created. | |||
| var variables = new List<IVariableV1>(); | |||
| if (!Trainable) | |||
| @@ -103,18 +104,40 @@ namespace Tensorflow.Keras.Engine | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| if (trackable_obj.Trainable) | |||
| variables.AddRange(trackable_obj.TrainableVariables); | |||
| variables.AddRange(trackable_obj.TrainableWeights); | |||
| } | |||
| foreach (var layer in _self_tracked_trackables) | |||
| variables.AddRange(_trainable_weights); | |||
| return variables.Distinct().ToList(); | |||
| } | |||
| } | |||
| public override List<IVariableV1> NonTrainableWeights | |||
| { | |||
| get | |||
| { | |||
| // skip the assertion of weights created. | |||
| var variables = new List<IVariableV1>(); | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| if (layer.Trainable) | |||
| variables.AddRange(layer.TrainableVariables); | |||
| variables.AddRange(trackable_obj.NonTrainableWeights); | |||
| } | |||
| // variables.AddRange(_trainable_weights); | |||
| if (!Trainable) | |||
| { | |||
| var trainable_variables = new List<IVariableV1>(); | |||
| foreach (var trackable_obj in _self_tracked_trackables) | |||
| { | |||
| variables.AddRange(trackable_obj.TrainableWeights); | |||
| } | |||
| variables.AddRange(trainable_variables); | |||
| variables.AddRange(_trainable_weights); | |||
| variables.AddRange(_non_trainable_weights); | |||
| } | |||
| return variables; | |||
| return variables.Distinct().ToList(); | |||
| } | |||
| } | |||
| @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Metrics | |||
| public virtual void reset_states() | |||
| { | |||
| foreach (var v in weights) | |||
| foreach (var v in Weights) | |||
| v.assign(0); | |||
| } | |||
| @@ -130,7 +130,7 @@ public partial class KerasSavedModelUtils | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| })); | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x => | |||
| var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x => | |||
| { | |||
| if (x is ResourceVariable or RefVariable) return (Trackable)x; | |||
| else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer."); | |||
| @@ -104,7 +104,7 @@ namespace Tensorflow.Keras.Utils | |||
| } | |||
| var trainable_count = count_params(model, model.TrainableVariables); | |||
| var non_trainable_count = count_params(model, model.non_trainable_variables); | |||
| var non_trainable_count = count_params(model, model.NonTrainableVariables); | |||
| print($"Total params: {trainable_count + non_trainable_count}"); | |||
| print($"Trainable params: {trainable_count}"); | |||
| @@ -21,8 +21,8 @@ public class SequentialModelLoad | |||
| [TestMethod] | |||
| public void SimpleModelFromSequential() | |||
| { | |||
| new SequentialModelSave().SimpleModelFromSequential(); | |||
| var model = keras.models.load_model(@"./pb_simple_sequential"); | |||
| //new SequentialModelSave().SimpleModelFromSequential(); | |||
| var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential"); | |||
| model.summary(); | |||
| @@ -40,5 +40,6 @@ public class SequentialModelLoad | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| model.summary(); | |||
| } | |||
| } | |||