You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationWithMovieReviews.cs 3.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow;
  6. using NumSharp.Core;
  7. using Newtonsoft.Json;
  8. using System.Linq;
  9. using Keras;
  10. using System.Text.RegularExpressions;
  11. namespace TensorFlowNET.Examples
  12. {
  13. public class TextClassificationWithMovieReviews : Python, IExample
  14. {
  15. string dir = "text_classification_with_movie_reviews";
  16. string dataFile = "imdb.zip";
  17. public void Run()
  18. {
  19. var((train_data, train_labels), (test_data, test_labels)) = PrepareData();
  20. Console.WriteLine($"Training entries: {train_data.size}, labels: {train_labels.size}");
  21. // A dictionary mapping words to an integer index
  22. var word_index = GetWordIndex();
  23. train_data = keras.preprocessing.sequence.pad_sequences(train_data,
  24. value: word_index["<PAD>"],
  25. padding: "post",
  26. maxlen: 256);
  27. }
  28. private ((NDArray, NDArray), (NDArray, NDArray)) PrepareData()
  29. {
  30. Directory.CreateDirectory(dir);
  31. // get model file
  32. string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
  33. string zipFile = Path.Join(dir, $"imdb.zip");
  34. Utility.Web.Download(url, zipFile);
  35. Utility.Compress.UnZip(zipFile, dir);
  36. // prepare training dataset
  37. var x_train = ReadData(Path.Join(dir, "x_train.txt"));
  38. var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
  39. var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
  40. // x_train = x_train[indices_train];
  41. // labels_train = labels_train[indices_train];
  42. var x_test = ReadData(Path.Join(dir, "x_test.txt"));
  43. var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
  44. var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
  45. // x_test = x_test[indices_test];
  46. // labels_test = labels_test[indices_test];
  47. // not completed
  48. /*var xs = x_train.hstack(x_test);
  49. var labels = labels_train.hstack(labels_test);
  50. var idx = x_train.size;
  51. var y_train = labels_train;
  52. var y_test = labels_test;
  53. return ((x_train, y_train), (x_test, y_test));*/
  54. throw new NotImplementedException();
  55. }
  56. private int[][] ReadData(string file)
  57. {
  58. var lines = new List<int[]>();
  59. foreach(var line in File.ReadAllLines(file))
  60. {
  61. var matches = Regex.Matches(line, @"\d+,*");
  62. var data = new int[matches.Count];
  63. for (int i = 0; i < data.Length; i++)
  64. data[i] = Convert.ToInt32(matches[i].Value.Trim(','));
  65. lines.Add(data.ToArray());
  66. }
  67. return lines.ToArray();
  68. }
  69. private Dictionary<string, int> GetWordIndex()
  70. {
  71. var result = new Dictionary<string, int>();
  72. var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
  73. var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
  74. dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
  75. result["<PAD>"] = 0;
  76. result["<START>"] = 1;
  77. result["<UNK>"] = 2; // unknown
  78. result["<UNUSED>"] = 3;
  79. return result;
  80. }
  81. }
  82. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。