| @@ -12,7 +12,7 @@ namespace Tensorflow.Hub | |||||
| public int EpochsCompleted { get; private set; } | public int EpochsCompleted { get; private set; } | ||||
| public int IndexInEpoch { get; private set; } | public int IndexInEpoch { get; private set; } | ||||
| public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
| public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) | |||||
| { | { | ||||
| EpochsCompleted = 0; | EpochsCompleted = 0; | ||||
| IndexInEpoch = 0; | IndexInEpoch = 0; | ||||
| @@ -20,11 +20,11 @@ namespace Tensorflow.Hub | |||||
| NumOfExamples = images.shape[0]; | NumOfExamples = images.shape[0]; | ||||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | ||||
| images.astype(dtype.as_numpy_datatype()); | |||||
| images.astype(dataType); | |||||
| images = np.multiply(images, 1.0f / 255.0f); | images = np.multiply(images, 1.0f / 255.0f); | ||||
| Data = images; | Data = images; | ||||
| labels.astype(dtype.as_numpy_datatype()); | |||||
| labels.astype(dataType); | |||||
| Labels = labels; | Labels = labels; | ||||
| } | } | ||||
| } | } | ||||
| @@ -81,7 +81,7 @@ namespace Tensorflow.Hub | |||||
| trainImages = trainImages[np.arange(validationSize, end)]; | trainImages = trainImages[np.arange(validationSize, end)]; | ||||
| trainLabels = trainLabels[np.arange(validationSize, end)]; | trainLabels = trainLabels[np.arange(validationSize, end)]; | ||||
| var dtype = setting.DtType; | |||||
| var dtype = setting.DataType; | |||||
| var reshape = setting.ReShape; | var reshape = setting.ReShape; | ||||
| var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); | var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); | ||||
| @@ -100,7 +100,7 @@ namespace Tensorflow.Hub | |||||
| { | { | ||||
| var magic = Read32(bytestream); | var magic = Read32(bytestream); | ||||
| if (magic != 2051) | if (magic != 2051) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | |||||
| throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); | |||||
| var num_images = Read32(bytestream); | var num_images = Read32(bytestream); | ||||
| num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | ||||
| @@ -128,7 +128,7 @@ namespace Tensorflow.Hub | |||||
| { | { | ||||
| var magic = Read32(bytestream); | var magic = Read32(bytestream); | ||||
| if (magic != 2049) | if (magic != 2049) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | |||||
| throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); | |||||
| var num_items = Read32(bytestream); | var num_items = Read32(bytestream); | ||||
| num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); | num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow.Hub | |||||
| { | { | ||||
| public string TrainDir { get; set; } | public string TrainDir { get; set; } | ||||
| public bool OneHot { get; set; } | public bool OneHot { get; set; } | ||||
| public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| public Type DataType { get; set; } = typeof(float); | |||||
| public bool ReShape { get; set; } | public bool ReShape { get; set; } | ||||
| public int ValidationSize { get; set; } = 5000; | public int ValidationSize { get; set; } = 5000; | ||||
| public int? TrainSize { get; set; } | public int? TrainSize { get; set; } | ||||
| @@ -4,6 +4,6 @@ | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||||
| <PackageReference Include="NumSharp" Version="0.10.6" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||