| @@ -293,7 +293,8 @@ namespace Tensorflow | |||||
| // c_api.TF_CloseSession(handle, tf.Status.Handle); | // c_api.TF_CloseSession(handle, tf.Status.Handle); | ||||
| if (tf.Status == null || tf.Status.Handle.IsInvalid) | if (tf.Status == null || tf.Status.Handle.IsInvalid) | ||||
| { | { | ||||
| c_api.TF_DeleteSession(handle, c_api.TF_NewStatus()); | |||||
| using var status = new Status(); | |||||
| c_api.TF_DeleteSession(handle, status.Handle); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -39,5 +39,25 @@ namespace Tensorflow.Keras.Callbacks | |||||
| { | { | ||||
| callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); | ||||
| } | } | ||||
| public void on_predict_begin() | |||||
| { | |||||
| callbacks.ForEach(x => x.on_predict_begin()); | |||||
| } | |||||
| public void on_predict_batch_begin(long step) | |||||
| { | |||||
| callbacks.ForEach(x => x.on_predict_batch_begin(step)); | |||||
| } | |||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
| { | |||||
| callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs)); | |||||
| } | |||||
| public void on_predict_end() | |||||
| { | |||||
| callbacks.ForEach(x => x.on_predict_end()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -48,5 +48,26 @@ namespace Tensorflow.Keras.Callbacks | |||||
| history[log.Key].Add((float)log.Value); | history[log.Key].Add((float)log.Value); | ||||
| } | } | ||||
| } | } | ||||
| public void on_predict_begin() | |||||
| { | |||||
| epochs = new List<int>(); | |||||
| history = new Dictionary<string, List<float>>(); | |||||
| } | |||||
| public void on_predict_batch_begin(long step) | |||||
| { | |||||
| } | |||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
| { | |||||
| } | |||||
| public void on_predict_end() | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -11,5 +11,9 @@ namespace Tensorflow.Keras.Callbacks | |||||
| void on_train_batch_begin(long step); | void on_train_batch_begin(long step); | ||||
| void on_train_batch_end(long end_step, Dictionary<string, float> logs); | void on_train_batch_end(long end_step, Dictionary<string, float> logs); | ||||
| void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs); | void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs); | ||||
| void on_predict_begin(); | |||||
| void on_predict_batch_begin(long step); | |||||
| void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | |||||
| void on_predict_end(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,5 +1,4 @@ | |||||
| using PureHDF; | |||||
| using System; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -77,5 +76,26 @@ namespace Tensorflow.Keras.Callbacks | |||||
| { | { | ||||
| } | } | ||||
| public void on_predict_begin() | |||||
| { | |||||
| _reset_progbar(); | |||||
| _maybe_init_progbar(); | |||||
| } | |||||
| public void on_predict_batch_begin(long step) | |||||
| { | |||||
| } | |||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||||
| { | |||||
| } | |||||
| public void on_predict_end() | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,11 +5,70 @@ using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Keras.Callbacks; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| public partial class Model | public partial class Model | ||||
| { | { | ||||
| public Tensors predict(IDatasetV2 dataset, | |||||
| int batch_size = -1, | |||||
| int verbose = 0, | |||||
| int steps = -1, | |||||
| int max_queue_size = 10, | |||||
| int workers = 1, | |||||
| bool use_multiprocessing = false) | |||||
| { | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| Dataset = dataset, | |||||
| BatchSize = batch_size, | |||||
| StepsPerEpoch = steps, | |||||
| InitialEpoch = 0, | |||||
| Epochs = 1, | |||||
| MaxQueueSize = max_queue_size, | |||||
| Workers = workers, | |||||
| UseMultiprocessing = use_multiprocessing, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| var callbacks = new CallbackList(new CallbackParams | |||||
| { | |||||
| Model = this, | |||||
| Verbose = verbose, | |||||
| Epochs = 1, | |||||
| Steps = data_handler.Inferredsteps | |||||
| }); | |||||
| Tensor batch_outputs = null; | |||||
| _predict_counter.assign(0); | |||||
| callbacks.on_predict_begin(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| callbacks.on_predict_batch_begin(step); | |||||
| var tmp_batch_outputs = run_predict_step(iterator); | |||||
| if (batch_outputs == null) | |||||
| { | |||||
| batch_outputs = tmp_batch_outputs[0]; | |||||
| } | |||||
| else | |||||
| { | |||||
| batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); | |||||
| } | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | |||||
| } | |||||
| GC.Collect(); | |||||
| } | |||||
| callbacks.on_predict_end(); | |||||
| return batch_outputs; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Generates output predictions for the input samples. | /// Generates output predictions for the input samples. | ||||
| /// </summary> | /// </summary> | ||||