From 03336c8b7e79c2918062910ba3691b8e2ea9b2be Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 7 Jul 2019 23:58:33 -0500 Subject: [PATCH] add Char CNN example. --- .../TextProcess/CnnTextClassification.cs | 18 ++- .../TextProcess/DataHelpers.cs | 2 - .../TextProcess/cnn_models/CharCnn.cs | 151 ++++++++++++++++++ .../TextProcess/cnn_models/ITextModel.cs | 3 - .../TextProcess/cnn_models/WordCnn.cs | 18 --- 5 files changed, 166 insertions(+), 26 deletions(-) create mode 100644 test/TensorFlowNET.Examples/TextProcess/cnn_models/CharCnn.cs diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index d75f279c..a100fde7 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -40,6 +40,7 @@ namespace TensorFlowNET.Examples float loss_value = 0; double max_accuracy = 0; + int alphabet_size = -1; int vocabulary_size = -1; NDArray train_x, valid_x, train_y, valid_y; @@ -117,10 +118,18 @@ namespace TensorFlowNET.Examples Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); Console.WriteLine("Building dataset..."); + var (x, y) = (new int[0][], new int[0]); - var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - vocabulary_size = len(word_dict); - var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); + if(ModelName == "char_cnn") + { + (x, y, alphabet_size) = DataHelpers.build_char_dataset(TRAIN_PATH, "char_cnn", CHAR_MAX_LEN); + } + else + { + var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); + vocabulary_size = len(word_dict); + (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); + } Console.WriteLine("\tDONE "); @@ -162,6 +171,9 @@ namespace TensorFlowNET.Examples case "word_cnn": textModel = new WordCnn(vocabulary_size, WORD_MAX_LEN, NUM_CLASS); break; + case "char_cnn": + textModel = new CharCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); + break; } return graph; diff --git a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs index 8b5f79e2..b8e41640 100644 --- a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs @@ -55,8 +55,6 @@ namespace TensorFlowNET.Examples public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true) { - if (model != "vd_cnn") - throw new NotImplementedException(model); string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; /*if (step == "train") df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/ diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/CharCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/CharCnn.cs new file mode 100644 index 00000000..43097b5b --- /dev/null +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/CharCnn.cs @@ -0,0 +1,151 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow; +using static Tensorflow.Python; + +namespace TensorFlowNET.Examples.Text +{ + public class CharCnn : ITextModel + { + public CharCnn(int alphabet_size, int document_max_len, int num_class) + { + var learning_rate = 0.001f; + var filter_sizes = new int[] { 7, 7, 3, 3, 3, 3 }; + var num_filters = 256; + var kernel_initializer = tf.truncated_normal_initializer(stddev: 0.05f); + + var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); + var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); + var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); + var global_step = tf.Variable(0, trainable: false); + var keep_prob = tf.where(is_training, 0.5f, 1.0f); + + var x_one_hot = tf.one_hot(x, alphabet_size); + var x_expanded = tf.expand_dims(x_one_hot, -1); + + // ============= Convolutional Layers ============= + Tensor pool1 = null, pool2 = null; + Tensor conv3 = null, conv4 = null, conv5 = null, conv6 = null; + Tensor h_pool = null; + + with(tf.name_scope("conv-maxpool-1"), delegate + { + var conv1 = tf.layers.conv2d(x_expanded, + filters: num_filters, + kernel_size: new[] { filter_sizes[0], alphabet_size }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + + pool1 = tf.layers.max_pooling2d(conv1, + pool_size: new[] { 3, 1 }, + strides: new[] { 3, 1 }); + pool1 = tf.transpose(pool1, new[] { 0, 1, 3, 2 }); + }); + + with(tf.name_scope("conv-maxpool-2"), delegate + { + var conv2 = tf.layers.conv2d(pool1, + filters: num_filters, + kernel_size: new[] {filter_sizes[1], num_filters }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + + pool2 = tf.layers.max_pooling2d(conv2, + pool_size: new[] { 3, 1 }, + strides: new[] { 3, 1 }); + pool2 = tf.transpose(pool2, new[] { 0, 1, 3, 2 }); + }); + + with(tf.name_scope("conv-3"), delegate + { + conv3 = tf.layers.conv2d(pool2, + filters: num_filters, + kernel_size: new[] { filter_sizes[2], num_filters }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + conv3 = tf.transpose(conv3, new[] { 0, 1, 3, 2 }); + }); + + with(tf.name_scope("conv-4"), delegate + { + conv4 = tf.layers.conv2d(conv3, + filters: num_filters, + kernel_size: new[] { filter_sizes[3], num_filters }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + conv4 = tf.transpose(conv4, new[] { 0, 1, 3, 2 }); + }); + + with(tf.name_scope("conv-5"), delegate + { + conv5 = tf.layers.conv2d(conv4, + filters: num_filters, + kernel_size: new[] { filter_sizes[4], num_filters }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + conv5 = tf.transpose(conv5, new[] { 0, 1, 3, 2 }); + }); + + with(tf.name_scope("conv-maxpool-6"), delegate + { + conv6 = tf.layers.conv2d(conv5, + filters: num_filters, + kernel_size: new[] { filter_sizes[5], num_filters }, + kernel_initializer: kernel_initializer, + activation: tf.nn.relu()); + + var pool6 = tf.layers.max_pooling2d(conv6, + pool_size: new[] { 3, 1 }, + strides: new[] { 3, 1 }); + pool6 = tf.transpose(pool6, new[] { 0, 2, 1, 3 }); + + h_pool = tf.reshape(pool6, new[] { -1, 34 * num_filters }); + }); + + // ============= Fully Connected Layers ============= + Tensor fc1_out = null, fc2_out = null; + Tensor logits = null; + Tensor predictions = null; + + with(tf.name_scope("fc-1"), delegate + { + fc1_out = tf.layers.dense(h_pool, + 1024, + activation: tf.nn.relu(), + kernel_initializer: kernel_initializer); + }); + + with(tf.name_scope("fc-2"), delegate + { + fc2_out = tf.layers.dense(fc1_out, + 1024, + activation: tf.nn.relu(), + kernel_initializer: kernel_initializer); + }); + + with(tf.name_scope("fc-3"), delegate + { + logits = tf.layers.dense(fc2_out, + num_class, + kernel_initializer: kernel_initializer); + predictions = tf.argmax(logits, -1, output_type: tf.int32); + }); + + with(tf.name_scope("loss"), delegate + { + var y_one_hot = tf.one_hot(y, num_class); + var loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot)); + var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step); + }); + + with(tf.name_scope("accuracy"), delegate + { + var correct_predictions = tf.equal(predictions, y); + var accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name: "accuracy"); + }); + } + } +} diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs index adf23c23..ab4bb372 100644 --- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/ITextModel.cs @@ -7,8 +7,5 @@ namespace TensorFlowNET.Examples.Text { interface ITextModel { - Tensor is_training { get; } - Tensor x { get;} - Tensor y { get; } } } diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs index 33e2fac5..796c1c1a 100644 --- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/WordCnn.cs @@ -9,24 +9,6 @@ namespace TensorFlowNET.Examples.Text { public class WordCnn : ITextModel { - private int embedding_size; - private int[] filter_sizes; - private int[] num_filters; - private int[] num_blocks; - private float learning_rate; - private IInitializer cnn_initializer; - private IInitializer fc_initializer; - public Tensor x { get; private set; } - public Tensor y { get; private set; } - public Tensor is_training { get; private set; } - private RefVariable global_step; - private RefVariable embeddings; - private Tensor x_emb; - private Tensor x_expanded; - private Tensor logits; - private Tensor predictions; - private Tensor loss; - public WordCnn(int vocabulary_size, int document_max_len, int num_class) { var embedding_size = 128;