diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs index 2d042a5d..c8c2d45f 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs @@ -14,7 +14,8 @@ namespace Tensorflow.NumPy TF_DataType.TF_FLOAT => Scalar(*(float*)nd.data), TF_DataType.TF_INT32 => Scalar(*(int*)nd.data), TF_DataType.TF_INT64 => Scalar(*(long*)nd.data), - _ => throw new NotImplementedException("") + TF_DataType.TF_DOUBLE => Scalar(*(double*)nd.data), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; static T Scalar(byte input) @@ -23,7 +24,8 @@ namespace Tensorflow.NumPy TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), - _ => throw new NotImplementedException("") + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; static T Scalar(float input) @@ -32,7 +34,8 @@ namespace Tensorflow.NumPy TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), - _ => throw new NotImplementedException("") + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; static T Scalar(int input) @@ -41,7 +44,8 @@ namespace Tensorflow.NumPy TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), - _ => throw new NotImplementedException("") + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; static T Scalar(long input) @@ -50,7 +54,8 @@ namespace Tensorflow.NumPy TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), - _ => throw new NotImplementedException("") + TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double), + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; public static unsafe Array ToMultiDimArray(NDArray nd) where T : unmanaged @@ -65,7 +70,7 @@ namespace Tensorflow.NumPy T[,,,] array => Addr(array), T[,,,,] array => Addr(array), T[,,,,,] array => Addr(array), - _ => throw new NotImplementedException("") + _ => throw new NotImplementedException(nameof(NDArrayConverter)) }; System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index 4d5755b0..c27ea909 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -1,5 +1,4 @@ -using Tensorflow.NumPy; -using System; +using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; @@ -33,40 +32,7 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - var callbacks = new CallbackList(new CallbackParams - { - Model = this, - Verbose = verbose, - Epochs = 1, - Steps = data_handler.Inferredsteps - }); - - Tensor batch_outputs = null; - _predict_counter.assign(0); - callbacks.on_predict_begin(); - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - foreach (var step in data_handler.steps()) - { - callbacks.on_predict_batch_begin(step); - var tmp_batch_outputs = run_predict_step(iterator); - if (batch_outputs == null) - { - batch_outputs = tmp_batch_outputs[0]; - } - else - { - batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); - } - - var end_step = step + data_handler.StepIncrement; - callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } }); - } - GC.Collect(); - } - - callbacks.on_predict_end(); - return batch_outputs; + return PredictInternal(data_handler, verbose); } /// @@ -105,23 +71,45 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - Tensors outputs = null; + return PredictInternal(data_handler, verbose); + } + + Tensors PredictInternal(DataHandler data_handler, int verbose) + { + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = 1, + Steps = data_handler.Inferredsteps + }); + + Tensor batch_outputs = null; _predict_counter.assign(0); - // callbacks.on_predict_begin() + callbacks.on_predict_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { - foreach(var step in data_handler.steps()) + foreach (var step in data_handler.steps()) { - // callbacks.on_predict_batch_begin(step) - var batch_outputs = run_predict_step(iterator); - outputs = batch_outputs; + callbacks.on_predict_batch_begin(step); + var tmp_batch_outputs = run_predict_step(iterator); + if (batch_outputs == null) + { + batch_outputs = tmp_batch_outputs[0]; + } + else + { + batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); + } + var end_step = step + data_handler.StepIncrement; - // callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) + callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } }); } - GC.Collect(); } - // callbacks.on_predict_end() - return outputs; + + callbacks.on_predict_end(); + + return batch_outputs; } Tensors run_predict_step(OwnedIterator iterator) diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index dfe5b05f..dd3e11a2 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -36,7 +36,6 @@ namespace Tensorflow.Keras.Engine IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; - DataHandler data_handler; public OptimizerV2 Optimizer {