Browse Source

Add double to NDArrayConverter.

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
0ee50d319e
3 changed files with 45 additions and 53 deletions
  1. +11
    -6
      src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
  2. +34
    -46
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs
  3. +0
    -1
      src/TensorFlowNET.Keras/Engine/Model.cs

+ 11
- 6
src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs View File

@@ -14,7 +14,8 @@ namespace Tensorflow.NumPy
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data), TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data), TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data), TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
_ => throw new NotImplementedException("")
TF_DataType.TF_DOUBLE => Scalar<T>(*(double*)nd.data),
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


static T Scalar<T>(byte input) static T Scalar<T>(byte input)
@@ -23,7 +24,8 @@ namespace Tensorflow.NumPy
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


static T Scalar<T>(float input) static T Scalar<T>(float input)
@@ -32,7 +34,8 @@ namespace Tensorflow.NumPy
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


static T Scalar<T>(int input) static T Scalar<T>(int input)
@@ -41,7 +44,8 @@ namespace Tensorflow.NumPy
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64), TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


static T Scalar<T>(long input) static T Scalar<T>(long input)
@@ -50,7 +54,8 @@ namespace Tensorflow.NumPy
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
TypeCode.Double => (T)Convert.ChangeType(input, TypeCode.Double),
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
@@ -65,7 +70,7 @@ namespace Tensorflow.NumPy
T[,,,] array => Addr(array), T[,,,] array => Addr(array),
T[,,,,] array => Addr(array), T[,,,,] array => Addr(array),
T[,,,,,] array => Addr(array), T[,,,,,] array => Addr(array),
_ => throw new NotImplementedException("")
_ => throw new NotImplementedException(nameof(NDArrayConverter))
}; };


System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize); System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);


+ 34
- 46
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -1,5 +1,4 @@
using Tensorflow.NumPy;
using System;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
@@ -33,40 +32,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution StepsPerExecution = _steps_per_execution
}); });


var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = 1,
Steps = data_handler.Inferredsteps
});

Tensor batch_outputs = null;
_predict_counter.assign(0);
callbacks.on_predict_begin();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
foreach (var step in data_handler.steps())
{
callbacks.on_predict_batch_begin(step);
var tmp_batch_outputs = run_predict_step(iterator);
if (batch_outputs == null)
{
batch_outputs = tmp_batch_outputs[0];
}
else
{
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
}

var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
}
GC.Collect();
}
callbacks.on_predict_end();
return batch_outputs;
return PredictInternal(data_handler, verbose);
} }


/// <summary> /// <summary>
@@ -105,23 +71,45 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution StepsPerExecution = _steps_per_execution
}); });


Tensors outputs = null;
return PredictInternal(data_handler, verbose);
}

Tensors PredictInternal(DataHandler data_handler, int verbose)
{
var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = 1,
Steps = data_handler.Inferredsteps
});

Tensor batch_outputs = null;
_predict_counter.assign(0); _predict_counter.assign(0);
// callbacks.on_predict_begin()
callbacks.on_predict_begin();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{ {
foreach(var step in data_handler.steps())
foreach (var step in data_handler.steps())
{ {
// callbacks.on_predict_batch_begin(step)
var batch_outputs = run_predict_step(iterator);
outputs = batch_outputs;
callbacks.on_predict_batch_begin(step);
var tmp_batch_outputs = run_predict_step(iterator);
if (batch_outputs == null)
{
batch_outputs = tmp_batch_outputs[0];
}
else
{
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
}

var end_step = step + data_handler.StepIncrement; var end_step = step + data_handler.StepIncrement;
// callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
} }
GC.Collect();
} }
// callbacks.on_predict_end()
return outputs;

callbacks.on_predict_end();

return batch_outputs;
} }


Tensors run_predict_step(OwnedIterator iterator) Tensors run_predict_step(OwnedIterator iterator)


+ 0
- 1
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -36,7 +36,6 @@ namespace Tensorflow.Keras.Engine
IVariableV1 _predict_counter; IVariableV1 _predict_counter;
bool _base_model_initialized; bool _base_model_initialized;
bool stop_training; bool stop_training;
DataHandler data_handler;
public OptimizerV2 Optimizer public OptimizerV2 Optimizer
{ {


Loading…
Cancel
Save