| @@ -27,20 +27,6 @@ namespace Tensorflow.Keras.Engine | |||||
| Dictionary<int, int> tensor_usage_count; | Dictionary<int, int> tensor_usage_count; | ||||
| public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | ||||
| public override List<IVariableV1> trainable_variables | |||||
| { | |||||
| get | |||||
| { | |||||
| var variables = new List<IVariableV1>(); | |||||
| foreach(var layer in _layers) | |||||
| { | |||||
| if (layer.Trainable) | |||||
| variables.AddRange(layer.trainable_variables); | |||||
| } | |||||
| return variables; | |||||
| } | |||||
| } | |||||
| public Functional(Tensors inputs, Tensors outputs, string name = null) | public Functional(Tensors inputs, Tensors outputs, string name = null) | ||||
| : base(new ModelArgs | : base(new ModelArgs | ||||
| { | { | ||||
| @@ -12,136 +12,10 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| protected List<Layer> _layers = new List<Layer>(); | protected List<Layer> _layers = new List<Layer>(); | ||||
| public List<Layer> Layers => _layers; | public List<Layer> Layers => _layers; | ||||
| protected Layer Dense(int units, | |||||
| Activation activation = null, | |||||
| TensorShape input_shape = null) | |||||
| { | |||||
| var layer = new Dense(new DenseArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = activation ?? tf.keras.activations.Linear, | |||||
| InputShape = input_shape | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| protected Layer Conv2D(int filters, | |||||
| int kernel_size, | |||||
| TensorShape strides = null, | |||||
| string padding = "valid", | |||||
| string data_format = null, | |||||
| TensorShape dilation_rate = null, | |||||
| int groups = 1, | |||||
| Activation activation = null, | |||||
| bool use_bias = true, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| bool trainable = true, | |||||
| string name = null) | |||||
| { | |||||
| var layer = new Conv2D(new Conv2DArgs | |||||
| { | |||||
| Filters = filters, | |||||
| KernelSize = kernel_size, | |||||
| Strides = strides ?? (1, 1), | |||||
| Padding = padding, | |||||
| DataFormat = data_format, | |||||
| DilationRate = dilation_rate ?? (1, 1), | |||||
| Groups = groups, | |||||
| Activation = activation, | |||||
| UseBias = use_bias, | |||||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
| Trainable = trainable, | |||||
| Name = name | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| protected Layer MaxPooling2D(TensorShape pool_size, | |||||
| TensorShape strides, | |||||
| string padding = "valid", | |||||
| string data_format = null, | |||||
| string name = null) | |||||
| { | |||||
| var layer = new MaxPooling2D(new MaxPooling2DArgs | |||||
| { | |||||
| PoolSize = pool_size, | |||||
| Strides = strides, | |||||
| Padding = padding, | |||||
| DataFormat = data_format, | |||||
| Name = name | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| protected Layer Dropout(float rate, TensorShape noise_shape = null, int? seed = null) | |||||
| { | |||||
| var layer = new Dropout(new DropoutArgs | |||||
| { | |||||
| Rate = rate, | |||||
| NoiseShape = noise_shape, | |||||
| Seed = seed | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| protected Layer Flatten() | |||||
| protected void StackLayers(params Layer[] layers) | |||||
| { | { | ||||
| var layer = new Flatten(new FlattenArgs()); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| } | |||||
| protected Layer LSTM(int units, | |||||
| Activation activation = null, | |||||
| Activation recurrent_activation = null, | |||||
| bool use_bias = true, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer recurrent_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| bool unit_forget_bias = true, | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| int implementation = 2, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool time_major = false, | |||||
| bool unroll = false) | |||||
| { | |||||
| var layer = new LSTM(new LSTMArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = activation ?? tf.keras.activations.Tanh, | |||||
| RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
| RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
| Dropout = dropout, | |||||
| RecurrentDropout = recurrent_dropout, | |||||
| Implementation = implementation, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| TimeMajor = time_major, | |||||
| Unroll = unroll | |||||
| }); | |||||
| _layers.Add(layer); | |||||
| return layer; | |||||
| _layers.AddRange(layers); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,53 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using Tensorflow.Keras.Optimizers; | |||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Model | |||||
| { | |||||
| public void compile(string optimizerName, ILossFunc lossName) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||||
| { | |||||
| this.optimizer = optimizer; | |||||
| var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
| var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| // Initialize cache attrs. | |||||
| _reset_compile_cache(); | |||||
| _is_compiled = true; | |||||
| this.loss = loss; | |||||
| } | |||||
| public void compile(string optimizerName, string lossName) | |||||
| { | |||||
| switch (optimizerName) | |||||
| { | |||||
| case "rmsprop": | |||||
| optimizer = new RMSprop(new RMSpropArgs | |||||
| { | |||||
| }); | |||||
| break; | |||||
| } | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| _reset_compile_cache(); | |||||
| _is_compiled = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine.DataAdapters; | |||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Model | |||||
| { | |||||
| /// <summary> | |||||
| /// Trains the model for a fixed number of epochs (iterations on a dataset). | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="batch_size"></param> | |||||
| /// <param name="epochs"></param> | |||||
| /// <param name="verbose"></param> | |||||
| /// <param name="validation_split"></param> | |||||
| /// <param name="shuffle"></param> | |||||
| public void fit(NDArray x, NDArray y, | |||||
| int batch_size = -1, | |||||
| int epochs = 1, | |||||
| int verbose = 1, | |||||
| float validation_split = 0f, | |||||
| bool shuffle = true, | |||||
| int initial_epoch = 0, | |||||
| int max_queue_size = 10, | |||||
| int workers = 1, | |||||
| bool use_multiprocessing = false) | |||||
| { | |||||
| int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); | |||||
| var train_x = x[new Slice(0, train_count)]; | |||||
| var train_y = y[new Slice(0, train_count)]; | |||||
| var val_x = x[new Slice(train_count)]; | |||||
| var val_y = y[new Slice(train_count)]; | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| X = train_x, | |||||
| Y = train_y, | |||||
| BatchSize = batch_size, | |||||
| InitialEpoch = initial_epoch, | |||||
| Epochs = epochs, | |||||
| Shuffle = shuffle, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,51 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine.DataAdapters; | |||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public partial class Model | |||||
| { | |||||
| /// <summary> | |||||
| /// Generates output predictions for the input samples. | |||||
| /// </summary> | |||||
| /// <param name="x">Input samples</param> | |||||
| /// <param name="batch_size">Number of samples per batch</param> | |||||
| /// <param name="verbose">Verbosity mode</param> | |||||
| /// <param name="steps"> | |||||
| /// Total number of steps (batches of samples) | |||||
| /// before declaring the prediction round finished. | |||||
| /// </param> | |||||
| /// <param name="max_queue_size"></param> | |||||
| /// <param name="workers"></param> | |||||
| /// <param name="use_multiprocessing"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor predict(Tensor x, | |||||
| int batch_size = 32, | |||||
| int verbose = 0, | |||||
| int steps = -1, | |||||
| int max_queue_size = 10, | |||||
| int workers = 1, | |||||
| bool use_multiprocessing = false) | |||||
| { | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| X = x, | |||||
| BatchSize = batch_size, | |||||
| StepsPerEpoch = steps, | |||||
| InitialEpoch = 0, | |||||
| Epochs = 1, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using NumSharp; | using NumSharp; | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -39,84 +40,6 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||||
| { | |||||
| this.optimizer = optimizer; | |||||
| var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
| var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| // Initialize cache attrs. | |||||
| _reset_compile_cache(); | |||||
| _is_compiled = true; | |||||
| this.loss = loss; | |||||
| } | |||||
| public void compile(string optimizerName, string lossName) | |||||
| { | |||||
| switch (optimizerName) | |||||
| { | |||||
| case "rmsprop": | |||||
| optimizer = new RMSprop(new RMSpropArgs | |||||
| { | |||||
| }); | |||||
| break; | |||||
| } | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| _reset_compile_cache(); | |||||
| _is_compiled = true; | |||||
| } | |||||
| /// <summary> | |||||
| /// Trains the model for a fixed number of epochs (iterations on a dataset). | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="batch_size"></param> | |||||
| /// <param name="epochs"></param> | |||||
| /// <param name="verbose"></param> | |||||
| /// <param name="validation_split"></param> | |||||
| /// <param name="shuffle"></param> | |||||
| public void fit(NDArray x, NDArray y, | |||||
| int batch_size = -1, | |||||
| int epochs = 1, | |||||
| int verbose = 1, | |||||
| float validation_split = 0f, | |||||
| bool shuffle = true, | |||||
| int initial_epoch = 0, | |||||
| int max_queue_size = 10, | |||||
| int workers = 1, | |||||
| bool use_multiprocessing = false) | |||||
| { | |||||
| int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); | |||||
| var train_x = x[new Slice(0, train_count)]; | |||||
| var train_y = y[new Slice(0, train_count)]; | |||||
| var val_x = x[new Slice(train_count)]; | |||||
| var val_y = y[new Slice(train_count)]; | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| X = train_x, | |||||
| Y = train_y, | |||||
| BatchSize = batch_size, | |||||
| InitialEpoch = initial_epoch, | |||||
| Epochs = epochs, | |||||
| Shuffle = shuffle, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| } | |||||
| void _configure_steps_per_execution(int steps_per_execution) | void _configure_steps_per_execution(int steps_per_execution) | ||||
| { | { | ||||
| _steps_per_execution = tf.Variable(steps_per_execution, | _steps_per_execution = tf.Variable(steps_per_execution, | ||||
| @@ -145,48 +68,18 @@ namespace Tensorflow.Keras.Engine | |||||
| aggregation: VariableAggregation.OnlyFirstReplica); | aggregation: VariableAggregation.OnlyFirstReplica); | ||||
| } | } | ||||
| public void compile(string optimizerName, ILossFunc lossName) | |||||
| { | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | |||||
| /// Generates output predictions for the input samples. | |||||
| /// </summary> | |||||
| /// <param name="x">Input samples</param> | |||||
| /// <param name="batch_size">Number of samples per batch</param> | |||||
| /// <param name="verbose">Verbosity mode</param> | |||||
| /// <param name="steps"> | |||||
| /// Total number of steps (batches of samples) | |||||
| /// before declaring the prediction round finished. | |||||
| /// </param> | |||||
| /// <param name="max_queue_size"></param> | |||||
| /// <param name="workers"></param> | |||||
| /// <param name="use_multiprocessing"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor predict(Tensor x, | |||||
| int batch_size = 32, | |||||
| int verbose = 0, | |||||
| int steps = -1, | |||||
| int max_queue_size = 10, | |||||
| int workers = 1, | |||||
| bool use_multiprocessing = false) | |||||
| public override List<IVariableV1> trainable_variables | |||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| get | |||||
| { | { | ||||
| X = x, | |||||
| BatchSize = batch_size, | |||||
| StepsPerEpoch = steps, | |||||
| InitialEpoch = 0, | |||||
| Epochs = 1, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| throw new NotImplementedException(""); | |||||
| var variables = new List<IVariableV1>(); | |||||
| foreach (var layer in _layers) | |||||
| { | |||||
| if (layer.Trainable) | |||||
| variables.AddRange(layer.trainable_variables); | |||||
| } | |||||
| return variables; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,24 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| public class LayersApi | public class LayersApi | ||||
| { | { | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="filters"></param> | |||||
| /// <param name="kernel_size"></param> | |||||
| /// <param name="strides"></param> | |||||
| /// <param name="padding"></param> | |||||
| /// <param name="data_format"></param> | |||||
| /// <param name="dilation_rate"></param> | |||||
| /// <param name="groups"></param> | |||||
| /// <param name="activation">tf.keras.activations</param> | |||||
| /// <param name="use_bias"></param> | |||||
| /// <param name="kernel_initializer"></param> | |||||
| /// <param name="bias_initializer"></param> | |||||
| /// <param name="kernel_regularizer"></param> | |||||
| /// <param name="bias_regularizer"></param> | |||||
| /// <param name="activity_regularizer"></param> | |||||
| /// <returns></returns> | |||||
| public Conv2D Conv2D(int filters, | public Conv2D Conv2D(int filters, | ||||
| TensorShape kernel_size = null, | TensorShape kernel_size = null, | ||||
| TensorShape strides = null, | TensorShape strides = null, | ||||
| @@ -17,7 +35,7 @@ namespace Tensorflow.Keras.Layers | |||||
| string data_format = null, | string data_format = null, | ||||
| TensorShape dilation_rate = null, | TensorShape dilation_rate = null, | ||||
| int groups = 1, | int groups = 1, | ||||
| string activation = null, | |||||
| Activation activation = null, | |||||
| bool use_bias = true, | bool use_bias = true, | ||||
| IInitializer kernel_initializer = null, | IInitializer kernel_initializer = null, | ||||
| IInitializer bias_initializer = null, | IInitializer bias_initializer = null, | ||||
| @@ -40,20 +58,27 @@ namespace Tensorflow.Keras.Layers | |||||
| BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | ||||
| BiasRegularizer = bias_regularizer, | BiasRegularizer = bias_regularizer, | ||||
| ActivityRegularizer = activity_regularizer, | ActivityRegularizer = activity_regularizer, | ||||
| Activation = GetActivationByName(activation) | |||||
| Activation = activation ?? tf.keras.activations.Linear | |||||
| }); | }); | ||||
| public Dense Dense(int units, | public Dense Dense(int units, | ||||
| string activation = "linear", | |||||
| Activation activation = null, | |||||
| TensorShape input_shape = null) | TensorShape input_shape = null) | ||||
| => new Dense(new DenseArgs | => new Dense(new DenseArgs | ||||
| { | { | ||||
| Units = units, | Units = units, | ||||
| Activation = GetActivationByName(activation), | |||||
| Activation = activation ?? tf.keras.activations.Linear, | |||||
| InputShape = input_shape | InputShape = input_shape | ||||
| }); | }); | ||||
| public Dropout Dropout(float rate, TensorShape noise_shape = null, int? seed = null) | |||||
| => new Dropout(new DropoutArgs | |||||
| { | |||||
| Rate = rate, | |||||
| NoiseShape = noise_shape, | |||||
| Seed = seed | |||||
| }); | |||||
| /// <summary> | /// <summary> | ||||
| /// Turns positive integers (indexes) into dense vectors of fixed size. | /// Turns positive integers (indexes) into dense vectors of fixed size. | ||||
| /// This layer can only be used as the first layer in a model. | /// This layer can only be used as the first layer in a model. | ||||
| @@ -121,6 +146,42 @@ namespace Tensorflow.Keras.Layers | |||||
| Padding = padding | Padding = padding | ||||
| }); | }); | ||||
| public Layer LSTM(int units, | |||||
| Activation activation = null, | |||||
| Activation recurrent_activation = null, | |||||
| bool use_bias = true, | |||||
| IInitializer kernel_initializer = null, | |||||
| IInitializer recurrent_initializer = null, | |||||
| IInitializer bias_initializer = null, | |||||
| bool unit_forget_bias = true, | |||||
| float dropout = 0f, | |||||
| float recurrent_dropout = 0f, | |||||
| int implementation = 2, | |||||
| bool return_sequences = false, | |||||
| bool return_state = false, | |||||
| bool go_backwards = false, | |||||
| bool stateful = false, | |||||
| bool time_major = false, | |||||
| bool unroll = false) | |||||
| => new LSTM(new LSTMArgs | |||||
| { | |||||
| Units = units, | |||||
| Activation = activation ?? tf.keras.activations.Tanh, | |||||
| RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||||
| KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||||
| RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||||
| BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||||
| Dropout = dropout, | |||||
| RecurrentDropout = recurrent_dropout, | |||||
| Implementation = implementation, | |||||
| ReturnSequences = return_sequences, | |||||
| ReturnState = return_state, | |||||
| GoBackwards = go_backwards, | |||||
| Stateful = stateful, | |||||
| TimeMajor = time_major, | |||||
| Unroll = unroll | |||||
| }); | |||||
| public Rescaling Rescaling(float scale, | public Rescaling Rescaling(float scale, | ||||
| float offset = 0, | float offset = 0, | ||||
| TensorShape input_shape = null) | TensorShape input_shape = null) | ||||