| @@ -13,10 +13,10 @@ namespace Tensorflow.Keras | |||||
| List<INode> InboundNodes { get; } | List<INode> InboundNodes { get; } | ||||
| List<INode> OutboundNodes { get; } | List<INode> OutboundNodes { get; } | ||||
| Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); | Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false); | ||||
| List<IVariableV1> trainable_variables { get; } | |||||
| List<IVariableV1> trainable_weights { get; } | |||||
| List<IVariableV1> non_trainable_weights { get; } | |||||
| Shape output_shape { get; } | |||||
| List<IVariableV1> TrainableVariables { get; } | |||||
| List<IVariableV1> TrainableWeights { get; } | |||||
| List<IVariableV1> NonTrainableWeights { get; } | |||||
| Shape OutputShape { get; } | |||||
| Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
| TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
| int count_params(); | int count_params(); | ||||
| @@ -67,11 +67,11 @@ namespace Tensorflow | |||||
| public bool Trainable => throw new NotImplementedException(); | public bool Trainable => throw new NotImplementedException(); | ||||
| public List<IVariableV1> trainable_variables => throw new NotImplementedException(); | |||||
| public List<IVariableV1> trainable_weights => throw new NotImplementedException(); | |||||
| public List<IVariableV1> non_trainable_weights => throw new NotImplementedException(); | |||||
| public List<IVariableV1> TrainableVariables => throw new NotImplementedException(); | |||||
| public List<IVariableV1> TrainableWeights => throw new NotImplementedException(); | |||||
| public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException(); | |||||
| public Shape output_shape => throw new NotImplementedException(); | |||||
| public Shape OutputShape => throw new NotImplementedException(); | |||||
| public Shape BatchInputShape => throw new NotImplementedException(); | public Shape BatchInputShape => throw new NotImplementedException(); | ||||
| @@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine | |||||
| }; | }; | ||||
| var node_conversion_map = new Dictionary<string, int>(); | var node_conversion_map = new Dictionary<string, int>(); | ||||
| foreach (var layer in _layers) | |||||
| foreach (var layer in _self_tracked_trackables) | |||||
| { | { | ||||
| var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | var kept_nodes = _should_skip_first_node(layer) ? 1 : 0; | ||||
| foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| var layer_configs = new List<LayerConfig>(); | var layer_configs = new List<LayerConfig>(); | ||||
| foreach (var layer in _layers) | |||||
| foreach (var layer in _self_tracked_trackables) | |||||
| { | { | ||||
| var filtered_inbound_nodes = new List<NodeConfig>(); | var filtered_inbound_nodes = new List<NodeConfig>(); | ||||
| foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | ||||
| @@ -65,13 +65,8 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| // Keep track of the network's nodes and layers. | // Keep track of the network's nodes and layers. | ||||
| var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||||
| (NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs); | |||||
| NetworkNodes = nodes; | |||||
| NodesByDepth = nodes_by_depth; | |||||
| if (_layers.Count == 0) | |||||
| _layers = layers; | |||||
| _self_tracked_trackables = layers; | |||||
| // Build self.input_names and self.output_names. | // Build self.input_names and self.output_names. | ||||
| _set_output_names(); | _set_output_names(); | ||||
| @@ -5,8 +5,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Layer | public partial class Layer | ||||
| { | { | ||||
| protected List<ILayer> _layers = new List<ILayer>(); | |||||
| public virtual List<ILayer> Layers => _layers; | |||||
| public virtual List<ILayer> Layers => _self_tracked_trackables; | |||||
| protected void StackLayers(params ILayer[] layers) | protected void StackLayers(params ILayer[] layers) | ||||
| { | { | ||||
| @@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
| protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
| public virtual List<IVariableV1> trainable_variables => _trainable_weights; | |||||
| public virtual List<IVariableV1> TrainableVariables => _trainable_weights; | |||||
| protected List<IVariableV1> _non_trainable_weights; | protected List<IVariableV1> _non_trainable_weights; | ||||
| public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | public List<IVariableV1> non_trainable_variables => _non_trainable_weights; | ||||
| @@ -88,7 +88,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public CallContext CallContext => callContext.Value; | public CallContext CallContext => callContext.Value; | ||||
| public Tensor[] input => inboundNodes[0].input_tensors; | public Tensor[] input => inboundNodes[0].input_tensors; | ||||
| public Dictionary<int, List<INode>> NodesByDepth { get; set; } | public Dictionary<int, List<INode>> NodesByDepth { get; set; } | ||||
| public Shape output_shape => inboundNodes[0].Outputs.shape; | |||||
| public Shape OutputShape => inboundNodes[0].Outputs.shape; | |||||
| protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
| public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
| @@ -250,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||||
| return layer_utils.count_params(this, weights); | return layer_utils.count_params(this, weights); | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| List<IVariableV1> ILayer.trainable_weights | |||||
| List<IVariableV1> ILayer.TrainableWeights | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -258,7 +258,7 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| } | } | ||||
| List<IVariableV1> ILayer.non_trainable_weights | |||||
| List<IVariableV1> ILayer.NonTrainableWeights | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Engine | |||||
| // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | ||||
| // The _minimize call does a few extra steps unnecessary in most cases, | // The _minimize call does a few extra steps unnecessary in most cases, | ||||
| // such as loss scaling and gradient clipping. | // such as loss scaling and gradient clipping. | ||||
| _minimize(tape, optimizer, loss, trainable_variables); | |||||
| _minimize(tape, optimizer, loss, TrainableVariables); | |||||
| compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
| return metrics.Select(x => (x.Name, x.result())).ToList(); | return metrics.Select(x => (x.Name, x.result())).ToList(); | ||||
| @@ -74,7 +74,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public override List<ILayer> Layers | public override List<ILayer> Layers | ||||
| => _flatten_layers(recursive: false, include_self: false).ToList(); | => _flatten_layers(recursive: false, include_self: false).ToList(); | ||||
| public override List<IVariableV1> trainable_variables | |||||
| public override List<IVariableV1> TrainableVariables | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| @@ -88,13 +88,13 @@ namespace Tensorflow.Keras.Engine | |||||
| foreach (var trackable_obj in _self_tracked_trackables) | foreach (var trackable_obj in _self_tracked_trackables) | ||||
| { | { | ||||
| if (trackable_obj.Trainable) | if (trackable_obj.Trainable) | ||||
| variables.AddRange(trackable_obj.trainable_variables); | |||||
| variables.AddRange(trackable_obj.TrainableVariables); | |||||
| } | } | ||||
| foreach (var layer in _layers) | |||||
| foreach (var layer in _self_tracked_trackables) | |||||
| { | { | ||||
| if (layer.Trainable) | if (layer.Trainable) | ||||
| variables.AddRange(layer.trainable_variables); | |||||
| variables.AddRange(layer.TrainableVariables); | |||||
| } | } | ||||
| // variables.AddRange(_trainable_weights); | // variables.AddRange(_trainable_weights); | ||||
| @@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| built = false; | built = false; | ||||
| var set_inputs = false; | var set_inputs = false; | ||||
| if (_layers.Count == 0) | |||||
| if (_self_tracked_trackables.Count == 0) | |||||
| { | { | ||||
| if (layer is InputLayer) | if (layer is InputLayer) | ||||
| { | { | ||||
| @@ -128,7 +128,7 @@ namespace Tensorflow.Keras.Engine | |||||
| void _handle_deferred_layer_dependencies(params ILayer[] layers) | void _handle_deferred_layer_dependencies(params ILayer[] layers) | ||||
| { | { | ||||
| _layers.AddRange(layers); | |||||
| _self_tracked_trackables.AddRange(layers); | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -156,12 +156,12 @@ namespace Tensorflow.Keras.Engine | |||||
| ops.init_scope(); | ops.init_scope(); | ||||
| var inputs = keras.Input(batch_input_shape: input_shape, | var inputs = keras.Input(batch_input_shape: input_shape, | ||||
| dtype: input_dtype, | dtype: input_dtype, | ||||
| name: $"{_layers[0].Name}_input"); | |||||
| name: $"{_self_tracked_trackables[0].Name}_input"); | |||||
| Tensors layer_input = inputs; | Tensors layer_input = inputs; | ||||
| Tensors layer_output = null; | Tensors layer_output = null; | ||||
| Tensors outputs = null; | Tensors outputs = null; | ||||
| List<INode> created_nodes = new List<INode>(); | List<INode> created_nodes = new List<INode>(); | ||||
| foreach (var layer in _layers) | |||||
| foreach (var layer in _self_tracked_trackables) | |||||
| { | { | ||||
| clear_previously_created_nodes(layer, _created_nodes); | clear_previously_created_nodes(layer, _created_nodes); | ||||
| layer_output = layer.Apply(layer_input); | layer_output = layer.Apply(layer_input); | ||||
| @@ -338,8 +338,8 @@ namespace Tensorflow.Keras.Saving | |||||
| public static List<IVariableV1> _legacy_weights(ILayer layer) | public static List<IVariableV1> _legacy_weights(ILayer layer) | ||||
| { | { | ||||
| var weights = layer.trainable_weights.Select(x => x).ToList(); | |||||
| weights.AddRange(layer.non_trainable_weights); | |||||
| var weights = layer.TrainableWeights.Select(x => x).ToList(); | |||||
| weights.AddRange(layer.NonTrainableWeights); | |||||
| return weights; | return weights; | ||||
| } | } | ||||
| } | } | ||||
| @@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Utils | |||||
| print(string.Join("", range(line_length).Select(x => "_"))); | print(string.Join("", range(line_length).Select(x => "_"))); | ||||
| } | } | ||||
| var trainable_count = count_params(model, model.trainable_variables); | |||||
| var trainable_count = count_params(model, model.TrainableVariables); | |||||
| var non_trainable_count = count_params(model, model.non_trainable_variables); | var non_trainable_count = count_params(model, model.non_trainable_variables); | ||||
| print($"Total params: {trainable_count + non_trainable_count}"); | print($"Total params: {trainable_count + non_trainable_count}"); | ||||
| @@ -137,7 +137,7 @@ namespace Tensorflow.Keras.Utils | |||||
| var fields = new string[] | var fields = new string[] | ||||
| { | { | ||||
| $"{name} ({layer.GetType().Name})", | $"{name} ({layer.GetType().Name})", | ||||
| $"{layer.output_shape}", | |||||
| $"{layer.OutputShape}", | |||||
| $"{layer.count_params()}" | $"{layer.count_params()}" | ||||
| }; | }; | ||||
| @@ -164,7 +164,7 @@ namespace Tensorflow.Keras.Utils | |||||
| var fields = new string[] | var fields = new string[] | ||||
| { | { | ||||
| $"{name}({layer.GetType().Name})", | $"{name}({layer.GetType().Name})", | ||||
| $"{layer.output_shape}", | |||||
| $"{layer.OutputShape}", | |||||
| $"{layer.count_params()}", | $"{layer.count_params()}", | ||||
| first_connection | first_connection | ||||
| }; | }; | ||||