using Tensorflow.NumPy; using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using System.Diagnostics; using Tensorflow.Keras.Callbacks; using System.Data; namespace Tensorflow.Keras.Engine { public partial class Model { /// /// Trains the model for a fixed number of epochs (iterations on a dataset). /// /// /// /// /// /// /// /// public ICallback fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, bool use_multiprocessing = false) { if (x.dims[0] != y.dims[0]) { throw new InvalidArgumentError( $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); } int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); var train_x = x[new Slice(0, train_count)]; var train_y = y[new Slice(0, train_count)]; var val_x = x[new Slice(train_count)]; var val_y = y[new Slice(train_count)]; var data_handler = new DataHandler(new DataHandlerArgs { X = train_x, Y = train_y, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, Model = this, StepsPerExecution = _steps_per_execution }); return FitInternal(data_handler, epochs, verbose); } public History fit(IDatasetV2 dataset, IDatasetV2 validation_data = null, int batch_size = -1, int epochs = 1, int verbose = 1, float validation_split = 0f, bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, bool use_multiprocessing = false) { var data_handler = new DataHandler(new DataHandlerArgs { Dataset = dataset, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, Model = this, StepsPerExecution = _steps_per_execution }); return FitInternal(data_handler, epochs, verbose, validation_data: validation_data); } History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null) { stop_training = false; _train_counter.assign(0); var callbacks = new CallbackList(new CallbackParams { Model = this, Verbose = verbose, Epochs = epochs, Steps = data_handler.Inferredsteps }); callbacks.on_train_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); callbacks.on_epoch_begin(epoch); // data_handler.catch_stop_iteration(); var logs = new Dictionary(); foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); logs = train_step_function(data_handler, iterator); var end_step = step + data_handler.StepIncrement; callbacks.on_train_batch_end(end_step, logs); } if (validation_data != null) { var val_logs = evaluate(validation_data); foreach(var log in val_logs) { logs["val_" + log.Key] = log.Value; } } callbacks.on_epoch_end(epoch, logs); GC.Collect(); GC.WaitForPendingFinalizers(); } return callbacks.History; } } }