You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Word2Vec.cs 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using TensorFlowNET.Examples.Utility;
  8. namespace TensorFlowNET.Examples
  9. {
  10. /// <summary>
  11. /// Implement Word2Vec algorithm to compute vector representations of words.
  12. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
  13. /// </summary>
  14. public class Word2Vec : Python, IExample
  15. {
  16. public int Priority => 12;
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Word2Vec";
  19. // Training Parameters
  20. float learning_rate = 0.1f;
  21. int batch_size = 128;
  22. int num_steps = 3000000;
  23. int display_step = 10000;
  24. int eval_step = 200000;
  25. // Evaluation Parameters
  26. string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
  27. string[] text_words;
  28. // Word2Vec Parameters
  29. int embedding_size = 200; // Dimension of the embedding vector
  30. int max_vocabulary_size = 50000; // Total number of different words in the vocabulary
  31. int min_occurrence = 10; // Remove all words that does not appears at least n times
  32. int skip_window = 3; // How many words to consider left and right
  33. int num_skips = 2; // How many times to reuse an input to generate a label
  34. int num_sampled = 64; // Number of negative examples to sample
  35. int data_index;
  36. public bool Run()
  37. {
  38. PrepareData();
  39. var graph = tf.Graph().as_default();
  40. tf.train.import_meta_graph("graph/word2vec.meta");
  41. // Initialize the variables (i.e. assign their default value)
  42. var init = tf.global_variables_initializer();
  43. with(tf.Session(graph), sess =>
  44. {
  45. sess.run(init);
  46. });
  47. return false;
  48. }
  49. // Generate training batch for the skip-gram model
  50. private void next_batch()
  51. {
  52. }
  53. public void PrepareData()
  54. {
  55. // Download graph meta
  56. var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
  57. Web.Download(url, "graph", "word2vec.meta");
  58. // Download a small chunk of Wikipedia articles collection
  59. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
  60. Web.Download(url, "word2vec", "text8.zip");
  61. // Unzip the dataset file. Text has already been processed
  62. Compress.UnZip(@"word2vec\text8.zip", "word2vec");
  63. int wordId = 0;
  64. text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split();
  65. // Build the dictionary and replace rare words with UNK token
  66. var word2id = text_words.GroupBy(x => x)
  67. .Select(x => new WordId
  68. {
  69. Word = x.Key,
  70. Occurrence = x.Count()
  71. })
  72. .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
  73. .OrderByDescending(x => x.Occurrence) // Retrieve the most common words
  74. .Select(x => new WordId
  75. {
  76. Word = x.Word,
  77. Id = ++wordId, // Assign an id to each word
  78. Occurrence = x.Occurrence
  79. })
  80. .ToList();
  81. // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
  82. var data = (from word in text_words
  83. join id in word2id on word equals id.Word into wi
  84. from wi2 in wi.DefaultIfEmpty()
  85. select wi2 == null ? 0 : wi2.Id).ToList();
  86. word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });
  87. print($"Words count: {text_words.Length}");
  88. print($"Unique words: {text_words.Distinct().Count()}");
  89. print($"Vocabulary size: {word2id.Count}");
  90. print($"Most common words: {string.Join(", ", word2id.Take(10))}");
  91. }
  92. private class WordId
  93. {
  94. public string Word { get; set; }
  95. public int Id { get; set; }
  96. public int Occurrence { get; set; }
  97. public override string ToString()
  98. {
  99. return Word + " " + Id + " " + Occurrence;
  100. }
  101. }
  102. }
  103. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。