| @@ -18,6 +18,9 @@ namespace TensorFlowNET.Examples | |||||
| public int Priority => 8; | public int Priority => 8; | ||||
| public bool Enabled { get; set; } = true; | public bool Enabled { get; set; } = true; | ||||
| public string Name => "K-means Clustering"; | public string Name => "K-means Clustering"; | ||||
| public int DataSize = 5000; | |||||
| public int TestSize = 5000; | |||||
| public int BatchSize = 100; | |||||
| Datasets mnist; | Datasets mnist; | ||||
| NDArray full_data_x; | NDArray full_data_x; | ||||
| @@ -45,7 +48,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size:TestSize); | |||||
| full_data_x = mnist.train.images; | full_data_x = mnist.train.images; | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,8 +21,10 @@ namespace TensorFlowNET.Examples | |||||
| public string Name => "Logistic Regression"; | public string Name => "Logistic Regression"; | ||||
| private float learning_rate = 0.01f; | private float learning_rate = 0.01f; | ||||
| private int training_epochs = 10; | |||||
| private int batch_size = 100; | |||||
| public int TrainingEpochs = 10; | |||||
| public int DataSize = 5000; | |||||
| public int TestSize = 5000; | |||||
| public int BatchSize = 100; | |||||
| private int display_step = 1; | private int display_step = 1; | ||||
| Datasets mnist; | Datasets mnist; | ||||
| @@ -57,14 +59,14 @@ namespace TensorFlowNET.Examples | |||||
| sess.run(init); | sess.run(init); | ||||
| // Training cycle | // Training cycle | ||||
| foreach (var epoch in range(training_epochs)) | |||||
| foreach (var epoch in range(TrainingEpochs)) | |||||
| { | { | ||||
| var avg_cost = 0.0f; | var avg_cost = 0.0f; | ||||
| var total_batch = mnist.train.num_examples / batch_size; | |||||
| var total_batch = mnist.train.num_examples / BatchSize; | |||||
| // Loop over all batches | // Loop over all batches | ||||
| foreach (var i in range(total_batch)) | foreach (var i in range(total_batch)) | ||||
| { | { | ||||
| var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | |||||
| var (batch_xs, batch_ys) = mnist.train.next_batch(BatchSize); | |||||
| // Run optimization op (backprop) and cost op (to get loss value) | // Run optimization op (backprop) and cost op (to get loss value) | ||||
| var result = sess.run(new object[] { optimizer, cost }, | var result = sess.run(new object[] { optimizer, cost }, | ||||
| new FeedItem(x, batch_xs), | new FeedItem(x, batch_xs), | ||||
| @@ -96,7 +98,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize); | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -19,6 +19,8 @@ namespace TensorFlowNET.Examples | |||||
| public string Name => "Nearest Neighbor"; | public string Name => "Nearest Neighbor"; | ||||
| Datasets mnist; | Datasets mnist; | ||||
| NDArray Xtr, Ytr, Xte, Yte; | NDArray Xtr, Ytr, Xte, Yte; | ||||
| public int DataSize = 5000; | |||||
| public int TestBatchSize = 200; | |||||
| public bool Run() | public bool Run() | ||||
| { | { | ||||
| @@ -62,10 +64,10 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize); | |||||
| // In this example, we limit mnist data | // In this example, we limit mnist data | ||||
| (Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(200); // 200 for testing | |||||
| (Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; | private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; | ||||
| private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; | private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; | ||||
| public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len) | |||||
| public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null) | |||||
| { | { | ||||
| string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; | string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; | ||||
| /*if (step == "train") | /*if (step == "train") | ||||
| @@ -25,10 +25,11 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| char_dict[c.ToString()] = char_dict.Count; | char_dict[c.ToString()] = char_dict.Count; | ||||
| var contents = File.ReadAllLines(TRAIN_PATH); | var contents = File.ReadAllLines(TRAIN_PATH); | ||||
| var x = new int[contents.Length][]; | |||||
| var y = new int[contents.Length]; | |||||
| for (int i = 0; i < contents.Length; i++) | |||||
| var size = limit == null ? contents.Length : limit.Value; | |||||
| var x = new int[size][]; | |||||
| var y = new int[size]; | |||||
| for (int i = 0; i < size; i++) | |||||
| { | { | ||||
| string[] parts = contents[i].ToLower().Split(",\"").ToArray(); | string[] parts = contents[i].ToLower().Split(",\"").ToArray(); | ||||
| string content = parts[2]; | string content = parts[2]; | ||||
| @@ -17,6 +17,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public bool Enabled { get; set; }= false; | public bool Enabled { get; set; }= false; | ||||
| public string Name => "Text Classification"; | public string Name => "Text Classification"; | ||||
| public int? DataLimit = null; | |||||
| private string dataDir = "text_classification"; | private string dataDir = "text_classification"; | ||||
| private string dataFileName = "dbpedia_csv.tar.gz"; | private string dataFileName = "dbpedia_csv.tar.gz"; | ||||
| @@ -28,7 +29,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| { | { | ||||
| PrepareData(); | PrepareData(); | ||||
| Console.WriteLine("Building dataset..."); | Console.WriteLine("Building dataset..."); | ||||
| var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); | |||||
| var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit); | |||||
| var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | ||||
| @@ -21,23 +21,26 @@ namespace TensorFlowNET.Examples.Utility | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| bool reshape = true, | bool reshape = true, | ||||
| int validation_size = 5000, | int validation_size = 5000, | ||||
| int test_size = 5000, | |||||
| string source_url = DEFAULT_SOURCE_URL) | string source_url = DEFAULT_SOURCE_URL) | ||||
| { | { | ||||
| var train_size = validation_size * 2; | |||||
| Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | ||||
| var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); | |||||
| var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); | |||||
| Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | ||||
| var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
| var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size); | |||||
| Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | ||||
| var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); | |||||
| var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size); | |||||
| Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | ||||
| var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
| var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size); | |||||
| int end = train_images.shape[0]; | int end = train_images.shape[0]; | ||||
| var validation_images = train_images[np.arange(validation_size)]; | var validation_images = train_images[np.arange(validation_size)]; | ||||
| @@ -52,14 +55,15 @@ namespace TensorFlowNET.Examples.Utility | |||||
| return new Datasets(train, validation, test); | return new Datasets(train, validation, test); | ||||
| } | } | ||||
| public static NDArray extract_images(string file) | |||||
| public static NDArray extract_images(string file, int? limit=null) | |||||
| { | { | ||||
| using (var bytestream = new FileStream(file, FileMode.Open)) | using (var bytestream = new FileStream(file, FileMode.Open)) | ||||
| { | { | ||||
| var magic = _read32(bytestream); | var magic = _read32(bytestream); | ||||
| if (magic != 2051) | if (magic != 2051) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | ||||
| var num_images = _read32(bytestream); | |||||
| var num_images = _read32(bytestream); | |||||
| num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | |||||
| var rows = _read32(bytestream); | var rows = _read32(bytestream); | ||||
| var cols = _read32(bytestream); | var cols = _read32(bytestream); | ||||
| var buf = new byte[rows * cols * num_images]; | var buf = new byte[rows * cols * num_images]; | ||||
| @@ -70,7 +74,7 @@ namespace TensorFlowNET.Examples.Utility | |||||
| } | } | ||||
| } | } | ||||
| public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) | |||||
| public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) | |||||
| { | { | ||||
| using (var bytestream = new FileStream(file, FileMode.Open)) | using (var bytestream = new FileStream(file, FileMode.Open)) | ||||
| { | { | ||||
| @@ -78,6 +82,7 @@ namespace TensorFlowNET.Examples.Utility | |||||
| if (magic != 2049) | if (magic != 2049) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | ||||
| var num_items = _read32(bytestream); | var num_items = _read32(bytestream); | ||||
| num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit); | |||||
| var buf = new byte[num_items]; | var buf = new byte[num_items]; | ||||
| bytestream.Read(buf, 0, buf.Length); | bytestream.Read(buf, 0, buf.Length); | ||||
| var labels = np.frombuffer(buf, np.uint8); | var labels = np.frombuffer(buf, np.uint8); | ||||
| @@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests | |||||
| [TestMethod] | [TestMethod] | ||||
| public void LogisticRegression() | public void LogisticRegression() | ||||
| { | { | ||||
| new LogisticRegression() { Enabled = true }.Run(); | |||||
| new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run(); | |||||
| } | } | ||||
| [Ignore] | [Ignore] | ||||
| @@ -78,14 +78,14 @@ namespace TensorFlowNET.UnitTest.ExamplesTests | |||||
| [TestMethod] | [TestMethod] | ||||
| public void NearestNeighbor() | public void NearestNeighbor() | ||||
| { | { | ||||
| new NearestNeighbor() { Enabled = true }.Run(); | |||||
| new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run(); | |||||
| } | } | ||||
| [Ignore] | [Ignore] | ||||
| [TestMethod] | [TestMethod] | ||||
| public void TextClassificationTrain() | public void TextClassificationTrain() | ||||
| { | { | ||||
| new TextClassificationTrain() { Enabled = true }.Run(); | |||||
| new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); | |||||
| } | } | ||||
| [Ignore] | [Ignore] | ||||