You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Imdb.cs 3.2 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow.Keras.Utils;
  6. using NumSharp;
  7. using System.Linq;
  8. namespace Tensorflow.Keras.Datasets
  9. {
  10. /// <summary>
  11. /// This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment
  12. /// (positive/negative). Reviews have been preprocessed, and each review is
  13. /// encoded as a list of word indexes(integers).
  14. /// </summary>
  15. public class Imdb
  16. {
  17. string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
  18. string file_name = "imdb.npz";
  19. string dest_folder = "imdb";
  20. /// <summary>
  21. /// Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
  22. /// </summary>
  23. /// <param name="path"></param>
  24. /// <param name="num_words"></param>
  25. /// <param name="skip_top"></param>
  26. /// <param name="maxlen"></param>
  27. /// <param name="seed"></param>
  28. /// <param name="start_char"></param>
  29. /// <param name="oov_char"></param>
  30. /// <param name="index_from"></param>
  31. /// <returns></returns>
  32. public DatasetPass load_data(string path = "imdb.npz",
  33. int num_words = -1,
  34. int skip_top = 0,
  35. int maxlen = -1,
  36. int seed = 113,
  37. int start_char = 1,
  38. int oov_char= 2,
  39. int index_from = 3)
  40. {
  41. var dst = Download();
  42. var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
  43. var x_train_string = new string[lines.Length];
  44. var y_train = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
  45. for (int i = 0; i < lines.Length; i++)
  46. {
  47. y_train[i] = long.Parse(lines[i].Substring(0, 1));
  48. x_train_string[i] = lines[i].Substring(2);
  49. }
  50. var x_train = np.array(x_train_string);
  51. File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
  52. var x_test_string = new string[lines.Length];
  53. var y_test = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
  54. for (int i = 0; i < lines.Length; i++)
  55. {
  56. y_test[i] = long.Parse(lines[i].Substring(0, 1));
  57. x_test_string[i] = lines[i].Substring(2);
  58. }
  59. var x_test = np.array(x_test_string);
  60. return new DatasetPass
  61. {
  62. Train = (x_train, y_train),
  63. Test = (x_test, y_test)
  64. };
  65. }
  66. (NDArray, NDArray) LoadX(byte[] bytes)
  67. {
  68. var y = np.Load_Npz<byte[]>(bytes);
  69. return (y["x_train.npy"], y["x_test.npy"]);
  70. }
  71. (NDArray, NDArray) LoadY(byte[] bytes)
  72. {
  73. var y = np.Load_Npz<long[]>(bytes);
  74. return (y["y_train.npy"], y["y_test.npy"]);
  75. }
  76. string Download()
  77. {
  78. var dst = Path.Combine(Path.GetTempPath(), dest_folder);
  79. Directory.CreateDirectory(dst);
  80. Web.Download(origin_folder + file_name, dst, file_name);
  81. return dst;
  82. // return Path.Combine(dst, file_name);
  83. }
  84. }
  85. }