From e742105b3e6a1239bf78adfe5b6c868521a1599d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 31 Oct 2020 08:49:32 -0500 Subject: [PATCH] consolidate layer api. --- .../Keras/Engine/Functional.cs | 14 -- .../Keras/Engine/Layer.Layers.cs | 130 +----------------- .../Keras/Engine/Model.Compile.cs | 53 +++++++ .../Keras/Engine/Model.Fit.cs | 56 ++++++++ .../Keras/Engine/Model.Predict.cs | 51 +++++++ src/TensorFlowNET.Core/Keras/Engine/Model.cs | 129 ++--------------- .../Keras/Layers/LayersApi.cs | 71 +++++++++- 7 files changed, 239 insertions(+), 265 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Model.Predict.cs diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index 0a1c9464..b1223665 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -27,20 +27,6 @@ namespace Tensorflow.Keras.Engine Dictionary tensor_usage_count; public Dictionary TensorUsageCount => tensor_usage_count; - public override List trainable_variables - { - get - { - var variables = new List(); - foreach(var layer in _layers) - { - if (layer.Trainable) - variables.AddRange(layer.trainable_variables); - } - return variables; - } - } - public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs index 6c13e56a..dd809f83 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs @@ -12,136 +12,10 @@ namespace Tensorflow.Keras.Engine { protected List _layers = new List(); public List Layers => _layers; - - protected Layer Dense(int units, - Activation activation = null, - TensorShape input_shape = null) - { - var layer = new Dense(new DenseArgs - { - Units = units, - Activation = activation ?? tf.keras.activations.Linear, - InputShape = input_shape - }); - - _layers.Add(layer); - return layer; - } - - protected Layer Conv2D(int filters, - int kernel_size, - TensorShape strides = null, - string padding = "valid", - string data_format = null, - TensorShape dilation_rate = null, - int groups = 1, - Activation activation = null, - bool use_bias = true, - IInitializer kernel_initializer = null, - IInitializer bias_initializer = null, - bool trainable = true, - string name = null) - { - var layer = new Conv2D(new Conv2DArgs - { - Filters = filters, - KernelSize = kernel_size, - Strides = strides ?? (1, 1), - Padding = padding, - DataFormat = data_format, - DilationRate = dilation_rate ?? (1, 1), - Groups = groups, - Activation = activation, - UseBias = use_bias, - KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, - BiasInitializer = bias_initializer ?? tf.zeros_initializer, - Trainable = trainable, - Name = name - }); - - _layers.Add(layer); - return layer; - } - - protected Layer MaxPooling2D(TensorShape pool_size, - TensorShape strides, - string padding = "valid", - string data_format = null, - string name = null) - { - var layer = new MaxPooling2D(new MaxPooling2DArgs - { - PoolSize = pool_size, - Strides = strides, - Padding = padding, - DataFormat = data_format, - Name = name - }); - - _layers.Add(layer); - return layer; - } - - protected Layer Dropout(float rate, TensorShape noise_shape = null, int? seed = null) - { - var layer = new Dropout(new DropoutArgs - { - Rate = rate, - NoiseShape = noise_shape, - Seed = seed - }); - - _layers.Add(layer); - return layer; - } - protected Layer Flatten() + protected void StackLayers(params Layer[] layers) { - var layer = new Flatten(new FlattenArgs()); - - _layers.Add(layer); - return layer; - } - - protected Layer LSTM(int units, - Activation activation = null, - Activation recurrent_activation = null, - bool use_bias = true, - IInitializer kernel_initializer = null, - IInitializer recurrent_initializer = null, - IInitializer bias_initializer = null, - bool unit_forget_bias = true, - float dropout = 0f, - float recurrent_dropout = 0f, - int implementation = 2, - bool return_sequences = false, - bool return_state = false, - bool go_backwards = false, - bool stateful = false, - bool time_major = false, - bool unroll = false) - { - var layer = new LSTM(new LSTMArgs - { - Units = units, - Activation = activation ?? tf.keras.activations.Tanh, - RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, - KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, - RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, - BiasInitializer = bias_initializer ?? tf.zeros_initializer, - Dropout = dropout, - RecurrentDropout = recurrent_dropout, - Implementation = implementation, - ReturnSequences = return_sequences, - ReturnState = return_state, - GoBackwards = go_backwards, - Stateful = stateful, - TimeMajor = time_major, - Unroll = unroll - }); - - _layers.Add(layer); - return layer; + _layers.AddRange(layers); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs new file mode 100644 index 00000000..d9e4a0e2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public void compile(string optimizerName, ILossFunc lossName) + { + throw new NotImplementedException(""); + } + + public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) + { + this.optimizer = optimizer; + var compiled_loss = new LossesContainer(loss, output_names: output_names); + var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + // Initialize cache attrs. + _reset_compile_cache(); + _is_compiled = true; + this.loss = loss; + } + + public void compile(string optimizerName, string lossName) + { + switch (optimizerName) + { + case "rmsprop": + optimizer = new RMSprop(new RMSpropArgs + { + + }); + break; + } + + int experimental_steps_per_execution = 1; + _configure_steps_per_execution(experimental_steps_per_execution); + + _reset_compile_cache(); + + _is_compiled = true; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs new file mode 100644 index 00000000..a768a52b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs @@ -0,0 +1,56 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + /// + /// Trains the model for a fixed number of epochs (iterations on a dataset). + /// + /// + /// + /// + /// + /// + /// + /// + public void fit(NDArray x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); + var train_x = x[new Slice(0, train_count)]; + var train_y = y[new Slice(0, train_count)]; + var val_x = x[new Slice(train_count)]; + var val_y = y[new Slice(train_count)]; + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = train_x, + Y = train_y, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Predict.cs new file mode 100644 index 00000000..61188697 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Predict.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine.DataAdapters; +using Tensorflow.Keras.Utils; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + /// + /// Generates output predictions for the input samples. + /// + /// Input samples + /// Number of samples per batch + /// Verbosity mode + /// + /// Total number of steps (batches of samples) + /// before declaring the prediction round finished. + /// + /// + /// + /// + /// + public Tensor predict(Tensor x, + int batch_size = 32, + int verbose = 0, + int steps = -1, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + var data_handler = new DataHandler(new DataHandlerArgs + { + X = x, + BatchSize = batch_size, + StepsPerEpoch = steps, + InitialEpoch = 0, + Epochs = 1, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index 7d356dce..7dd2a4e7 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; using NumSharp; +using System.Collections.Generic; namespace Tensorflow.Keras.Engine { @@ -39,84 +40,6 @@ namespace Tensorflow.Keras.Engine } - public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) - { - this.optimizer = optimizer; - var compiled_loss = new LossesContainer(loss, output_names: output_names); - var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); - - int experimental_steps_per_execution = 1; - _configure_steps_per_execution(experimental_steps_per_execution); - - // Initialize cache attrs. - _reset_compile_cache(); - _is_compiled = true; - this.loss = loss; - } - - public void compile(string optimizerName, string lossName) - { - switch (optimizerName) - { - case "rmsprop": - optimizer = new RMSprop(new RMSpropArgs - { - - }); - break; - } - - int experimental_steps_per_execution = 1; - _configure_steps_per_execution(experimental_steps_per_execution); - - _reset_compile_cache(); - - _is_compiled = true; - } - - /// - /// Trains the model for a fixed number of epochs (iterations on a dataset). - /// - /// - /// - /// - /// - /// - /// - /// - public void fit(NDArray x, NDArray y, - int batch_size = -1, - int epochs = 1, - int verbose = 1, - float validation_split = 0f, - bool shuffle = true, - int initial_epoch = 0, - int max_queue_size = 10, - int workers = 1, - bool use_multiprocessing = false) - { - int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); - var train_x = x[new Slice(0, train_count)]; - var train_y = y[new Slice(0, train_count)]; - var val_x = x[new Slice(train_count)]; - var val_y = y[new Slice(train_count)]; - - var data_handler = new DataHandler(new DataHandlerArgs - { - X = train_x, - Y = train_y, - BatchSize = batch_size, - InitialEpoch = initial_epoch, - Epochs = epochs, - Shuffle = shuffle, - MaxQueueSize = max_queue_size, - Workers = workers, - UseMultiprocessing = use_multiprocessing, - Model = this, - StepsPerExecution = _steps_per_execution - }); - } - void _configure_steps_per_execution(int steps_per_execution) { _steps_per_execution = tf.Variable(steps_per_execution, @@ -145,48 +68,18 @@ namespace Tensorflow.Keras.Engine aggregation: VariableAggregation.OnlyFirstReplica); } - public void compile(string optimizerName, ILossFunc lossName) - { - throw new NotImplementedException(""); - } - - /// - /// Generates output predictions for the input samples. - /// - /// Input samples - /// Number of samples per batch - /// Verbosity mode - /// - /// Total number of steps (batches of samples) - /// before declaring the prediction round finished. - /// - /// - /// - /// - /// - public Tensor predict(Tensor x, - int batch_size = 32, - int verbose = 0, - int steps = -1, - int max_queue_size = 10, - int workers = 1, - bool use_multiprocessing = false) + public override List trainable_variables { - var data_handler = new DataHandler(new DataHandlerArgs + get { - X = x, - BatchSize = batch_size, - StepsPerEpoch = steps, - InitialEpoch = 0, - Epochs = 1, - MaxQueueSize = max_queue_size, - Workers = workers, - UseMultiprocessing = use_multiprocessing, - Model = this, - StepsPerExecution = _steps_per_execution - }); - - throw new NotImplementedException(""); + var variables = new List(); + foreach (var layer in _layers) + { + if (layer.Trainable) + variables.AddRange(layer.trainable_variables); + } + return variables; + } } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs index 98c45e15..9bd75578 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs @@ -10,6 +10,24 @@ namespace Tensorflow.Keras.Layers { public class LayersApi { + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// tf.keras.activations + /// + /// + /// + /// + /// + /// + /// public Conv2D Conv2D(int filters, TensorShape kernel_size = null, TensorShape strides = null, @@ -17,7 +35,7 @@ namespace Tensorflow.Keras.Layers string data_format = null, TensorShape dilation_rate = null, int groups = 1, - string activation = null, + Activation activation = null, bool use_bias = true, IInitializer kernel_initializer = null, IInitializer bias_initializer = null, @@ -40,20 +58,27 @@ namespace Tensorflow.Keras.Layers BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, BiasRegularizer = bias_regularizer, ActivityRegularizer = activity_regularizer, - Activation = GetActivationByName(activation) + Activation = activation ?? tf.keras.activations.Linear }); - public Dense Dense(int units, - string activation = "linear", + Activation activation = null, TensorShape input_shape = null) => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName(activation), + Activation = activation ?? tf.keras.activations.Linear, InputShape = input_shape }); + public Dropout Dropout(float rate, TensorShape noise_shape = null, int? seed = null) + => new Dropout(new DropoutArgs + { + Rate = rate, + NoiseShape = noise_shape, + Seed = seed + }); + /// /// Turns positive integers (indexes) into dense vectors of fixed size. /// This layer can only be used as the first layer in a model. @@ -121,6 +146,42 @@ namespace Tensorflow.Keras.Layers Padding = padding }); + public Layer LSTM(int units, + Activation activation = null, + Activation recurrent_activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer recurrent_initializer = null, + IInitializer bias_initializer = null, + bool unit_forget_bias = true, + float dropout = 0f, + float recurrent_dropout = 0f, + int implementation = 2, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool time_major = false, + bool unroll = false) + => new LSTM(new LSTMArgs + { + Units = units, + Activation = activation ?? tf.keras.activations.Tanh, + RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, + KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, + RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, + BiasInitializer = bias_initializer ?? tf.zeros_initializer, + Dropout = dropout, + RecurrentDropout = recurrent_dropout, + Implementation = implementation, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + TimeMajor = time_major, + Unroll = unroll + }); + public Rescaling Rescaling(float scale, float offset = 0, TensorShape input_shape = null)