You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Tokenizer.cs 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. using Tensorflow.NumPy;
  2. using Serilog.Debugging;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Collections.Specialized;
  6. using System.Data.SqlTypes;
  7. using System.Linq;
  8. using System.Net.Sockets;
  9. using System.Text;
  10. namespace Tensorflow.Keras.Text
  11. {
  12. /// <summary>
  13. /// Text tokenization API.
  14. /// This class allows to vectorize a text corpus, by turning each text into either a sequence of integers
  15. /// (each integer being the index of a token in a dictionary) or into a vector where the coefficient for
  16. /// each token could be binary, based on word count, based on tf-idf...
  17. /// </summary>
  18. /// <remarks>
  19. /// This code is a fairly straight port of the Python code for Keras text preprocessing found at:
  20. /// https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/text.py
  21. /// </remarks>
  22. public class Tokenizer
  23. {
  24. private readonly int num_words;
  25. private readonly string filters;
  26. private readonly bool lower;
  27. private readonly char split;
  28. private readonly bool char_level;
  29. private readonly string oov_token;
  30. private readonly Func<string, IEnumerable<string>> analyzer;
  31. private int document_count = 0;
  32. private Dictionary<string, int> word_docs = new Dictionary<string, int>();
  33. private Dictionary<string, int> word_counts = new Dictionary<string, int>();
  34. public Dictionary<string, int> word_index = null;
  35. public Dictionary<int, string> index_word = null;
  36. private Dictionary<int, int> index_docs = null;
  37. public Tokenizer(
  38. int num_words = -1,
  39. string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
  40. bool lower = true,
  41. char split = ' ',
  42. bool char_level = false,
  43. string oov_token = null,
  44. Func<string, IEnumerable<string>> analyzer = null)
  45. {
  46. this.num_words = num_words;
  47. this.filters = filters;
  48. this.lower = lower;
  49. this.split = split;
  50. this.char_level = char_level;
  51. this.oov_token = oov_token;
  52. this.analyzer = analyzer != null ? analyzer : (text) => TextApi.text_to_word_sequence(text, filters, lower, split);
  53. }
  54. /// <summary>
  55. /// Updates internal vocabulary based on a list of texts.
  56. /// </summary>
  57. /// <param name="texts">A list of strings, each containing one or more tokens.</param>
  58. /// <remarks>Required before using texts_to_sequences or texts_to_matrix.</remarks>
  59. public void fit_on_texts(IEnumerable<string> texts)
  60. {
  61. foreach (var text in texts)
  62. {
  63. IEnumerable<string> seq = null;
  64. document_count += 1;
  65. if (char_level)
  66. {
  67. throw new NotImplementedException("char_level == true");
  68. }
  69. else
  70. {
  71. seq = analyzer(lower ? text.ToLower() : text);
  72. }
  73. foreach (var w in seq)
  74. {
  75. var count = 0;
  76. word_counts.TryGetValue(w, out count);
  77. word_counts[w] = count + 1;
  78. }
  79. foreach (var w in new HashSet<string>(seq))
  80. {
  81. var count = 0;
  82. word_docs.TryGetValue(w, out count);
  83. word_docs[w] = count + 1;
  84. }
  85. }
  86. var wcounts = word_counts.AsEnumerable().ToList();
  87. wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order.
  88. var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>() { oov_token };
  89. sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
  90. if (num_words > 0 - 1)
  91. {
  92. sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList();
  93. }
  94. word_index = new Dictionary<string, int>(sorted_voc.Count);
  95. index_word = new Dictionary<int, string>(sorted_voc.Count);
  96. index_docs = new Dictionary<int, int>(word_docs.Count);
  97. for (int i = 0; i < sorted_voc.Count; i++)
  98. {
  99. word_index.Add(sorted_voc[i], i + 1);
  100. index_word.Add(i + 1, sorted_voc[i]);
  101. }
  102. foreach (var kv in word_docs)
  103. {
  104. var idx = -1;
  105. if (word_index.TryGetValue(kv.Key, out idx))
  106. {
  107. index_docs.Add(idx, kv.Value);
  108. }
  109. }
  110. }
  111. /// <summary>
  112. /// Updates internal vocabulary based on a list of texts.
  113. /// </summary>
  114. /// <param name="texts">A list of list of strings, each containing one token.</param>
  115. /// <remarks>Required before using texts_to_sequences or texts_to_matrix.</remarks>
  116. public void fit_on_texts(IEnumerable<IEnumerable<string>> texts)
  117. {
  118. foreach (var seq in texts)
  119. {
  120. foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
  121. {
  122. var count = 0;
  123. word_counts.TryGetValue(w, out count);
  124. word_counts[w] = count + 1;
  125. }
  126. foreach (var w in new HashSet<string>(word_counts.Keys))
  127. {
  128. var count = 0;
  129. word_docs.TryGetValue(w, out count);
  130. word_docs[w] = count + 1;
  131. }
  132. }
  133. var wcounts = word_counts.AsEnumerable().ToList();
  134. wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
  135. var sorted_voc = (oov_token == null) ? new List<string>() : new List<string>() { oov_token };
  136. sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
  137. if (num_words > 0 - 1)
  138. {
  139. sorted_voc = sorted_voc.Take<string>((oov_token == null) ? num_words : num_words + 1).ToList();
  140. }
  141. word_index = new Dictionary<string, int>(sorted_voc.Count);
  142. index_word = new Dictionary<int, string>(sorted_voc.Count);
  143. index_docs = new Dictionary<int, int>(word_docs.Count);
  144. for (int i = 0; i < sorted_voc.Count; i++)
  145. {
  146. word_index.Add(sorted_voc[i], i + 1);
  147. index_word.Add(i + 1, sorted_voc[i]);
  148. }
  149. foreach (var kv in word_docs)
  150. {
  151. var idx = -1;
  152. if (word_index.TryGetValue(kv.Key, out idx))
  153. {
  154. index_docs.Add(idx, kv.Value);
  155. }
  156. }
  157. }
  158. /// <summary>
  159. /// Updates internal vocabulary based on a list of sequences.
  160. /// </summary>
  161. /// <param name="sequences"></param>
  162. /// <remarks>Required before using sequences_to_matrix (if fit_on_texts was never called).</remarks>
  163. public void fit_on_sequences(IEnumerable<int[]> sequences)
  164. {
  165. throw new NotImplementedException("fit_on_sequences");
  166. }
  167. /// <summary>
  168. /// Transforms each string in texts to a sequence of integers.
  169. /// </summary>
  170. /// <param name="texts"></param>
  171. /// <returns></returns>
  172. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  173. public IList<int[]> texts_to_sequences(IEnumerable<string> texts)
  174. {
  175. return texts_to_sequences_generator(texts).ToArray();
  176. }
  177. /// <summary>
  178. /// Transforms each token in texts to a sequence of integers.
  179. /// </summary>
  180. /// <param name="texts"></param>
  181. /// <returns></returns>
  182. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  183. public IList<int[]> texts_to_sequences(IEnumerable<IEnumerable<string>> texts)
  184. {
  185. return texts_to_sequences_generator(texts).ToArray();
  186. }
  187. public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<string> texts)
  188. {
  189. int oov_index = -1;
  190. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  191. return texts.Select(text =>
  192. {
  193. IEnumerable<string> seq = null;
  194. if (char_level)
  195. {
  196. throw new NotImplementedException("char_level == true");
  197. }
  198. else
  199. {
  200. seq = analyzer(lower ? text.ToLower() : text);
  201. }
  202. return ConvertToSequence(oov_index, seq).ToArray();
  203. });
  204. }
  205. public IEnumerable<int[]> texts_to_sequences_generator(IEnumerable<IEnumerable<string>> texts)
  206. {
  207. int oov_index = -1;
  208. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  209. return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
  210. }
  211. private List<int> ConvertToSequence(int oov_index, IEnumerable<string> seq)
  212. {
  213. var vect = new List<int>();
  214. foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
  215. {
  216. var i = -1;
  217. if (word_index.TryGetValue(w, out i))
  218. {
  219. if (num_words != -1 && i >= num_words)
  220. {
  221. if (oov_index != -1)
  222. {
  223. vect.Add(oov_index);
  224. }
  225. }
  226. else
  227. {
  228. vect.Add(i);
  229. }
  230. }
  231. else if (oov_index != -1)
  232. {
  233. vect.Add(oov_index);
  234. }
  235. }
  236. return vect;
  237. }
  238. /// <summary>
  239. /// Transforms each sequence into a list of text.
  240. /// </summary>
  241. /// <param name="sequences"></param>
  242. /// <returns>A list of texts(strings)</returns>
  243. /// <remarks>Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.</remarks>
  244. public IList<string> sequences_to_texts(IEnumerable<int[]> sequences)
  245. {
  246. return sequences_to_texts_generator(sequences).ToArray();
  247. }
  248. public IEnumerable<string> sequences_to_texts_generator(IEnumerable<IList<int>> sequences)
  249. {
  250. int oov_index = -1;
  251. var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
  252. return sequences.Select(seq =>
  253. {
  254. var bldr = new StringBuilder();
  255. for (var i = 0; i < seq.Count; i++)
  256. {
  257. if (i > 0) bldr.Append(' ');
  258. string word = null;
  259. if (index_word.TryGetValue(seq[i], out word))
  260. {
  261. if (num_words != -1 && i >= num_words)
  262. {
  263. if (oov_index != -1)
  264. {
  265. bldr.Append(oov_token);
  266. }
  267. }
  268. else
  269. {
  270. bldr.Append(word);
  271. }
  272. }
  273. else if (oov_index != -1)
  274. {
  275. bldr.Append(oov_token);
  276. }
  277. }
  278. return bldr.ToString();
  279. });
  280. }
  281. /// <summary>
  282. /// Convert a list of texts to a Numpy matrix.
  283. /// </summary>
  284. /// <param name="texts">A sequence of strings containing one or more tokens.</param>
  285. /// <param name="mode">One of "binary", "count", "tfidf", "freq".</param>
  286. /// <returns></returns>
  287. public NDArray texts_to_matrix(IEnumerable<string> texts, string mode = "binary")
  288. {
  289. return sequences_to_matrix(texts_to_sequences(texts), mode);
  290. }
  291. /// <summary>
  292. /// Convert a list of texts to a Numpy matrix.
  293. /// </summary>
  294. /// <param name="texts">A sequence of lists of strings, each containing one token.</param>
  295. /// <param name="mode">One of "binary", "count", "tfidf", "freq".</param>
  296. /// <returns></returns>
  297. public NDArray texts_to_matrix(IEnumerable<IList<string>> texts, string mode = "binary")
  298. {
  299. return sequences_to_matrix(texts_to_sequences(texts), mode);
  300. }
  301. /// <summary>
  302. /// Converts a list of sequences into a Numpy matrix.
  303. /// </summary>
  304. /// <param name="sequences">A sequence of lists of integers, encoding tokens.</param>
  305. /// <param name="mode">One of "binary", "count", "tfidf", "freq".</param>
  306. /// <returns></returns>
  307. public NDArray sequences_to_matrix(IEnumerable<IList<int>> sequences, string mode = "binary")
  308. {
  309. if (!modes.Contains(mode)) throw new InvalidArgumentError($"Unknown vectorization mode: {mode}");
  310. var word_count = 0;
  311. if (num_words == -1)
  312. {
  313. if (word_index != null)
  314. {
  315. word_count = word_index.Count + 1;
  316. }
  317. else
  318. {
  319. throw new InvalidOperationException("Specifya dimension ('num_words' arugment), or fit on some text data first.");
  320. }
  321. }
  322. else
  323. {
  324. word_count = num_words;
  325. }
  326. if (mode == "tfidf" && this.document_count == 0)
  327. {
  328. throw new InvalidOperationException("Fit the Tokenizer on some text data before using the 'tfidf' mode.");
  329. }
  330. var x = np.zeros((sequences.Count(), word_count));
  331. for (int i = 0; i < sequences.Count(); i++)
  332. {
  333. var seq = sequences.ElementAt(i);
  334. if (seq == null || seq.Count == 0)
  335. continue;
  336. var counts = new Dictionary<int, int>();
  337. var seq_length = seq.Count;
  338. foreach (var j in seq)
  339. {
  340. if (j >= word_count)
  341. continue;
  342. var count = 0;
  343. counts.TryGetValue(j, out count);
  344. counts[j] = count + 1;
  345. }
  346. if (mode == "count")
  347. {
  348. foreach (var kv in counts)
  349. {
  350. var j = kv.Key;
  351. var c = kv.Value + 0.0;
  352. x[i, j] = c;
  353. }
  354. }
  355. else if (mode == "freq")
  356. {
  357. foreach (var kv in counts)
  358. {
  359. var j = kv.Key;
  360. var c = kv.Value + 0.0;
  361. x[i, j] = ((double)c) / seq_length;
  362. }
  363. }
  364. else if (mode == "binary")
  365. {
  366. foreach (var kv in counts)
  367. {
  368. var j = kv.Key;
  369. // var c = kv.Value + 0.0;
  370. x[i, j] = 1.0;
  371. }
  372. }
  373. else if (mode == "tfidf")
  374. {
  375. foreach (var kv in counts)
  376. {
  377. var j = kv.Key;
  378. var c = kv.Value + 0.0;
  379. var id = 0;
  380. var _ = index_docs.TryGetValue(j, out id);
  381. var tf = 1.0 + np.log(c);
  382. var idf = np.log(1.0 + document_count / (1 + id));
  383. x[i, j] = tf * idf;
  384. }
  385. }
  386. }
  387. return x;
  388. }
  389. private string[] modes = new string[] { "binary", "count", "tfidf", "freq" };
  390. }
  391. }