You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tokenizer.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. using NumSharp;
  2. using Serilog.Debugging;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Collections.Specialized;
  6. using System.Linq;
  7. using System.Net.Sockets;
  8. using System.Text;
  9. namespace Tensorflow.Keras.Text
  10. {
  11. /// <summary>
  12. /// Text tokenization API.
  13. /// This class allows to vectorize a text corpus, by turning each text into either a sequence of integers
  14. /// (each integer being the index of a token in a dictionary) or into a vector where the coefficient for
  15. /// each token could be binary, based on word count, based on tf-idf...
  16. /// </summary>
  17. public class Tokenizer
  18. {
  19. private readonly int num_words;
  20. private readonly string filters;
  21. private readonly bool lower;
  22. private readonly char split;
  23. private readonly bool char_level;
  24. private readonly string oov_token;
  25. private readonly Func<string, IEnumerable<string>> analyzer;
  26. private int document_count = 0;
  27. private Dictionary<string, int> word_docs = new Dictionary<string, int>();
  28. private Dictionary<string, int> word_counts = new Dictionary<string, int>();
  29. public Dictionary<string, int> word_index = null;
  30. public Dictionary<int, string> index_word = null;
  31. private Dictionary<int, int> index_docs = null;
  32. public Tokenizer(
  33. int num_words = -1,
  34. string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
  35. bool lower = true,
  36. char split = ' ',
  37. bool char_level = false,
  38. string oov_token = null,
  39. Func<string, IEnumerable<string>> analyzer = null)
  40. {
  41. this.num_words = num_words;
  42. this.filters = filters;
  43. this.lower = lower;
  44. this.split = split;
  45. this.char_level = char_level;
  46. this.oov_token = oov_token;
  47. this.analyzer = analyzer;
  48. }
  49. /// <summary>
  50. /// Updates internal vocabulary based on a list of texts.
  51. /// </summary>
  52. /// <param name="texts">A list of strings, each containing one or more tokens.</param>
  53. /// <remarks>Required before using texts_to_sequences or texts_to_matrix.</remarks>
  54. public void fit_on_texts(IEnumerable<string> texts)
  55. {
  56. foreach (var text in texts)
  57. {
  58. IEnumerable<string> seq = null;
  59. document_count += 1;
  60. if (char_level)
  61. {
  62. throw new NotImplementedException("char_level == true");
  63. }
  64. else
  65. {
  66. seq = analyzer(lower ? text.ToLower() : text);
  67. }
  68. foreach (var w in seq)
  69. {
  70. var count = 0;
  71. word_counts.TryGetValue(w, out count);
  72. word_counts[w] = count + 1;
  73. }
  74. foreach (var w in new HashSet<string>(seq))
  75. {
  76. var count = 0;
  77. word_docs.TryGetValue(w, out count);
  78. word_docs[w] = count + 1;
  79. }
  80. }
  81. var wcounts = word_counts.AsEnumerable().ToList();
  82. wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order.
  83. var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>() { oov_token };
  84. sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
  85. if (num_words > 0 - 1)
  86. {
  87. sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList();
  88. }
  89. word_index = new Dictionary<string, int>(sorted_voc.Count);
  90. index_word = new Dictionary<int, string>(sorted_voc.Count);
  91. index_docs = new Dictionary<int, int>(word_docs.Count);
  92. for (int i = 0; i < sorted_voc.Count; i++)
  93. {
  94. word_index.Add(sorted_voc[i], i + 1);
  95. index_word.Add(i + 1, sorted_voc[i]);
  96. }
  97. foreach (var kv in word_docs)
  98. {
  99. var idx = -1;
  100. if (word_index.TryGetValue(kv.Key, out idx))
  101. {
  102. index_docs.Add(idx, kv.Value);
  103. }
  104. }
  105. }
  106. /// <summary>
  107. /// Updates internal vocabulary based on a list of texts.
  108. /// </summary>
  109. /// <param name="texts">A list of list of strings, each containing one token.</param>
  110. /// <remarks>Required before using texts_to_sequences or texts_to_matrix.</remarks>
  111. public void fit_on_texts(IEnumerable<IEnumerable<string>> texts)
  112. {
  113. foreach (var seq in texts)
  114. {
  115. foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
  116. {
  117. var count = 0;
  118. word_counts.TryGetValue(w, out count);
  119. word_counts[w] = count + 1;
  120. }
  121. foreach (var w in new HashSet<string>(word_counts.Keys))
  122. {
  123. var count = 0;
  124. word_docs.TryGetValue(w, out count);
  125. word_docs[w] = count + 1;
  126. }
  127. }
  128. var wcounts = word_counts.AsEnumerable().ToList();
  129. wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
  130. var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>() { oov_token };
  131. sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
  132. if (num_words > 0 - 1)
  133. {
  134. sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList();
  135. }
  136. word_index = new Dictionary<string, int>(sorted_voc.Count);
  137. index_word = new Dictionary<int, string>(sorted_voc.Count);
  138. index_docs = new Dictionary<int, int>(word_docs.Count);
  139. for (int i = 0; i < sorted_voc.Count; i++)
  140. {
  141. word_index.Add(sorted_voc[i], i + 1);
  142. index_word.Add(i + 1, sorted_voc[i]);
  143. }
  144. foreach (var kv in word_docs)
  145. {
  146. var idx = -1;
  147. if (word_index.TryGetValue(kv.Key, out idx))
  148. {
  149. index_docs.Add(idx, kv.Value);
  150. }
  151. }
  152. }
  153. /// <summary>
  154. /// Updates internal vocabulary based on a list of sequences.
  155. /// </summary>
  156. /// <param name="sequences"></param>
  157. /// <remarks>Required before using sequences_to_matrix (if fit_on_texts was never called).</remarks>
  158. public void fit_on_sequences(IEnumerable<int[]> sequences)
  159. {
  160. throw new NotImplementedException("fit_on_sequences");
  161. }
  162. /// <summary>
  163. /// Transforms each string in texts to a sequence of integers.
  164. /// </summary>
  165. /// <param name="texts"></param>
  166. /// <returns></returns>
  167. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  168. public IList<int[]> texts_to_sequences(IEnumerable<string> texts)
  169. {
  170. return texts_to_sequences_generator(texts).ToArray();
  171. }
  172. /// <summary>
  173. /// Transforms each token in texts to a sequence of integers.
  174. /// </summary>
  175. /// <param name="texts"></param>
  176. /// <returns></returns>
  177. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  178. public IList<int[]> texts_to_sequences(IEnumerable<IEnumerable<string>> texts)
  179. {
  180. return texts_to_sequences_generator(texts).ToArray();
  181. }
  182. public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<string> texts)
  183. {
  184. int oov_index = -1;
  185. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  186. return texts.Select(text =>
  187. {
  188. IEnumerable<string> seq = null;
  189. if (char_level)
  190. {
  191. throw new NotImplementedException("char_level == true");
  192. }
  193. else
  194. {
  195. seq = analyzer(lower ? text.ToLower() : text);
  196. }
  197. return ConvertToSequence(oov_index, seq).ToArray();
  198. });
  199. }
  200. public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<IEnumerable<string>> texts)
  201. {
  202. int oov_index = -1;
  203. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  204. return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
  205. }
  206. private List<int> ConvertToSequence(int oov_index, IEnumerable<string> seq)
  207. {
  208. var vect = new List<int>();
  209. foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
  210. {
  211. var i = -1;
  212. if (word_index.TryGetValue(w, out i))
  213. {
  214. if (num_words != -1 && i >= num_words)
  215. {
  216. if (oov_index != -1)
  217. {
  218. vect.Add(oov_index);
  219. }
  220. }
  221. else
  222. {
  223. vect.Add(i);
  224. }
  225. }
  226. else if (oov_index != -1)
  227. {
  228. vect.Add(oov_index);
  229. }
  230. }
  231. return vect;
  232. }
  233. /// <summary>
  234. /// Transforms each sequence into a list of text.
  235. /// </summary>
  236. /// <param name="sequences"></param>
  237. /// <returns>A list of texts(strings)</returns>
  238. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  239. public IList<string> sequences_to_texts(IEnumerable<int[]> sequences)
  240. {
  241. return sequences_to_texts_generator(sequences).ToArray();
  242. }
  243. public IEnumerable<string> sequences_to_texts_generator(IEnumerable<IList<int>> sequences)
  244. {
  245. int oov_index = -1;
  246. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  247. return sequences.Select(seq =>
  248. {
  249. var bldr = new StringBuilder();
  250. for (var i = 0; i < seq.Count; i++)
  251. {
  252. if (i > 0) bldr.Append(' ');
  253. string word = null;
  254. if (index_word.TryGetValue(seq[i], out word))
  255. {
  256. if (num_words != -1 && i >= num_words)
  257. {
  258. if (oov_index != -1)
  259. {
  260. bldr.Append(oov_token);
  261. }
  262. }
  263. else
  264. {
  265. bldr.Append(word);
  266. }
  267. }
  268. else if (oov_index != -1)
  269. {
  270. bldr.Append(oov_token);
  271. }
  272. }
  273. return bldr.ToString();
  274. });
  275. }
  276. /// <summary>
  277. /// Converts a list of sequences into a Numpy matrix.
  278. /// </summary>
  279. /// <param name="sequences"></param>
  280. /// <returns></returns>
  281. public NDArray sequences_to_matrix(IEnumerable<IList<int>> sequences)
  282. {
  283. throw new NotImplementedException("sequences_to_matrix");
  284. }
  285. }
  286. }