|
|
|
@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
/// <param name="callbacks"></param> |
|
|
|
/// <param name="verbose"></param> |
|
|
|
/// <param name="validation_split"></param> |
|
|
|
/// <param name="validation_data"></param> |
|
|
|
/// <param name="shuffle"></param> |
|
|
|
public ICallback fit(NDArray x, NDArray y, |
|
|
|
int batch_size = -1, |
|
|
|
@@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
int verbose = 1, |
|
|
|
List<ICallback> callbacks = null, |
|
|
|
float validation_split = 0f, |
|
|
|
(NDArray val_x, NDArray val_y)? validation_data = null, |
|
|
|
bool shuffle = true, |
|
|
|
int initial_epoch = 0, |
|
|
|
int max_queue_size = 10, |
|
|
|
@@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine |
|
|
|
throw new InvalidArgumentError( |
|
|
|
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); |
|
|
|
} |
|
|
|
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); |
|
|
|
var train_x = x[new Slice(0, train_count)]; |
|
|
|
var train_y = y[new Slice(0, train_count)]; |
|
|
|
var val_x = x[new Slice(train_count)]; |
|
|
|
var val_y = y[new Slice(train_count)]; |
|
|
|
|
|
|
|
var train_x = x; |
|
|
|
var train_y = y; |
|
|
|
|
|
|
|
if (validation_split != 0f && validation_data == null) |
|
|
|
{ |
|
|
|
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); |
|
|
|
train_x = x[new Slice(0, train_count)]; |
|
|
|
train_y = y[new Slice(0, train_count)]; |
|
|
|
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); |
|
|
|
} |
|
|
|
|
|
|
|
var data_handler = new DataHandler(new DataHandlerArgs |
|
|
|
{ |
|
|
|
@@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
StepsPerExecution = _steps_per_execution |
|
|
|
}); |
|
|
|
|
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_function); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
int verbose = 1, |
|
|
|
List<ICallback> callbacks = null, |
|
|
|
float validation_split = 0f, |
|
|
|
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, |
|
|
|
bool shuffle = true, |
|
|
|
int initial_epoch = 0, |
|
|
|
int max_queue_size = 10, |
|
|
|
@@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine |
|
|
|
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); |
|
|
|
} |
|
|
|
} |
|
|
|
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); |
|
|
|
|
|
|
|
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); |
|
|
|
var train_y = y[new Slice(0, train_count)]; |
|
|
|
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); |
|
|
|
var val_y = y[new Slice(train_count)]; |
|
|
|
|
|
|
|
var train_x = x; |
|
|
|
var train_y = y; |
|
|
|
if (validation_split != 0f && validation_data == null) |
|
|
|
{ |
|
|
|
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); |
|
|
|
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); |
|
|
|
train_y = y[new Slice(0, train_count)]; |
|
|
|
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); |
|
|
|
var val_y = y[new Slice(train_count)]; |
|
|
|
validation_data = (val_x, val_y); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
var data_handler = new DataHandler(new DataHandlerArgs |
|
|
|
{ |
|
|
|
@@ -110,29 +126,30 @@ namespace Tensorflow.Keras.Engine |
|
|
|
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || |
|
|
|
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) |
|
|
|
{ |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_multi_inputs_function); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_function); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public History fit(IDatasetV2 dataset, |
|
|
|
IDatasetV2 validation_data = null, |
|
|
|
int batch_size = -1, |
|
|
|
int epochs = 1, |
|
|
|
int verbose = 1, |
|
|
|
List<ICallback> callbacks = null, |
|
|
|
float validation_split = 0f, |
|
|
|
//float validation_split = 0f, |
|
|
|
IDatasetV2 validation_data = null, |
|
|
|
bool shuffle = true, |
|
|
|
int initial_epoch = 0, |
|
|
|
int max_queue_size = 10, |
|
|
|
int workers = 1, |
|
|
|
bool use_multiprocessing = false) |
|
|
|
{ |
|
|
|
|
|
|
|
var data_handler = new DataHandler(new DataHandlerArgs |
|
|
|
{ |
|
|
|
Dataset = dataset, |
|
|
|
@@ -146,7 +163,10 @@ namespace Tensorflow.Keras.Engine |
|
|
|
Model = this, |
|
|
|
StepsPerExecution = _steps_per_execution |
|
|
|
}); |
|
|
|
foreach( var (x,y) in dataset) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_function); |
|
|
|
} |
|
|
|
@@ -178,11 +198,13 @@ namespace Tensorflow.Keras.Engine |
|
|
|
callbacks.on_epoch_begin(epoch); |
|
|
|
// data_handler.catch_stop_iteration(); |
|
|
|
var logs = new Dictionary<string, float>(); |
|
|
|
long End_step = 0; |
|
|
|
foreach (var step in data_handler.steps()) |
|
|
|
{ |
|
|
|
callbacks.on_train_batch_begin(step); |
|
|
|
logs = train_step_func(data_handler, iterator); |
|
|
|
var end_step = step + data_handler.StepIncrement; |
|
|
|
End_step = end_step; |
|
|
|
callbacks.on_train_batch_end(end_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -193,6 +215,123 @@ namespace Tensorflow.Keras.Engine |
|
|
|
{ |
|
|
|
logs["val_" + log.Key] = log.Value; |
|
|
|
} |
|
|
|
callbacks.on_train_batch_end(End_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
callbacks.on_epoch_end(epoch, logs); |
|
|
|
|
|
|
|
GC.Collect(); |
|
|
|
GC.WaitForPendingFinalizers(); |
|
|
|
} |
|
|
|
|
|
|
|
return callbacks.History; |
|
|
|
} |
|
|
|
|
|
|
|
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, |
|
|
|
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) |
|
|
|
{ |
|
|
|
stop_training = false; |
|
|
|
_train_counter.assign(0); |
|
|
|
var callbacks = new CallbackList(new CallbackParams |
|
|
|
{ |
|
|
|
Model = this, |
|
|
|
Verbose = verbose, |
|
|
|
Epochs = epochs, |
|
|
|
Steps = data_handler.Inferredsteps |
|
|
|
}); |
|
|
|
|
|
|
|
if (callbackList != null) |
|
|
|
{ |
|
|
|
foreach (var callback in callbackList) |
|
|
|
callbacks.callbacks.add(callback); |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_train_begin(); |
|
|
|
|
|
|
|
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
|
|
|
{ |
|
|
|
reset_metrics(); |
|
|
|
callbacks.on_epoch_begin(epoch); |
|
|
|
// data_handler.catch_stop_iteration(); |
|
|
|
var logs = new Dictionary<string, float>(); |
|
|
|
long End_step = 0; |
|
|
|
foreach (var step in data_handler.steps()) |
|
|
|
{ |
|
|
|
callbacks.on_train_batch_begin(step); |
|
|
|
logs = train_step_func(data_handler, iterator); |
|
|
|
var end_step = step + data_handler.StepIncrement; |
|
|
|
End_step = end_step; |
|
|
|
callbacks.on_train_batch_end(end_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
if (validation_data != null) |
|
|
|
{ |
|
|
|
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen |
|
|
|
// so we need to pass a is_val parameter to stop on_test_batch_end |
|
|
|
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); |
|
|
|
foreach (var log in val_logs) |
|
|
|
{ |
|
|
|
logs["val_" + log.Key] = log.Value; |
|
|
|
} |
|
|
|
// because after evaluate, logs add some new log which we need to print |
|
|
|
callbacks.on_train_batch_end(End_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_epoch_end(epoch, logs); |
|
|
|
|
|
|
|
GC.Collect(); |
|
|
|
GC.WaitForPendingFinalizers(); |
|
|
|
} |
|
|
|
|
|
|
|
return callbacks.History; |
|
|
|
} |
|
|
|
|
|
|
|
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, |
|
|
|
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) |
|
|
|
{ |
|
|
|
stop_training = false; |
|
|
|
_train_counter.assign(0); |
|
|
|
var callbacks = new CallbackList(new CallbackParams |
|
|
|
{ |
|
|
|
Model = this, |
|
|
|
Verbose = verbose, |
|
|
|
Epochs = epochs, |
|
|
|
Steps = data_handler.Inferredsteps |
|
|
|
}); |
|
|
|
|
|
|
|
if (callbackList != null) |
|
|
|
{ |
|
|
|
foreach (var callback in callbackList) |
|
|
|
callbacks.callbacks.add(callback); |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_train_begin(); |
|
|
|
|
|
|
|
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
|
|
|
{ |
|
|
|
reset_metrics(); |
|
|
|
callbacks.on_epoch_begin(epoch); |
|
|
|
// data_handler.catch_stop_iteration(); |
|
|
|
var logs = new Dictionary<string, float>(); |
|
|
|
long End_step = 0; |
|
|
|
foreach (var step in data_handler.steps()) |
|
|
|
{ |
|
|
|
callbacks.on_train_batch_begin(step); |
|
|
|
logs = train_step_func(data_handler, iterator); |
|
|
|
var end_step = step + data_handler.StepIncrement; |
|
|
|
End_step = end_step; |
|
|
|
callbacks.on_train_batch_end(end_step, logs); |
|
|
|
} |
|
|
|
|
|
|
|
if (validation_data != null) |
|
|
|
{ |
|
|
|
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); |
|
|
|
foreach (var log in val_logs) |
|
|
|
{ |
|
|
|
logs["val_" + log.Key] = log.Value; |
|
|
|
callbacks.on_train_batch_end(End_step, logs); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
callbacks.on_epoch_end(epoch, logs); |
|
|
|
|