| @@ -15,6 +15,16 @@ namespace Tensorflow.Hub | |||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false) | |||||
| { | |||||
| var loader = new MnistModelLoader(); | |||||
| return await loader.LoadAsync(new ModelLoadSetting | |||||
| { | |||||
| TrainDir = trainDir, | |||||
| OneHot = oneHot | |||||
| }); | |||||
| } | |||||
| public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | ||||
| { | { | ||||
| if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) | if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) | ||||