From 58607b09261f2fd6b93956cafcc1f3f560da6af5 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 27 Apr 2019 21:03:52 -0500 Subject: [PATCH] word2vec in progress. --- test/TensorFlowNET.Examples/Text/Word2Vec.cs | 82 +++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/test/TensorFlowNET.Examples/Text/Word2Vec.cs b/test/TensorFlowNET.Examples/Text/Word2Vec.cs index 72dc0a80..54f831aa 100644 --- a/test/TensorFlowNET.Examples/Text/Word2Vec.cs +++ b/test/TensorFlowNET.Examples/Text/Word2Vec.cs @@ -1,7 +1,10 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Text; using Tensorflow; +using TensorFlowNET.Examples.Utility; namespace TensorFlowNET.Examples { @@ -24,6 +27,17 @@ namespace TensorFlowNET.Examples // Evaluation Parameters string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; + string[] text_words; + + // Word2Vec Parameters + int embedding_size = 200; // Dimension of the embedding vector + int max_vocabulary_size = 50000; // Total number of different words in the vocabulary + int min_occurrence = 10; // Remove all words that does not appears at least n times + int skip_window = 3; // How many words to consider left and right + int num_skips = 2; // How many times to reuse an input to generate a label + int num_sampled = 64; // Number of negative examples to sample + + int data_index; public bool Run() { @@ -33,12 +47,78 @@ namespace TensorFlowNET.Examples tf.train.import_meta_graph("graph/word2vec.meta"); + // Initialize the variables (i.e. assign their default value) + var init = tf.global_variables_initializer(); + + with(tf.Session(graph), sess => + { + sess.run(init); + }); + return false; } + // Generate training batch for the skip-gram model + private void next_batch() + { + + } + public void PrepareData() { - var url = ""; + // Download graph meta + var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta"; + Web.Download(url, "graph", "word2vec.meta"); + + // Download a small chunk of Wikipedia articles collection + url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; + Web.Download(url, "word2vec", "text8.zip"); + // Unzip the dataset file. Text has already been processed + Compress.UnZip(@"word2vec\text8.zip", "word2vec"); + + int wordId = 0; + text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split(); + // Build the dictionary and replace rare words with UNK token + var word2id = text_words.GroupBy(x => x) + .Select(x => new WordId + { + Word = x.Key, + Occurrence = x.Count() + }) + .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences + .OrderByDescending(x => x.Occurrence) // Retrieve the most common words + .Select(x => new WordId + { + Word = x.Word, + Id = ++wordId, // Assign an id to each word + Occurrence = x.Occurrence + }) + .ToList(); + + // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary + var data = (from word in text_words + join id in word2id on word equals id.Word into wi + from wi2 in wi.DefaultIfEmpty() + select wi2 == null ? 0 : wi2.Id).ToList(); + + word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) }); + + print($"Words count: {text_words.Length}"); + print($"Unique words: {text_words.Distinct().Count()}"); + print($"Vocabulary size: {word2id.Count}"); + print($"Most common words: {string.Join(", ", word2id.Take(10))}"); + } + + private class WordId + { + public string Word { get; set; } + public int Id { get; set; } + public int Occurrence { get; set; } + + public override string ToString() + { + return Word + " " + Id + " " + Occurrence; + } } } }