| @@ -1,7 +1,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using TensorFlowNET.Examples.Utility; | |||||
| namespace TensorFlowNET.Examples | namespace TensorFlowNET.Examples | ||||
| { | { | ||||
| @@ -24,6 +27,17 @@ namespace TensorFlowNET.Examples | |||||
| // Evaluation Parameters | // Evaluation Parameters | ||||
| string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; | string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; | ||||
| string[] text_words; | |||||
| // Word2Vec Parameters | |||||
| int embedding_size = 200; // Dimension of the embedding vector | |||||
| int max_vocabulary_size = 50000; // Total number of different words in the vocabulary | |||||
| int min_occurrence = 10; // Remove all words that does not appears at least n times | |||||
| int skip_window = 3; // How many words to consider left and right | |||||
| int num_skips = 2; // How many times to reuse an input to generate a label | |||||
| int num_sampled = 64; // Number of negative examples to sample | |||||
| int data_index; | |||||
| public bool Run() | public bool Run() | ||||
| { | { | ||||
| @@ -33,12 +47,78 @@ namespace TensorFlowNET.Examples | |||||
| tf.train.import_meta_graph("graph/word2vec.meta"); | tf.train.import_meta_graph("graph/word2vec.meta"); | ||||
| // Initialize the variables (i.e. assign their default value) | |||||
| var init = tf.global_variables_initializer(); | |||||
| with(tf.Session(graph), sess => | |||||
| { | |||||
| sess.run(init); | |||||
| }); | |||||
| return false; | return false; | ||||
| } | } | ||||
| // Generate training batch for the skip-gram model | |||||
| private void next_batch() | |||||
| { | |||||
| } | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| var url = ""; | |||||
| // Download graph meta | |||||
| var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta"; | |||||
| Web.Download(url, "graph", "word2vec.meta"); | |||||
| // Download a small chunk of Wikipedia articles collection | |||||
| url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; | |||||
| Web.Download(url, "word2vec", "text8.zip"); | |||||
| // Unzip the dataset file. Text has already been processed | |||||
| Compress.UnZip(@"word2vec\text8.zip", "word2vec"); | |||||
| int wordId = 0; | |||||
| text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split(); | |||||
| // Build the dictionary and replace rare words with UNK token | |||||
| var word2id = text_words.GroupBy(x => x) | |||||
| .Select(x => new WordId | |||||
| { | |||||
| Word = x.Key, | |||||
| Occurrence = x.Count() | |||||
| }) | |||||
| .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences | |||||
| .OrderByDescending(x => x.Occurrence) // Retrieve the most common words | |||||
| .Select(x => new WordId | |||||
| { | |||||
| Word = x.Word, | |||||
| Id = ++wordId, // Assign an id to each word | |||||
| Occurrence = x.Occurrence | |||||
| }) | |||||
| .ToList(); | |||||
| // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary | |||||
| var data = (from word in text_words | |||||
| join id in word2id on word equals id.Word into wi | |||||
| from wi2 in wi.DefaultIfEmpty() | |||||
| select wi2 == null ? 0 : wi2.Id).ToList(); | |||||
| word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) }); | |||||
| print($"Words count: {text_words.Length}"); | |||||
| print($"Unique words: {text_words.Distinct().Count()}"); | |||||
| print($"Vocabulary size: {word2id.Count}"); | |||||
| print($"Most common words: {string.Join(", ", word2id.Take(10))}"); | |||||
| } | |||||
| private class WordId | |||||
| { | |||||
| public string Word { get; set; } | |||||
| public int Id { get; set; } | |||||
| public int Occurrence { get; set; } | |||||
| public override string ToString() | |||||
| { | |||||
| return Word + " " + Id + " " + Occurrence; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||