unit test suite: added lightweight version of all examplestags/v0.9
| @@ -46,7 +46,7 @@ namespace Tensorflow | |||||
| catch (Exception ex) | catch (Exception ex) | ||||
| { | { | ||||
| Console.WriteLine(ex.ToString()); | Console.WriteLine(ex.ToString()); | ||||
| throw ex; | |||||
| throw; | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -65,7 +65,7 @@ namespace Tensorflow | |||||
| catch (Exception ex) | catch (Exception ex) | ||||
| { | { | ||||
| Console.WriteLine(ex.ToString()); | Console.WriteLine(ex.ToString()); | ||||
| throw ex; | |||||
| throw; | |||||
| } | } | ||||
| finally | finally | ||||
| { | { | ||||
| @@ -83,10 +83,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| catch (Exception ex) | catch (Exception ex) | ||||
| { | { | ||||
| Console.WriteLine(ex.ToString()); | |||||
| #if DEBUG | |||||
| Debugger.Break(); | |||||
| #endif | |||||
| Console.WriteLine(ex.ToString()); | |||||
| throw; | |||||
| return default(TOut); | return default(TOut); | ||||
| } | } | ||||
| finally | finally | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| public class BasicEagerApi : IExample | public class BasicEagerApi : IExample | ||||
| { | { | ||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; } = false; | |||||
| public string Name => "Basic Eager"; | public string Name => "Basic Eager"; | ||||
| private Tensor a, b, c, d; | private Tensor a, b, c, d; | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| /// </summary> | /// </summary> | ||||
| public class BasicOperations : Python, IExample | public class BasicOperations : Python, IExample | ||||
| { | { | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public int Priority => 2; | public int Priority => 2; | ||||
| public string Name => "Basic Operations"; | public string Name => "Basic Operations"; | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| public class HelloWorld : Python, IExample | public class HelloWorld : Python, IExample | ||||
| { | { | ||||
| public int Priority => 1; | public int Priority => 1; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Hello World"; | public string Name => "Hello World"; | ||||
| public bool Run() | public bool Run() | ||||
| @@ -17,7 +17,7 @@ namespace TensorFlowNET.Examples | |||||
| /// <summary> | /// <summary> | ||||
| /// True to run example | /// True to run example | ||||
| /// </summary> | /// </summary> | ||||
| bool Enabled { get; } | |||||
| bool Enabled { get; set; } | |||||
| string Name { get; } | string Name { get; } | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples | |||||
| public class ImageRecognition : Python, IExample | public class ImageRecognition : Python, IExample | ||||
| { | { | ||||
| public int Priority => 7; | public int Priority => 7; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Image Recognition"; | public string Name => "Image Recognition"; | ||||
| string dir = "ImageRecognition"; | string dir = "ImageRecognition"; | ||||
| @@ -19,7 +19,7 @@ namespace TensorFlowNET.Examples | |||||
| /// </summary> | /// </summary> | ||||
| public class InceptionArchGoogLeNet : Python, IExample | public class InceptionArchGoogLeNet : Python, IExample | ||||
| { | { | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; } = false; | |||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public string Name => "Inception Arch GoogLeNet"; | public string Name => "Inception Arch GoogLeNet"; | ||||
| @@ -16,13 +16,17 @@ namespace TensorFlowNET.Examples | |||||
| public class KMeansClustering : Python, IExample | public class KMeansClustering : Python, IExample | ||||
| { | { | ||||
| public int Priority => 8; | public int Priority => 8; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "K-means Clustering"; | public string Name => "K-means Clustering"; | ||||
| public int? train_size = null; | |||||
| public int validation_size = 5000; | |||||
| public int? test_size = null; | |||||
| public int batch_size = 1024; // The number of samples per batch | |||||
| Datasets mnist; | Datasets mnist; | ||||
| NDArray full_data_x; | NDArray full_data_x; | ||||
| int num_steps = 50; // Total steps to train | int num_steps = 50; // Total steps to train | ||||
| int batch_size = 1024; // The number of samples per batch | |||||
| int k = 25; // The number of clusters | int k = 25; // The number of clusters | ||||
| int num_classes = 10; // The 10 digits | int num_classes = 10; // The 10 digits | ||||
| int num_features = 784; // Each image is 28x28 pixels | int num_features = 784; // Each image is 28x28 pixels | ||||
| @@ -45,7 +49,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size); | |||||
| full_data_x = mnist.train.images; | full_data_x = mnist.train.images; | ||||
| } | } | ||||
| } | } | ||||
| @@ -13,16 +13,16 @@ namespace TensorFlowNET.Examples | |||||
| public class LinearRegression : Python, IExample | public class LinearRegression : Python, IExample | ||||
| { | { | ||||
| public int Priority => 3; | public int Priority => 3; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Linear Regression"; | public string Name => "Linear Regression"; | ||||
| NumPyRandom rng = np.random; | |||||
| public int training_epochs = 1000; | |||||
| // Parameters | // Parameters | ||||
| float learning_rate = 0.01f; | float learning_rate = 0.01f; | ||||
| int training_epochs = 1000; | |||||
| int display_step = 50; | int display_step = 50; | ||||
| NumPyRandom rng = np.random; | |||||
| NDArray train_X, train_Y; | NDArray train_X, train_Y; | ||||
| int n_samples; | int n_samples; | ||||
| @@ -17,12 +17,16 @@ namespace TensorFlowNET.Examples | |||||
| public class LogisticRegression : Python, IExample | public class LogisticRegression : Python, IExample | ||||
| { | { | ||||
| public int Priority => 4; | public int Priority => 4; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Logistic Regression"; | public string Name => "Logistic Regression"; | ||||
| public int training_epochs = 10; | |||||
| public int? train_size = null; | |||||
| public int validation_size = 5000; | |||||
| public int? test_size = null; | |||||
| public int batch_size = 100; | |||||
| private float learning_rate = 0.01f; | private float learning_rate = 0.01f; | ||||
| private int training_epochs = 10; | |||||
| private int batch_size = 100; | |||||
| private int display_step = 1; | private int display_step = 1; | ||||
| Datasets mnist; | Datasets mnist; | ||||
| @@ -96,7 +100,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -139,7 +143,7 @@ namespace TensorFlowNET.Examples | |||||
| if (results.argmax() == (batch_ys[0] as NDArray).argmax()) | if (results.argmax() == (batch_ys[0] as NDArray).argmax()) | ||||
| print("predicted OK!"); | print("predicted OK!"); | ||||
| else | else | ||||
| throw new ValueError("predict error, maybe 90% accuracy"); | |||||
| throw new ValueError("predict error, should be 90% accuracy"); | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,7 +10,7 @@ namespace TensorFlowNET.Examples | |||||
| public class MetaGraph : Python, IExample | public class MetaGraph : Python, IExample | ||||
| { | { | ||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; } = false; | |||||
| public string Name => "Meta Graph"; | public string Name => "Meta Graph"; | ||||
| public bool Run() | public bool Run() | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples | |||||
| public class NaiveBayesClassifier : Python, IExample | public class NaiveBayesClassifier : Python, IExample | ||||
| { | { | ||||
| public int Priority => 6; | public int Priority => 6; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Naive Bayes Classifier"; | public string Name => "Naive Bayes Classifier"; | ||||
| public Normal dist { get; set; } | public Normal dist { get; set; } | ||||
| @@ -11,7 +11,7 @@ namespace TensorFlowNET.Examples | |||||
| public class NamedEntityRecognition : Python, IExample | public class NamedEntityRecognition : Python, IExample | ||||
| { | { | ||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; } = false; | |||||
| public string Name => "NER"; | public string Name => "NER"; | ||||
| public bool Run() | public bool Run() | ||||
| @@ -15,10 +15,13 @@ namespace TensorFlowNET.Examples | |||||
| public class NearestNeighbor : Python, IExample | public class NearestNeighbor : Python, IExample | ||||
| { | { | ||||
| public int Priority => 5; | public int Priority => 5; | ||||
| public bool Enabled => true; | |||||
| public bool Enabled { get; set; } = true; | |||||
| public string Name => "Nearest Neighbor"; | public string Name => "Nearest Neighbor"; | ||||
| Datasets mnist; | Datasets mnist; | ||||
| NDArray Xtr, Ytr, Xte, Yte; | NDArray Xtr, Ytr, Xte, Yte; | ||||
| public int? TrainSize = null; | |||||
| public int ValidationSize = 5000; | |||||
| public int? TestSize = null; | |||||
| public bool Run() | public bool Run() | ||||
| { | { | ||||
| @@ -62,10 +65,10 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); | |||||
| // In this example, we limit mnist data | // In this example, we limit mnist data | ||||
| (Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(200); // 200 for testing | |||||
| (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; | private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; | ||||
| private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; | private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; | ||||
| public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len) | |||||
| public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null) | |||||
| { | { | ||||
| string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; | string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; | ||||
| /*if (step == "train") | /*if (step == "train") | ||||
| @@ -25,10 +25,11 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| char_dict[c.ToString()] = char_dict.Count; | char_dict[c.ToString()] = char_dict.Count; | ||||
| var contents = File.ReadAllLines(TRAIN_PATH); | var contents = File.ReadAllLines(TRAIN_PATH); | ||||
| var x = new int[contents.Length][]; | |||||
| var y = new int[contents.Length]; | |||||
| for (int i = 0; i < contents.Length; i++) | |||||
| var size = limit == null ? contents.Length : limit.Value; | |||||
| var x = new int[size][]; | |||||
| var y = new int[size]; | |||||
| for (int i = 0; i < size; i++) | |||||
| { | { | ||||
| string[] parts = contents[i].ToLower().Split(",\"").ToArray(); | string[] parts = contents[i].ToLower().Split(",\"").ToArray(); | ||||
| string content = parts[2]; | string content = parts[2]; | ||||
| @@ -15,8 +15,9 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| public class TextClassificationTrain : Python, IExample | public class TextClassificationTrain : Python, IExample | ||||
| { | { | ||||
| public int Priority => 100; | public int Priority => 100; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; }= false; | |||||
| public string Name => "Text Classification"; | public string Name => "Text Classification"; | ||||
| public int? DataLimit = null; | |||||
| private string dataDir = "text_classification"; | private string dataDir = "text_classification"; | ||||
| private string dataFileName = "dbpedia_csv.tar.gz"; | private string dataFileName = "dbpedia_csv.tar.gz"; | ||||
| @@ -28,7 +29,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| { | { | ||||
| PrepareData(); | PrepareData(); | ||||
| Console.WriteLine("Building dataset..."); | Console.WriteLine("Building dataset..."); | ||||
| var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); | |||||
| var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit); | |||||
| var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | ||||
| @@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples | |||||
| public class TextClassificationWithMovieReviews : Python, IExample | public class TextClassificationWithMovieReviews : Python, IExample | ||||
| { | { | ||||
| public int Priority => 9; | public int Priority => 9; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled { get; set; } = false; | |||||
| public string Name => "Movie Reviews"; | public string Name => "Movie Reviews"; | ||||
| string dir = "text_classification_with_movie_reviews"; | string dir = "text_classification_with_movie_reviews"; | ||||
| @@ -15,29 +15,33 @@ namespace TensorFlowNET.Examples.Utility | |||||
| private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | ||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static Datasets read_data_sets(string train_dir, | public static Datasets read_data_sets(string train_dir, | ||||
| bool one_hot = false, | bool one_hot = false, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| bool reshape = true, | bool reshape = true, | ||||
| int validation_size = 5000, | int validation_size = 5000, | ||||
| int? train_size = null, | |||||
| int? test_size = null, | |||||
| string source_url = DEFAULT_SOURCE_URL) | string source_url = DEFAULT_SOURCE_URL) | ||||
| { | { | ||||
| if (train_size!=null && validation_size >= train_size) | |||||
| throw new ArgumentException("Validation set should be smaller than training set"); | |||||
| Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | ||||
| var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0])); | |||||
| var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); | |||||
| Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); | ||||
| var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
| var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size); | |||||
| Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); | ||||
| var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0])); | |||||
| var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size); | |||||
| Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); | ||||
| var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot); | |||||
| var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size); | |||||
| int end = train_images.shape[0]; | int end = train_images.shape[0]; | ||||
| var validation_images = train_images[np.arange(validation_size)]; | var validation_images = train_images[np.arange(validation_size)]; | ||||
| @@ -52,14 +56,15 @@ namespace TensorFlowNET.Examples.Utility | |||||
| return new Datasets(train, validation, test); | return new Datasets(train, validation, test); | ||||
| } | } | ||||
| public static NDArray extract_images(string file) | |||||
| public static NDArray extract_images(string file, int? limit=null) | |||||
| { | { | ||||
| using (var bytestream = new FileStream(file, FileMode.Open)) | using (var bytestream = new FileStream(file, FileMode.Open)) | ||||
| { | { | ||||
| var magic = _read32(bytestream); | var magic = _read32(bytestream); | ||||
| if (magic != 2051) | if (magic != 2051) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); | ||||
| var num_images = _read32(bytestream); | |||||
| var num_images = _read32(bytestream); | |||||
| num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | |||||
| var rows = _read32(bytestream); | var rows = _read32(bytestream); | ||||
| var cols = _read32(bytestream); | var cols = _read32(bytestream); | ||||
| var buf = new byte[rows * cols * num_images]; | var buf = new byte[rows * cols * num_images]; | ||||
| @@ -70,7 +75,7 @@ namespace TensorFlowNET.Examples.Utility | |||||
| } | } | ||||
| } | } | ||||
| public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10) | |||||
| public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) | |||||
| { | { | ||||
| using (var bytestream = new FileStream(file, FileMode.Open)) | using (var bytestream = new FileStream(file, FileMode.Open)) | ||||
| { | { | ||||
| @@ -78,6 +83,7 @@ namespace TensorFlowNET.Examples.Utility | |||||
| if (magic != 2049) | if (magic != 2049) | ||||
| throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); | ||||
| var num_items = _read32(bytestream); | var num_items = _read32(bytestream); | ||||
| num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit); | |||||
| var buf = new byte[num_items]; | var buf = new byte[num_items]; | ||||
| bytestream.Read(buf, 0, buf.Length); | bytestream.Read(buf, 0, buf.Length); | ||||
| var labels = np.frombuffer(buf, np.uint8); | var labels = np.frombuffer(buf, np.uint8); | ||||
| @@ -0,0 +1,99 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using TensorFlowNET.Examples; | |||||
| using TensorFlowNET.Examples.CnnTextClassification; | |||||
| namespace TensorFlowNET.UnitTest.ExamplesTests | |||||
| { | |||||
| [TestClass] | |||||
| public class ExamplesTest | |||||
| { | |||||
| [TestMethod] | |||||
| public void BasicOperations() | |||||
| { | |||||
| new BasicOperations() { Enabled = true }.Run(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void HelloWorld() | |||||
| { | |||||
| new HelloWorld() { Enabled = true }.Run(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void ImageRecognition() | |||||
| { | |||||
| new HelloWorld() { Enabled = true }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void InceptionArchGoogLeNet() | |||||
| { | |||||
| new InceptionArchGoogLeNet() { Enabled = true }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void KMeansClustering() | |||||
| { | |||||
| new KMeansClustering() { Enabled = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void LinearRegression() | |||||
| { | |||||
| new LinearRegression() { Enabled = true }.Run(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void LogisticRegression() | |||||
| { | |||||
| new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void MetaGraph() | |||||
| { | |||||
| new MetaGraph() { Enabled = true }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void NaiveBayesClassifier() | |||||
| { | |||||
| new NaiveBayesClassifier() { Enabled = true }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void NamedEntityRecognition() | |||||
| { | |||||
| new NamedEntityRecognition() { Enabled = true }.Run(); | |||||
| } | |||||
| [TestMethod] | |||||
| public void NearestNeighbor() | |||||
| { | |||||
| new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void TextClassificationTrain() | |||||
| { | |||||
| new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); | |||||
| } | |||||
| [Ignore] | |||||
| [TestMethod] | |||||
| public void TextClassificationWithMovieReviews() | |||||
| { | |||||
| new TextClassificationWithMovieReviews() { Enabled = true }.Run(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,6 +23,7 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| </Project> | </Project> | ||||