diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs index f4f9de5f..98769a21 100644 --- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs +++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs @@ -5,8 +5,6 @@ using System.Text; using Tensorflow.Keras.Utils; using NumSharp; using System.Linq; -using NumSharp.Utilities; -using Tensorflow.Queues; namespace Tensorflow.Keras.Datasets { @@ -17,10 +15,8 @@ namespace Tensorflow.Keras.Datasets /// public class Imdb { - //string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; - string origin_folder = "http://ai.stanford.edu/~amaas/data/sentiment/"; - //string file_name = "imdb.npz"; - string file_name = "aclImdb_v1.tar.gz"; + string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; + string file_name = "imdb.npz"; string dest_folder = "imdb"; /// @@ -41,66 +37,38 @@ namespace Tensorflow.Keras.Datasets int maxlen = -1, int seed = 113, int start_char = 1, - int oov_char = 2, + int oov_char= 2, int index_from = 3) { var dst = Download(); - var vocab = BuildVocabulary(Path.Combine(dst, "imdb.vocab"), start_char, oov_char, index_from); - - var (x_train,y_train) = GetDataSet(Path.Combine(dst, "train")); - var (x_test, y_test) = GetDataSet(Path.Combine(dst, "test")); - - return new DatasetPass + var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt")); + var x_train_string = new string[lines.Length]; + var y_train = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64); + for (int i = 0; i < lines.Length; i++) { - Train = (x_train, y_train), - Test = (x_test, y_test) - }; - } + y_train[i] = long.Parse(lines[i].Substring(0, 1)); + x_train_string[i] = lines[i].Substring(2); + } - private static Dictionary BuildVocabulary(string path, - int start_char, - int oov_char, - int index_from) - { - var words = File.ReadAllLines(path); - var result = new Dictionary(); - var idx = index_from; + var x_train = np.array(x_train_string); - foreach (var word in words) + File.ReadAllLines(Path.Combine(dst, "imdb_test.txt")); + var x_test_string = new string[lines.Length]; + var y_test = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64); + for (int i = 0; i < lines.Length; i++) { - result[word] = idx; - idx += 1; + y_test[i] = long.Parse(lines[i].Substring(0, 1)); + x_test_string[i] = lines[i].Substring(2); } - return result; - } - - private static (NDArray, NDArray) GetDataSet(string path) - { - var posFiles = Directory.GetFiles(Path.Combine(path, "pos")).Slice(0,10); - var negFiles = Directory.GetFiles(Path.Combine(path, "neg")).Slice(0,10); - - var x_string = new string[posFiles.Length + negFiles.Length]; - var y = new int[posFiles.Length + negFiles.Length]; - var trg = 0; - var longest = 0; + var x_test = np.array(x_test_string); - for (int i = 0; i < posFiles.Length; i++, trg++) - { - y[trg] = 1; - x_string[trg] = File.ReadAllText(posFiles[i]); - longest = Math.Max(longest, x_string[trg].Length); - } - for (int i = 0; i < posFiles.Length; i++, trg++) + return new DatasetPass { - y[trg] = 0; - x_string[trg] = File.ReadAllText(negFiles[i]); - longest = Math.Max(longest, x_string[trg].Length); - } - var x = np.array(x_string); - - return (x, y); + Train = (x_train, y_train), + Test = (x_test, y_test) + }; } (NDArray, NDArray) LoadX(byte[] bytes) @@ -122,9 +90,8 @@ namespace Tensorflow.Keras.Datasets Web.Download(origin_folder + file_name, dst, file_name); - Tensorflow.Keras.Utils.Compress.ExtractTGZ(Path.Combine(dst, file_name), dst); - - return Path.Combine(dst, "aclImdb"); + return dst; + // return Path.Combine(dst, file_name); } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs index 8bf7cf38..aaca1cb9 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs @@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Text /// /// Updates internal vocabulary based on a list of texts. /// - /// A list of strings, each containing one or more tokens. + /// /// Required before using texts_to_sequences or texts_to_matrix. public void fit_on_texts(IEnumerable texts) { @@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Text } var wcounts = word_counts.AsEnumerable().ToList(); - wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order. + wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token }; sorted_voc.AddRange(word_counts.Select(kv => kv.Key)); @@ -120,12 +120,7 @@ namespace Tensorflow.Keras.Text } } - /// - /// Updates internal vocabulary based on a list of texts. - /// - /// A list of list of strings, each containing one token. - /// Required before using texts_to_sequences or texts_to_matrix. - public void fit_on_texts(IEnumerable> texts) + public void fit_on_texts(IEnumerable> texts) { foreach (var seq in texts) { @@ -202,7 +197,7 @@ namespace Tensorflow.Keras.Text /// /// /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account. - public IList texts_to_sequences(IEnumerable> texts) + public IList texts_to_sequences(IEnumerable> texts) { return texts_to_sequences_generator(texts).ToArray(); } @@ -229,13 +224,6 @@ namespace Tensorflow.Keras.Text }); } - public IEnumerable texts_to_sequences_generator(IEnumerable> texts) - { - int oov_index = -1; - var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); - return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray()); - } - private List ConvertToSequence(int oov_index, IEnumerable seq) { var vect = new List(); @@ -256,7 +244,7 @@ namespace Tensorflow.Keras.Text vect.Add(i); } } - else if (oov_index != -1) + else if(oov_index != -1) { vect.Add(oov_index); } @@ -265,6 +253,13 @@ namespace Tensorflow.Keras.Text return vect; } + public IEnumerable texts_to_sequences_generator(IEnumerable> texts) + { + int oov_index = -1; + var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); + return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray()); + } + /// /// Transforms each sequence into a list of text. /// @@ -276,7 +271,7 @@ namespace Tensorflow.Keras.Text return sequences_to_texts_generator(sequences).ToArray(); } - public IEnumerable sequences_to_texts_generator(IEnumerable> sequences) + public IEnumerable sequences_to_texts_generator(IEnumerable sequences) { int oov_index = -1; var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index); @@ -285,7 +280,7 @@ namespace Tensorflow.Keras.Text { var bldr = new StringBuilder(); - for (var i = 0; i < seq.Count; i++) + for (var i = 0; i < seq.Length; i++) { if (i > 0) bldr.Append(' '); @@ -319,7 +314,7 @@ namespace Tensorflow.Keras.Text /// /// /// - public NDArray sequences_to_matrix(IEnumerable> sequences) + public NDArray sequences_to_matrix(IEnumerable sequences) { throw new NotImplementedException("sequences_to_matrix"); } diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs index ad4f91bf..ebde87fa 100644 --- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -15,23 +15,23 @@ namespace TensorFlowNET.Keras.UnitTest { private readonly string[] texts = new string[] { "It was the best of times, it was the worst of times.", - "Mr and Mrs Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.", + "this is a new dawn, an era to follow the previous era. It can not be said to start anew.", "It was the best of times, it was the worst of times.", - "Mr and Mrs Dursley of number four, Privet Drive.", + "this is a new dawn, an era to follow the previous era.", }; private readonly string[][] tokenized_texts = new string[][] { new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, - new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive","were","proud","to","say","that","they","were","perfectly","normal","thank","you","very","much"}, + new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era","It","can","not","be","said","to","start","anew" }, new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"}, - new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive"}, + new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era" }, }; private readonly string[] processed_texts = new string[] { "it was the best of times it was the worst of times", - "mr and mrs dursley of number four privet drive were proud to say that they were perfectly normal thank you very much", + "this is a new dawn an era to follow the previous era it can not be said to start anew", "it was the best of times it was the worst of times", - "mr and mrs dursley of number four privet drive", + "this is a new dawn an era to follow the previous era", }; private const string OOV = ""; @@ -42,11 +42,11 @@ namespace TensorFlowNET.Keras.UnitTest var tokenizer = keras.preprocessing.text.Tokenizer(); tokenizer.fit_on_texts(texts); - Assert.AreEqual(27, tokenizer.word_index.Count); + Assert.AreEqual(23, tokenizer.word_index.Count); Assert.AreEqual(7, tokenizer.word_index["worst"]); - Assert.AreEqual(12, tokenizer.word_index["number"]); - Assert.AreEqual(16, tokenizer.word_index["were"]); + Assert.AreEqual(12, tokenizer.word_index["dawn"]); + Assert.AreEqual(16, tokenizer.word_index["follow"]); } [TestMethod] @@ -56,11 +56,11 @@ namespace TensorFlowNET.Keras.UnitTest // Use the list version, where the tokenization has already been done. tokenizer.fit_on_texts(tokenized_texts); - Assert.AreEqual(27, tokenizer.word_index.Count); + Assert.AreEqual(23, tokenizer.word_index.Count); Assert.AreEqual(7, tokenizer.word_index["worst"]); - Assert.AreEqual(12, tokenizer.word_index["number"]); - Assert.AreEqual(16, tokenizer.word_index["were"]); + Assert.AreEqual(12, tokenizer.word_index["dawn"]); + Assert.AreEqual(16, tokenizer.word_index["follow"]); } [TestMethod] @@ -69,12 +69,12 @@ namespace TensorFlowNET.Keras.UnitTest var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); tokenizer.fit_on_texts(texts); - Assert.AreEqual(28, tokenizer.word_index.Count); + Assert.AreEqual(24, tokenizer.word_index.Count); Assert.AreEqual(1, tokenizer.word_index[OOV]); Assert.AreEqual(8, tokenizer.word_index["worst"]); - Assert.AreEqual(13, tokenizer.word_index["number"]); - Assert.AreEqual(17, tokenizer.word_index["were"]); + Assert.AreEqual(13, tokenizer.word_index["dawn"]); + Assert.AreEqual(17, tokenizer.word_index["follow"]); } [TestMethod] @@ -84,12 +84,12 @@ namespace TensorFlowNET.Keras.UnitTest // Use the list version, where the tokenization has already been done. tokenizer.fit_on_texts(tokenized_texts); - Assert.AreEqual(28, tokenizer.word_index.Count); + Assert.AreEqual(24, tokenizer.word_index.Count); Assert.AreEqual(1, tokenizer.word_index[OOV]); Assert.AreEqual(8, tokenizer.word_index["worst"]); - Assert.AreEqual(13, tokenizer.word_index["number"]); - Assert.AreEqual(17, tokenizer.word_index["were"]); + Assert.AreEqual(13, tokenizer.word_index["dawn"]); + Assert.AreEqual(17, tokenizer.word_index["follow"]); } [TestMethod] @@ -102,7 +102,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); } [TestMethod] @@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); } [TestMethod] @@ -200,7 +200,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); for (var i = 0; i < sequences.Count; i++) for (var j = 0; j < sequences[i].Length; j++) @@ -217,7 +217,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(4, sequences.Count); Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]); - Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]); + Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]); var oov_count = 0; for (var i = 0; i < sequences.Count; i++) @@ -225,7 +225,7 @@ namespace TensorFlowNET.Keras.UnitTest if (tokenizer.word_index[OOV] == sequences[i][j]) oov_count += 1; - Assert.AreEqual(9, oov_count); + Assert.AreEqual(5, oov_count); } [TestMethod] @@ -238,15 +238,15 @@ namespace TensorFlowNET.Keras.UnitTest var padded = keras.preprocessing.sequence.pad_sequences(sequences); Assert.AreEqual(4, padded.shape[0]); - Assert.AreEqual(22, padded.shape[1]); + Assert.AreEqual(20, padded.shape[1]); var firstRow = padded[0]; var secondRow = padded[1]; - Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19].GetInt32()); + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32()); for (var i = 0; i < 8; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); for (var i = 0; i < 20; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); } @@ -269,7 +269,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32()); for (var i = 0; i < 3; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 3].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32()); for (var i = 0; i < 15; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); } @@ -292,7 +292,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); for (var i = 12; i < 15; i++) Assert.AreEqual(0, padded[0, i].GetInt32()); - Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32()); + Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32()); for (var i = 0; i < 15; i++) Assert.AreNotEqual(0, padded[1, i].GetInt32()); }