Browse Source

move word cnn to seperate class.

tags/v0.9
Oceania2018 6 years ago
parent
commit
771a82861f
7 changed files with 138 additions and 414 deletions
  1. +1
    -1
      docs/source/NeuralNetwork.md
  2. +25
    -97
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
  3. +0
    -298
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
  4. +2
    -2
      test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs
  5. +2
    -3
      test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs
  6. +104
    -0
      test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs
  7. +4
    -13
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 1
- 1
docs/source/NeuralNetwork.md View File

@@ -1,4 +1,4 @@
# Neural Network
# Chapter. Neural Network


In this chapter, we'll learn how to build a graph of neural network model. The key advantage of neural network compared to Linear Classifier is that it can separate data which it not linearly separable. We'll implement this model to classify hand-written digits images from the MNIST dataset. In this chapter, we'll learn how to build a graph of neural network model. The key advantage of neural network compared to Linear Classifier is that it can separate data which it not linearly separable. We'll implement this model to classify hand-written digits images from the MNIST dataset.




+ 25
- 97
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -9,6 +9,7 @@ using Newtonsoft.Json;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Sessions; using Tensorflow.Sessions;
using TensorFlowNET.Examples.Text;
using TensorFlowNET.Examples.Utility; using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python; using static Tensorflow.Python;


@@ -24,24 +25,27 @@ namespace TensorFlowNET.Examples
public int? DataLimit = null; public int? DataLimit = null;
public bool IsImportingGraph { get; set; } = false; public bool IsImportingGraph { get; set; } = false;


private const string dataDir = "word_cnn";
private string dataFileName = "dbpedia_csv.tar.gz";
const string dataDir = "cnn_text";
string dataFileName = "dbpedia_csv.tar.gz";


private const string TRAIN_PATH = "word_cnn/dbpedia_csv/train.csv";
private const string TEST_PATH = "word_cnn/dbpedia_csv/test.csv";
string TRAIN_PATH = $"{dataDir}/dbpedia_csv/train.csv";
string TEST_PATH = $"{dataDir}/dbpedia_csv/test.csv";
private const int NUM_CLASS = 14;
private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10;
private const int WORD_MAX_LEN = 100;
private const int CHAR_MAX_LEN = 1014;
int NUM_CLASS = 14;
int BATCH_SIZE = 64;
int NUM_EPOCHS = 10;
int WORD_MAX_LEN = 100;
int CHAR_MAX_LEN = 1014;
protected float loss_value = 0;
float loss_value = 0;
double max_accuracy = 0; double max_accuracy = 0;


int vocabulary_size = 50000;
int vocabulary_size = -1;
NDArray train_x, valid_x, train_y, valid_y; NDArray train_x, valid_x, train_y, valid_y;


ITextModel textModel;
public string ModelName = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn

public bool Run() public bool Run()
{ {
PrepareData(); PrepareData();
@@ -68,7 +72,7 @@ namespace TensorFlowNET.Examples
return (train_x, valid_x, train_y, valid_y); return (train_x, valid_x, train_y, valid_y);
} }


private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
private void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{ {
int i = 0; int i = 0;
var label_keys = labels.Keys.ToArray(); var label_keys = labels.Keys.ToArray();
@@ -114,10 +118,8 @@ namespace TensorFlowNET.Examples


Console.WriteLine("Building dataset..."); Console.WriteLine("Building dataset...");


int alphabet_size = 0;

var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
//vocabulary_size = len(word_dict);
vocabulary_size = len(word_dict);
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);


Console.WriteLine("\tDONE "); Console.WriteLine("\tDONE ");
@@ -155,83 +157,19 @@ namespace TensorFlowNET.Examples
{ {
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();


var embedding_size = 128;
var learning_rate = 0.001f;
var filter_sizes = new int[3, 4, 5];
var num_filters = 100;
var document_max_len = 100;

var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
var global_step = tf.Variable(0, trainable: false);
var keep_prob = tf.where(is_training, 0.5f, 1.0f);
Tensor x_emb = null;

with(tf.name_scope("embedding"), scope =>
{
var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
x_emb = tf.nn.embedding_lookup(embeddings, x);
x_emb = tf.expand_dims(x_emb, -1);
});

var pooled_outputs = new List<Tensor>();
for (int len = 0; len < filter_sizes.Rank; len++)
switch (ModelName)
{ {
int filter_size = filter_sizes.GetLength(len);
var conv = tf.layers.conv2d(
x_emb,
filters: num_filters,
kernel_size: new int[] { filter_size, embedding_size },
strides: new int[] { 1, 1 },
padding: "VALID",
activation: tf.nn.relu());

var pool = tf.layers.max_pooling2d(
conv,
pool_size: new[] { document_max_len - filter_size + 1, 1 },
strides: new[] { 1, 1 },
padding: "VALID");

pooled_outputs.Add(pool);
case "word_cnn":
textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS);
break;
} }


var h_pool = tf.concat(pooled_outputs, 3);
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));
Tensor h_drop = null;
with(tf.name_scope("dropout"), delegate
{
h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
});

Tensor logits = null;
Tensor predictions = null;
with(tf.name_scope("output"), delegate
{
logits = tf.layers.dense(h_drop, NUM_CLASS);
predictions = tf.argmax(logits, -1, output_type: tf.int32);
});

with(tf.name_scope("loss"), delegate
{
var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
var loss = tf.reduce_mean(sscel);
var adam = tf.train.AdamOptimizer(learning_rate);
var optimizer = adam.minimize(loss, global_step: global_step);
});

with(tf.name_scope("accuracy"), delegate
{
var correct_predictions = tf.equal(predictions, y);
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
});

return graph; return graph;
} }


private bool Train(Session sess, Graph graph)
public void Train(Session sess)
{ {
var graph = tf.get_default_graph();
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();


sess.run(tf.global_variables_initializer()); sess.run(tf.global_variables_initializer());
@@ -263,10 +201,7 @@ namespace TensorFlowNET.Examples
loss_value = result[2]; loss_value = result[2];
var step = (int)result[1]; var step = (int)result[1];
if (step % 10 == 0) if (step % 10 == 0)
{
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
}
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value.ToString("0.0000")}.");


if (step % 100 == 0) if (step % 100 == 0)
{ {
@@ -289,7 +224,7 @@ namespace TensorFlowNET.Examples


var valid_accuracy = sum_accuracy / cnt; var valid_accuracy = sum_accuracy / cnt;


print($"\nValidation Accuracy = {valid_accuracy}\n");
print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n");


// Save model // Save model
if (valid_accuracy > max_accuracy) if (valid_accuracy > max_accuracy)
@@ -300,13 +235,6 @@ namespace TensorFlowNET.Examples
} }
} }
} }

return max_accuracy > 0.9;
}

public void Train(Session sess)
{
Train(sess, sess.graph);
} }


public void Predict(Session sess) public void Predict(Session sess)


+ 0
- 298
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -1,298 +0,0 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using NumSharp;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Sessions;
using TensorFlowNET.Examples.Text.cnn_models;
using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples
{
/// <summary>
/// https://github.com/dongjun-Lee/text-classification-models-tf
/// </summary>
public class TextClassificationTrain : IExample
{
public bool Enabled { get; set; } = false;
public string Name => "Text Classification";
public int? DataLimit = null;
public bool IsImportingGraph { get; set; } = true;
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia

private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz";

public string model_name = "word_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn

private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string SUBSET_PATH = "text_classification/dbpedia_csv/dbpedia_6400.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";

private const int NUM_CLASS = 14;
private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10;
private const int WORD_MAX_LEN = 100;
private const int CHAR_MAX_LEN = 1014;
protected float loss_value = 0;

public bool Run()
{
PrepareData();
var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess =>
{
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph);
else
return RunWithBuiltGraph(sess, graph);
});
}

protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
{
var stopwatch = Stopwatch.StartNew();
Console.WriteLine("Building dataset...");
var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;
int[][] x = null;
int[] y = null;
int alphabet_size = 0;
int vocabulary_size = 0;

if (model_name == "vd_cnn")
(x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle:!UseSubset);
else
{
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
vocabulary_size = len(word_dict);
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
}

Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.len);
Console.WriteLine("Test set size: " + valid_x.len);

Console.WriteLine("Import graph...");
var meta_file = model_name + ".meta";
tf.train.import_meta_graph(Path.Join("graph", meta_file));
Console.WriteLine("\tDONE " + stopwatch.Elapsed);

sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
double max_accuracy = 0;

Tensor is_training = graph.OperationByName("is_training");
Tensor model_x = graph.OperationByName("x");
Tensor model_y = graph.OperationByName("y");
Tensor loss = graph.OperationByName("loss/Mean"); // word_cnn
Operation optimizer = graph.OperationByName("loss/Adam"); // word_cnn
Tensor global_step = graph.OperationByName("Variable");
Tensor accuracy = graph.OperationByName("accuracy/accuracy");
stopwatch = Stopwatch.StartNew();
int i = 0;
foreach (var (x_batch, y_batch, total) in train_batches)
{
i++;
var train_feed_dict = new FeedDict
{
[model_x] = x_batch,
[model_y] = y_batch,
[is_training] = true,
};
//Console.WriteLine("x: " + x_batch.ToString() + "\n");
//Console.WriteLine("y: " + y_batch.ToString());
// original python:
//_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
loss_value = result[2];
var step = (int)result[1];
if (step % 10 == 0)
{
var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
}

if (step % 100 == 0)
{
// # Test accuracy with validation data for each epoch.
var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
var (sum_accuracy, cnt) = (0.0f, 0);
foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
{
var valid_feed_dict = new FeedDict
{
[model_x] = valid_x_batch,
[model_y] = valid_y_batch,
[is_training] = false
};
var result1 = sess.run(accuracy, valid_feed_dict);
float accuracy_value = result1;
sum_accuracy += accuracy_value;
cnt += 1;
}

var valid_accuracy = sum_accuracy / cnt;

print($"\nValidation Accuracy = {valid_accuracy}\n");
// # Save model
if (valid_accuracy > max_accuracy)
{
max_accuracy = valid_accuracy;
// saver.save(sess, $"{dataDir}/{model_name}.ckpt", global_step: step.ToString());
print("Model is saved.\n");
}
}
}

return false;
}

protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
{
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

ITextClassificationModel model = null;
switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
{
case "word_cnn":
case "char_cnn":
case "word_rnn":
case "att_rnn":
case "rcnn":
throw new NotImplementedException();
break;
case "vd_cnn":
model = new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
break;
}
// todo train the model
return false;
}
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
{
Console.WriteLine("Splitting in Training and Testing data...");
int len = x.shape[0];
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y);
}

private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{
int i = 0;
var label_keys = labels.Keys.ToArray();
while (i < shuffled_x.Length)
{
var key = label_keys[random.Next(label_keys.Length)];
var set = labels[key];
var index = set.First();
if (set.Count == 0)
{
labels.Remove(key); // remove the set as it is empty
label_keys = labels.Keys.ToArray();
}
shuffled_x[i] = x[index];
shuffled_y[i] = y[index];
i++;
}
}

private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
var total_batches = num_batches_per_epoch * num_epochs;
foreach (var epoch in range(num_epochs))
{
foreach (var batch_num in range(num_batches_per_epoch))
{
var start_index = batch_num * batch_size;
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
if (end_index <= start_index)
break;
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
}
}
}

public void PrepareData()
{
if (UseSubset)
{
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
}
else
{
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}

if (IsImportingGraph)
{
// download graph meta data
var meta_file = model_name + ".meta";
var meta_path = Path.Combine("graph", meta_file);
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
{
// delete old cached file which contains errors
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public void Train(Session sess)
{
throw new NotImplementedException();
}
public void Predict(Session sess)
{
throw new NotImplementedException();
}
public void Test(Session sess)
{
throw new NotImplementedException();
}
}
}

test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextClassificationModel.cs → test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs View File

@@ -3,9 +3,9 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;


namespace TensorFlowNET.Examples.Text.cnn_models
namespace TensorFlowNET.Examples.Text
{ {
interface ITextClassificationModel
interface ITextModel
{ {
Tensor is_training { get; } Tensor is_training { get; }
Tensor x { get;} Tensor x { get;}

+ 2
- 3
test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs View File

@@ -3,12 +3,11 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
using TensorFlowNET.Examples.Text.cnn_models;
using static Tensorflow.Python; using static Tensorflow.Python;


namespace TensorFlowNET.Examples.TextClassification
namespace TensorFlowNET.Examples.Text
{ {
public class VdCnn : ITextClassificationModel
public class VdCnn : ITextModel
{ {
private int embedding_size; private int embedding_size;
private int[] filter_sizes; private int[] filter_sizes;


+ 104
- 0
test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs View File

@@ -0,0 +1,104 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using static Tensorflow.Python;

namespace TensorFlowNET.Examples.Text
{
public class WordCnn : ITextModel
{
private int embedding_size;
private int[] filter_sizes;
private int[] num_filters;
private int[] num_blocks;
private float learning_rate;
private IInitializer cnn_initializer;
private IInitializer fc_initializer;
public Tensor x { get; private set; }
public Tensor y { get; private set; }
public Tensor is_training { get; private set; }
private RefVariable global_step;
private RefVariable embeddings;
private Tensor x_emb;
private Tensor x_expanded;
private Tensor logits;
private Tensor predictions;
private Tensor loss;

public WordCnn(int vocabulary_size, int document_max_len, int num_class)
{
var embedding_size = 128;
var learning_rate = 0.001f;
var filter_sizes = new int[3, 4, 5];
var num_filters = 100;

var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
var global_step = tf.Variable(0, trainable: false);
var keep_prob = tf.where(is_training, 0.5f, 1.0f);
Tensor x_emb = null;

with(tf.name_scope("embedding"), scope =>
{
var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
x_emb = tf.nn.embedding_lookup(embeddings, x);
x_emb = tf.expand_dims(x_emb, -1);
});

var pooled_outputs = new List<Tensor>();
for (int len = 0; len < filter_sizes.Rank; len++)
{
int filter_size = filter_sizes.GetLength(len);
var conv = tf.layers.conv2d(
x_emb,
filters: num_filters,
kernel_size: new int[] { filter_size, embedding_size },
strides: new int[] { 1, 1 },
padding: "VALID",
activation: tf.nn.relu());

var pool = tf.layers.max_pooling2d(
conv,
pool_size: new[] { document_max_len - filter_size + 1, 1 },
strides: new[] { 1, 1 },
padding: "VALID");

pooled_outputs.Add(pool);
}

var h_pool = tf.concat(pooled_outputs, 3);
var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));
Tensor h_drop = null;
with(tf.name_scope("dropout"), delegate
{
h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
});

Tensor logits = null;
Tensor predictions = null;
with(tf.name_scope("output"), delegate
{
logits = tf.layers.dense(h_drop, num_class);
predictions = tf.argmax(logits, -1, output_type: tf.int32);
});

with(tf.name_scope("loss"), delegate
{
var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
var loss = tf.reduce_mean(sscel);
var adam = tf.train.AdamOptimizer(learning_rate);
var optimizer = adam.minimize(loss, global_step: global_step);
});

with(tf.name_scope("accuracy"), delegate
{
var correct_predictions = tf.equal(predictions, y);
var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
});
}
}
}

+ 4
- 13
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -83,22 +83,13 @@ namespace TensorFlowNET.ExamplesTests
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
} }
[Ignore]
[TestMethod] [TestMethod]
public void TextClassificationTrain()
{
tf.Graph().as_default();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
}
public void WordCnnTextClassification()
=> new CnnTextClassification { Enabled = true, ModelName = "word_cnn", DataLimit =100 }.Run();
[TestMethod] [TestMethod]
public void CnnTextClassificationTrain()
{
tf.Graph().as_default();
new CnnTextClassification() { Enabled = true, IsImportingGraph = false }.Run();
}
public void CharCnnTextClassification()
=> new CnnTextClassification { Enabled = true, ModelName = "char_cnn", DataLimit = 100 }.Run();
[Ignore] [Ignore]
[TestMethod] [TestMethod]


Loading…
Cancel
Save