diff --git a/docs/source/NeuralNetwork.md b/docs/source/NeuralNetwork.md index 9da93405..1d46111b 100644 --- a/docs/source/NeuralNetwork.md +++ b/docs/source/NeuralNetwork.md @@ -1,4 +1,4 @@ -# Neural Network +# Chapter. Neural Network In this chapter, we'll learn how to build a graph of neural network model. The key advantage of neural network compared to Linear Classifier is that it can separate data which it not linearly separable. We'll implement this model to classify hand-written digits images from the MNIST dataset. diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 1819f740..d75f279c 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -9,6 +9,7 @@ using Newtonsoft.Json; using NumSharp; using Tensorflow; using Tensorflow.Sessions; +using TensorFlowNET.Examples.Text; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; @@ -24,24 +25,27 @@ namespace TensorFlowNET.Examples public int? DataLimit = null; public bool IsImportingGraph { get; set; } = false; - private const string dataDir = "word_cnn"; - private string dataFileName = "dbpedia_csv.tar.gz"; + const string dataDir = "cnn_text"; + string dataFileName = "dbpedia_csv.tar.gz"; - private const string TRAIN_PATH = "word_cnn/dbpedia_csv/train.csv"; - private const string TEST_PATH = "word_cnn/dbpedia_csv/test.csv"; + string TRAIN_PATH = $"{dataDir}/dbpedia_csv/train.csv"; + string TEST_PATH = $"{dataDir}/dbpedia_csv/test.csv"; - private const int NUM_CLASS = 14; - private const int BATCH_SIZE = 64; - private const int NUM_EPOCHS = 10; - private const int WORD_MAX_LEN = 100; - private const int CHAR_MAX_LEN = 1014; + int NUM_CLASS = 14; + int BATCH_SIZE = 64; + int NUM_EPOCHS = 10; + int WORD_MAX_LEN = 100; + int CHAR_MAX_LEN = 1014; - protected float loss_value = 0; + float loss_value = 0; double max_accuracy = 0; - int vocabulary_size = 50000; + int vocabulary_size = -1; NDArray train_x, valid_x, train_y, valid_y; + ITextModel textModel; + public string ModelName = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn + public bool Run() { PrepareData(); @@ -68,7 +72,7 @@ namespace TensorFlowNET.Examples return (train_x, valid_x, train_y, valid_y); } - private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) + private void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) { int i = 0; var label_keys = labels.Keys.ToArray(); @@ -114,10 +118,8 @@ namespace TensorFlowNET.Examples Console.WriteLine("Building dataset..."); - int alphabet_size = 0; - var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - //vocabulary_size = len(word_dict); + vocabulary_size = len(word_dict); var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); Console.WriteLine("\tDONE "); @@ -155,83 +157,19 @@ namespace TensorFlowNET.Examples { var graph = tf.Graph().as_default(); - var embedding_size = 128; - var learning_rate = 0.001f; - var filter_sizes = new int[3, 4, 5]; - var num_filters = 100; - var document_max_len = 100; - - var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); - var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); - var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); - var global_step = tf.Variable(0, trainable: false); - var keep_prob = tf.where(is_training, 0.5f, 1.0f); - Tensor x_emb = null; - - with(tf.name_scope("embedding"), scope => - { - var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size }); - var embeddings = tf.get_variable("embeddings", initializer: init_embeddings); - x_emb = tf.nn.embedding_lookup(embeddings, x); - x_emb = tf.expand_dims(x_emb, -1); - }); - - var pooled_outputs = new List(); - for (int len = 0; len < filter_sizes.Rank; len++) + switch (ModelName) { - int filter_size = filter_sizes.GetLength(len); - var conv = tf.layers.conv2d( - x_emb, - filters: num_filters, - kernel_size: new int[] { filter_size, embedding_size }, - strides: new int[] { 1, 1 }, - padding: "VALID", - activation: tf.nn.relu()); - - var pool = tf.layers.max_pooling2d( - conv, - pool_size: new[] { document_max_len - filter_size + 1, 1 }, - strides: new[] { 1, 1 }, - padding: "VALID"); - - pooled_outputs.Add(pool); + case "word_cnn": + textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS); + break; } - var h_pool = tf.concat(pooled_outputs, 3); - var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); - Tensor h_drop = null; - with(tf.name_scope("dropout"), delegate - { - h_drop = tf.nn.dropout(h_pool_flat, keep_prob); - }); - - Tensor logits = null; - Tensor predictions = null; - with(tf.name_scope("output"), delegate - { - logits = tf.layers.dense(h_drop, NUM_CLASS); - predictions = tf.argmax(logits, -1, output_type: tf.int32); - }); - - with(tf.name_scope("loss"), delegate - { - var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y); - var loss = tf.reduce_mean(sscel); - var adam = tf.train.AdamOptimizer(learning_rate); - var optimizer = adam.minimize(loss, global_step: global_step); - }); - - with(tf.name_scope("accuracy"), delegate - { - var correct_predictions = tf.equal(predictions, y); - var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy"); - }); - return graph; } - private bool Train(Session sess, Graph graph) + public void Train(Session sess) { + var graph = tf.get_default_graph(); var stopwatch = Stopwatch.StartNew(); sess.run(tf.global_variables_initializer()); @@ -263,10 +201,7 @@ namespace TensorFlowNET.Examples loss_value = result[2]; var step = (int)result[1]; if (step % 10 == 0) - { - var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); - Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}"); - } + Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value.ToString("0.0000")}."); if (step % 100 == 0) { @@ -289,7 +224,7 @@ namespace TensorFlowNET.Examples var valid_accuracy = sum_accuracy / cnt; - print($"\nValidation Accuracy = {valid_accuracy}\n"); + print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n"); // Save model if (valid_accuracy > max_accuracy) @@ -300,13 +235,6 @@ namespace TensorFlowNET.Examples } } } - - return max_accuracy > 0.9; - } - - public void Train(Session sess) - { - Train(sess, sess.graph); } public void Predict(Session sess) diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs deleted file mode 100644 index 800cd5a3..00000000 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ /dev/null @@ -1,298 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Text; -using NumSharp; -using Tensorflow; -using Tensorflow.Keras.Engine; -using Tensorflow.Sessions; -using TensorFlowNET.Examples.Text.cnn_models; -using TensorFlowNET.Examples.TextClassification; -using TensorFlowNET.Examples.Utility; -using static Tensorflow.Python; - -namespace TensorFlowNET.Examples -{ - /// - /// https://github.com/dongjun-Lee/text-classification-models-tf - /// - public class TextClassificationTrain : IExample - { - public bool Enabled { get; set; } = false; - public string Name => "Text Classification"; - public int? DataLimit = null; - public bool IsImportingGraph { get; set; } = true; - public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia - - private string dataDir = "text_classification"; - private string dataFileName = "dbpedia_csv.tar.gz"; - - public string model_name = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn - - private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; - private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv"; - private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - - private const int NUM_CLASS = 14; - private const int BATCH_SIZE = 64; - private const int NUM_EPOCHS = 10; - private const int WORD_MAX_LEN = 100; - private const int CHAR_MAX_LEN = 1014; - - protected float loss_value = 0; - - public bool Run() - { - PrepareData(); - var graph = tf.Graph().as_default(); - return with(tf.Session(graph), sess => - { - if (IsImportingGraph) - return RunWithImportedGraph(sess, graph); - else - return RunWithBuiltGraph(sess, graph); - }); - } - - protected virtual bool RunWithImportedGraph(Session sess, Graph graph) - { - var stopwatch = Stopwatch.StartNew(); - Console.WriteLine("Building dataset..."); - var path = UseSubset ? SUBSET_PATH : TRAIN_PATH; - int[][] x = null; - int[] y = null; - int alphabet_size = 0; - int vocabulary_size = 0; - - if (model_name == "vd_cnn") - (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset); - else - { - var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - vocabulary_size = len(word_dict); - (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); - } - - Console.WriteLine("\tDONE "); - - var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - Console.WriteLine("Training set size: " + train_x.len); - Console.WriteLine("Test set size: " + valid_x.len); - - Console.WriteLine("Import graph..."); - var meta_file = model_name + ".meta"; - tf.train.import_meta_graph(Path.Join("graph", meta_file)); - Console.WriteLine("\tDONE " + stopwatch.Elapsed); - - sess.run(tf.global_variables_initializer()); - var saver = tf.train.Saver(tf.global_variables()); - - var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); - var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; - double max_accuracy = 0; - - Tensor is_training = graph.OperationByName("is_training"); - Tensor model_x = graph.OperationByName("x"); - Tensor model_y = graph.OperationByName("y"); - Tensor loss = graph.OperationByName("loss/Mean"); // word_cnn - Operation optimizer = graph.OperationByName("loss/Adam"); // word_cnn - Tensor global_step = graph.OperationByName("Variable"); - Tensor accuracy = graph.OperationByName("accuracy/accuracy"); - stopwatch = Stopwatch.StartNew(); - int i = 0; - foreach (var (x_batch, y_batch, total) in train_batches) - { - i++; - var train_feed_dict = new FeedDict - { - [model_x] = x_batch, - [model_y] = y_batch, - [is_training] = true, - }; - //Console.WriteLine("x: " + x_batch.ToString() + "\n"); - //Console.WriteLine("y: " + y_batch.ToString()); - // original python: - //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) - var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); - loss_value = result[2]; - var step = (int)result[1]; - if (step % 10 == 0) - { - var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); - Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}"); - } - - if (step % 100 == 0) - { - // # Test accuracy with validation data for each epoch. - var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); - var (sum_accuracy, cnt) = (0.0f, 0); - foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) - { - var valid_feed_dict = new FeedDict - { - [model_x] = valid_x_batch, - [model_y] = valid_y_batch, - [is_training] = false - }; - var result1 = sess.run(accuracy, valid_feed_dict); - float accuracy_value = result1; - sum_accuracy += accuracy_value; - cnt += 1; - } - - var valid_accuracy = sum_accuracy / cnt; - - print($"\nValidation Accuracy = {valid_accuracy}\n"); - - // # Save model - if (valid_accuracy > max_accuracy) - { - max_accuracy = valid_accuracy; - // saver.save(sess, $"{dataDir}/{model_name}.ckpt", global_step: step.ToString()); - print("Model is saved.\n"); - } - } - } - - return false; - } - - protected virtual bool RunWithBuiltGraph(Session session, Graph graph) - { - Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); - - var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - - ITextClassificationModel model = null; - switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn - { - case "word_cnn": - case "char_cnn": - case "word_rnn": - case "att_rnn": - case "rcnn": - throw new NotImplementedException(); - break; - case "vd_cnn": - model = new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); - break; - } - // todo train the model - return false; - } - - // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here - private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f) - { - Console.WriteLine("Splitting in Training and Testing data..."); - int len = x.shape[0]; - //int classes = y.Data().Distinct().Count(); - //int samples = len / classes; - int train_size = (int)Math.Round(len * (1 - test_size)); - var train_x = x[new Slice(stop: train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size), new Slice()]; - var train_y = y[new Slice(stop: train_size)]; - var valid_y = y[new Slice(start: train_size)]; - Console.WriteLine("\tDONE"); - return (train_x, valid_x, train_y, valid_y); - } - - private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) - { - int i = 0; - var label_keys = labels.Keys.ToArray(); - while (i < shuffled_x.Length) - { - var key = label_keys[random.Next(label_keys.Length)]; - var set = labels[key]; - var index = set.First(); - if (set.Count == 0) - { - labels.Remove(key); // remove the set as it is empty - label_keys = labels.Keys.ToArray(); - } - shuffled_x[i] = x[index]; - shuffled_y[i] = y[index]; - i++; - } - } - - private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) - { - var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; - var total_batches = num_batches_per_epoch * num_epochs; - foreach (var epoch in range(num_epochs)) - { - foreach (var batch_num in range(num_batches_per_epoch)) - { - var start_index = batch_num * batch_size; - var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); - if (end_index <= start_index) - break; - yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); - } - } - } - - public void PrepareData() - { - if (UseSubset) - { - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; - Web.Download(url, dataDir, "dbpedia_subset.zip"); - Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - } - else - { - string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; - Web.Download(url, dataDir, dataFileName); - Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); - } - - if (IsImportingGraph) - { - // download graph meta data - var meta_file = model_name + ".meta"; - var meta_path = Path.Combine("graph", meta_file); - if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11)) - { - // delete old cached file which contains errors - Console.WriteLine("Discarding cached file: " + meta_path); - File.Delete(meta_path); - } - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; - Web.Download(url, "graph", meta_file); - } - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - - public void Train(Session sess) - { - throw new NotImplementedException(); - } - - public void Predict(Session sess) - { - throw new NotImplementedException(); - } - - public void Test(Session sess) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextClassificationModel.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs similarity index 68% rename from test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextClassificationModel.cs rename to test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs index 942f2e04..adf23c23 100644 --- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextClassificationModel.cs +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs @@ -3,9 +3,9 @@ using System.Collections.Generic; using System.Text; using Tensorflow; -namespace TensorFlowNET.Examples.Text.cnn_models +namespace TensorFlowNET.Examples.Text { - interface ITextClassificationModel + interface ITextModel { Tensor is_training { get; } Tensor x { get;} diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs index f4f430a5..06852faf 100644 --- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs @@ -3,12 +3,11 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow; -using TensorFlowNET.Examples.Text.cnn_models; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.TextClassification +namespace TensorFlowNET.Examples.Text { - public class VdCnn : ITextClassificationModel + public class VdCnn : ITextModel { private int embedding_size; private int[] filter_sizes; diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs new file mode 100644 index 00000000..33e2fac5 --- /dev/null +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs @@ -0,0 +1,104 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using static Tensorflow.Python; + +namespace TensorFlowNET.Examples.Text +{ + public class WordCnn : ITextModel + { + private int embedding_size; + private int[] filter_sizes; + private int[] num_filters; + private int[] num_blocks; + private float learning_rate; + private IInitializer cnn_initializer; + private IInitializer fc_initializer; + public Tensor x { get; private set; } + public Tensor y { get; private set; } + public Tensor is_training { get; private set; } + private RefVariable global_step; + private RefVariable embeddings; + private Tensor x_emb; + private Tensor x_expanded; + private Tensor logits; + private Tensor predictions; + private Tensor loss; + + public WordCnn(int vocabulary_size, int document_max_len, int num_class) + { + var embedding_size = 128; + var learning_rate = 0.001f; + var filter_sizes = new int[3, 4, 5]; + var num_filters = 100; + + var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); + var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); + var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); + var global_step = tf.Variable(0, trainable: false); + var keep_prob = tf.where(is_training, 0.5f, 1.0f); + Tensor x_emb = null; + + with(tf.name_scope("embedding"), scope => + { + var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size }); + var embeddings = tf.get_variable("embeddings", initializer: init_embeddings); + x_emb = tf.nn.embedding_lookup(embeddings, x); + x_emb = tf.expand_dims(x_emb, -1); + }); + + var pooled_outputs = new List(); + for (int len = 0; len < filter_sizes.Rank; len++) + { + int filter_size = filter_sizes.GetLength(len); + var conv = tf.layers.conv2d( + x_emb, + filters: num_filters, + kernel_size: new int[] { filter_size, embedding_size }, + strides: new int[] { 1, 1 }, + padding: "VALID", + activation: tf.nn.relu()); + + var pool = tf.layers.max_pooling2d( + conv, + pool_size: new[] { document_max_len - filter_size + 1, 1 }, + strides: new[] { 1, 1 }, + padding: "VALID"); + + pooled_outputs.Add(pool); + } + + var h_pool = tf.concat(pooled_outputs, 3); + var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank)); + Tensor h_drop = null; + with(tf.name_scope("dropout"), delegate + { + h_drop = tf.nn.dropout(h_pool_flat, keep_prob); + }); + + Tensor logits = null; + Tensor predictions = null; + with(tf.name_scope("output"), delegate + { + logits = tf.layers.dense(h_drop, num_class); + predictions = tf.argmax(logits, -1, output_type: tf.int32); + }); + + with(tf.name_scope("loss"), delegate + { + var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y); + var loss = tf.reduce_mean(sscel); + var adam = tf.train.AdamOptimizer(learning_rate); + var optimizer = adam.minimize(loss, global_step: global_step); + }); + + with(tf.name_scope("accuracy"), delegate + { + var correct_predictions = tf.equal(predictions, y); + var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy"); + }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index a267324e..4c71d7e2 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -83,22 +83,13 @@ namespace TensorFlowNET.ExamplesTests new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); } - [Ignore] [TestMethod] - public void TextClassificationTrain() - { - tf.Graph().as_default(); - new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run(); - } - + public void WordCnnTextClassification() + => new CnnTextClassification { Enabled = true, ModelName = "word_cnn", DataLimit =100 }.Run(); [TestMethod] - public void CnnTextClassificationTrain() - { - tf.Graph().as_default(); - new CnnTextClassification() { Enabled = true, IsImportingGraph = false }.Run(); - } - + public void CharCnnTextClassification() + => new CnnTextClassification { Enabled = true, ModelName = "char_cnn", DataLimit = 100 }.Run(); [Ignore] [TestMethod]