diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs
index 65b8ebae..e0717ccb 100644
--- a/src/TensorFlowHub/MnistDataSet.cs
+++ b/src/TensorFlowHub/MnistDataSet.cs
@@ -12,7 +12,7 @@ namespace Tensorflow.Hub
public int EpochsCompleted { get; private set; }
public int IndexInEpoch { get; private set; }
- public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
+ public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
{
EpochsCompleted = 0;
IndexInEpoch = 0;
@@ -20,11 +20,11 @@ namespace Tensorflow.Hub
NumOfExamples = images.shape[0];
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
- images.astype(dtype.as_numpy_datatype());
+ images.astype(dataType);
images = np.multiply(images, 1.0f / 255.0f);
Data = images;
- labels.astype(dtype.as_numpy_datatype());
+ labels.astype(dataType);
Labels = labels;
}
}
diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs
index 670fb0ac..7c4ff109 100644
--- a/src/TensorFlowHub/MnistModelLoader.cs
+++ b/src/TensorFlowHub/MnistModelLoader.cs
@@ -81,7 +81,7 @@ namespace Tensorflow.Hub
trainImages = trainImages[np.arange(validationSize, end)];
trainLabels = trainLabels[np.arange(validationSize, end)];
- var dtype = setting.DtType;
+ var dtype = setting.DataType;
var reshape = setting.ReShape;
var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
@@ -100,7 +100,7 @@ namespace Tensorflow.Hub
{
var magic = Read32(bytestream);
if (magic != 2051)
- throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
+ throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = Read32(bytestream);
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
@@ -128,7 +128,7 @@ namespace Tensorflow.Hub
{
var magic = Read32(bytestream);
if (magic != 2049)
- throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
+ throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = Read32(bytestream);
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
diff --git a/src/TensorFlowHub/ModelLoadSetting.cs b/src/TensorFlowHub/ModelLoadSetting.cs
index 91b4059c..89e46748 100644
--- a/src/TensorFlowHub/ModelLoadSetting.cs
+++ b/src/TensorFlowHub/ModelLoadSetting.cs
@@ -9,7 +9,7 @@ namespace Tensorflow.Hub
{
public string TrainDir { get; set; }
public bool OneHot { get; set; }
- public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT;
+ public Type DataType { get; set; } = typeof(float);
public bool ReShape { get; set; }
public int ValidationSize { get; set; } = 5000;
public int? TrainSize { get; set; }
diff --git a/src/TensorFlowHub/TensorFlowHub.csproj b/src/TensorFlowHub/TensorFlowHub.csproj
index 9288ad1b..ccdcef5a 100644
--- a/src/TensorFlowHub/TensorFlowHub.csproj
+++ b/src/TensorFlowHub/TensorFlowHub.csproj
@@ -4,6 +4,6 @@
netstandard2.0
-
+