diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs
index f4f9de5f..98769a21 100644
--- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs
+++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs
@@ -5,8 +5,6 @@ using System.Text;
using Tensorflow.Keras.Utils;
using NumSharp;
using System.Linq;
-using NumSharp.Utilities;
-using Tensorflow.Queues;
namespace Tensorflow.Keras.Datasets
{
@@ -17,10 +15,8 @@ namespace Tensorflow.Keras.Datasets
///
public class Imdb
{
- //string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
- string origin_folder = "http://ai.stanford.edu/~amaas/data/sentiment/";
- //string file_name = "imdb.npz";
- string file_name = "aclImdb_v1.tar.gz";
+ string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
+ string file_name = "imdb.npz";
string dest_folder = "imdb";
///
@@ -41,66 +37,38 @@ namespace Tensorflow.Keras.Datasets
int maxlen = -1,
int seed = 113,
int start_char = 1,
- int oov_char = 2,
+ int oov_char= 2,
int index_from = 3)
{
var dst = Download();
- var vocab = BuildVocabulary(Path.Combine(dst, "imdb.vocab"), start_char, oov_char, index_from);
-
- var (x_train,y_train) = GetDataSet(Path.Combine(dst, "train"));
- var (x_test, y_test) = GetDataSet(Path.Combine(dst, "test"));
-
- return new DatasetPass
+ var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
+ var x_train_string = new string[lines.Length];
+ var y_train = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
+ for (int i = 0; i < lines.Length; i++)
{
- Train = (x_train, y_train),
- Test = (x_test, y_test)
- };
- }
+ y_train[i] = long.Parse(lines[i].Substring(0, 1));
+ x_train_string[i] = lines[i].Substring(2);
+ }
- private static Dictionary BuildVocabulary(string path,
- int start_char,
- int oov_char,
- int index_from)
- {
- var words = File.ReadAllLines(path);
- var result = new Dictionary();
- var idx = index_from;
+ var x_train = np.array(x_train_string);
- foreach (var word in words)
+ File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
+ var x_test_string = new string[lines.Length];
+ var y_test = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
+ for (int i = 0; i < lines.Length; i++)
{
- result[word] = idx;
- idx += 1;
+ y_test[i] = long.Parse(lines[i].Substring(0, 1));
+ x_test_string[i] = lines[i].Substring(2);
}
- return result;
- }
-
- private static (NDArray, NDArray) GetDataSet(string path)
- {
- var posFiles = Directory.GetFiles(Path.Combine(path, "pos")).Slice(0,10);
- var negFiles = Directory.GetFiles(Path.Combine(path, "neg")).Slice(0,10);
-
- var x_string = new string[posFiles.Length + negFiles.Length];
- var y = new int[posFiles.Length + negFiles.Length];
- var trg = 0;
- var longest = 0;
+ var x_test = np.array(x_test_string);
- for (int i = 0; i < posFiles.Length; i++, trg++)
- {
- y[trg] = 1;
- x_string[trg] = File.ReadAllText(posFiles[i]);
- longest = Math.Max(longest, x_string[trg].Length);
- }
- for (int i = 0; i < posFiles.Length; i++, trg++)
+ return new DatasetPass
{
- y[trg] = 0;
- x_string[trg] = File.ReadAllText(negFiles[i]);
- longest = Math.Max(longest, x_string[trg].Length);
- }
- var x = np.array(x_string);
-
- return (x, y);
+ Train = (x_train, y_train),
+ Test = (x_test, y_test)
+ };
}
(NDArray, NDArray) LoadX(byte[] bytes)
@@ -122,9 +90,8 @@ namespace Tensorflow.Keras.Datasets
Web.Download(origin_folder + file_name, dst, file_name);
- Tensorflow.Keras.Utils.Compress.ExtractTGZ(Path.Combine(dst, file_name), dst);
-
- return Path.Combine(dst, "aclImdb");
+ return dst;
+ // return Path.Combine(dst, file_name);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
index 8bf7cf38..aaca1cb9 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Text
///
/// Updates internal vocabulary based on a list of texts.
///
- /// A list of strings, each containing one or more tokens.
+ ///
/// Required before using texts_to_sequences or texts_to_matrix.
public void fit_on_texts(IEnumerable texts)
{
@@ -90,7 +90,7 @@ namespace Tensorflow.Keras.Text
}
var wcounts = word_counts.AsEnumerable().ToList();
- wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order.
+ wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token };
sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
@@ -120,12 +120,7 @@ namespace Tensorflow.Keras.Text
}
}
- ///
- /// Updates internal vocabulary based on a list of texts.
- ///
- /// A list of list of strings, each containing one token.
- /// Required before using texts_to_sequences or texts_to_matrix.
- public void fit_on_texts(IEnumerable> texts)
+ public void fit_on_texts(IEnumerable> texts)
{
foreach (var seq in texts)
{
@@ -202,7 +197,7 @@ namespace Tensorflow.Keras.Text
///
///
/// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
- public IList texts_to_sequences(IEnumerable> texts)
+ public IList texts_to_sequences(IEnumerable> texts)
{
return texts_to_sequences_generator(texts).ToArray();
}
@@ -229,13 +224,6 @@ namespace Tensorflow.Keras.Text
});
}
- public IEnumerable texts_to_sequences_generator(IEnumerable> texts)
- {
- int oov_index = -1;
- var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
- return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
- }
-
private List ConvertToSequence(int oov_index, IEnumerable seq)
{
var vect = new List();
@@ -256,7 +244,7 @@ namespace Tensorflow.Keras.Text
vect.Add(i);
}
}
- else if (oov_index != -1)
+ else if(oov_index != -1)
{
vect.Add(oov_index);
}
@@ -265,6 +253,13 @@ namespace Tensorflow.Keras.Text
return vect;
}
+ public IEnumerable texts_to_sequences_generator(IEnumerable> texts)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+ return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
+ }
+
///
/// Transforms each sequence into a list of text.
///
@@ -276,7 +271,7 @@ namespace Tensorflow.Keras.Text
return sequences_to_texts_generator(sequences).ToArray();
}
- public IEnumerable sequences_to_texts_generator(IEnumerable> sequences)
+ public IEnumerable sequences_to_texts_generator(IEnumerable sequences)
{
int oov_index = -1;
var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
@@ -285,7 +280,7 @@ namespace Tensorflow.Keras.Text
{
var bldr = new StringBuilder();
- for (var i = 0; i < seq.Count; i++)
+ for (var i = 0; i < seq.Length; i++)
{
if (i > 0) bldr.Append(' ');
@@ -319,7 +314,7 @@ namespace Tensorflow.Keras.Text
///
///
///
- public NDArray sequences_to_matrix(IEnumerable> sequences)
+ public NDArray sequences_to_matrix(IEnumerable sequences)
{
throw new NotImplementedException("sequences_to_matrix");
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
index ad4f91bf..ebde87fa 100644
--- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
@@ -15,23 +15,23 @@ namespace TensorFlowNET.Keras.UnitTest
{
private readonly string[] texts = new string[] {
"It was the best of times, it was the worst of times.",
- "Mr and Mrs Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.",
+ "this is a new dawn, an era to follow the previous era. It can not be said to start anew.",
"It was the best of times, it was the worst of times.",
- "Mr and Mrs Dursley of number four, Privet Drive.",
+ "this is a new dawn, an era to follow the previous era.",
};
private readonly string[][] tokenized_texts = new string[][] {
new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
- new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive","were","proud","to","say","that","they","were","perfectly","normal","thank","you","very","much"},
+ new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era","It","can","not","be","said","to","start","anew" },
new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
- new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive"},
+ new string[] {"this","is","a","new","dawn","an","era","to","follow","the","previous","era" },
};
private readonly string[] processed_texts = new string[] {
"it was the best of times it was the worst of times",
- "mr and mrs dursley of number four privet drive were proud to say that they were perfectly normal thank you very much",
+ "this is a new dawn an era to follow the previous era it can not be said to start anew",
"it was the best of times it was the worst of times",
- "mr and mrs dursley of number four privet drive",
+ "this is a new dawn an era to follow the previous era",
};
private const string OOV = "";
@@ -42,11 +42,11 @@ namespace TensorFlowNET.Keras.UnitTest
var tokenizer = keras.preprocessing.text.Tokenizer();
tokenizer.fit_on_texts(texts);
- Assert.AreEqual(27, tokenizer.word_index.Count);
+ Assert.AreEqual(23, tokenizer.word_index.Count);
Assert.AreEqual(7, tokenizer.word_index["worst"]);
- Assert.AreEqual(12, tokenizer.word_index["number"]);
- Assert.AreEqual(16, tokenizer.word_index["were"]);
+ Assert.AreEqual(12, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(16, tokenizer.word_index["follow"]);
}
[TestMethod]
@@ -56,11 +56,11 @@ namespace TensorFlowNET.Keras.UnitTest
// Use the list version, where the tokenization has already been done.
tokenizer.fit_on_texts(tokenized_texts);
- Assert.AreEqual(27, tokenizer.word_index.Count);
+ Assert.AreEqual(23, tokenizer.word_index.Count);
Assert.AreEqual(7, tokenizer.word_index["worst"]);
- Assert.AreEqual(12, tokenizer.word_index["number"]);
- Assert.AreEqual(16, tokenizer.word_index["were"]);
+ Assert.AreEqual(12, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(16, tokenizer.word_index["follow"]);
}
[TestMethod]
@@ -69,12 +69,12 @@ namespace TensorFlowNET.Keras.UnitTest
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
tokenizer.fit_on_texts(texts);
- Assert.AreEqual(28, tokenizer.word_index.Count);
+ Assert.AreEqual(24, tokenizer.word_index.Count);
Assert.AreEqual(1, tokenizer.word_index[OOV]);
Assert.AreEqual(8, tokenizer.word_index["worst"]);
- Assert.AreEqual(13, tokenizer.word_index["number"]);
- Assert.AreEqual(17, tokenizer.word_index["were"]);
+ Assert.AreEqual(13, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(17, tokenizer.word_index["follow"]);
}
[TestMethod]
@@ -84,12 +84,12 @@ namespace TensorFlowNET.Keras.UnitTest
// Use the list version, where the tokenization has already been done.
tokenizer.fit_on_texts(tokenized_texts);
- Assert.AreEqual(28, tokenizer.word_index.Count);
+ Assert.AreEqual(24, tokenizer.word_index.Count);
Assert.AreEqual(1, tokenizer.word_index[OOV]);
Assert.AreEqual(8, tokenizer.word_index["worst"]);
- Assert.AreEqual(13, tokenizer.word_index["number"]);
- Assert.AreEqual(17, tokenizer.word_index["were"]);
+ Assert.AreEqual(13, tokenizer.word_index["dawn"]);
+ Assert.AreEqual(17, tokenizer.word_index["follow"]);
}
[TestMethod]
@@ -102,7 +102,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(4, sequences.Count);
Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
- Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
}
[TestMethod]
@@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(4, sequences.Count);
Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
- Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
}
[TestMethod]
@@ -200,7 +200,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(4, sequences.Count);
Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
- Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
for (var i = 0; i < sequences.Count; i++)
for (var j = 0; j < sequences[i].Length; j++)
@@ -217,7 +217,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(4, sequences.Count);
Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
- Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ Assert.AreEqual(tokenizer.word_index["previous"], sequences[1][10]);
var oov_count = 0;
for (var i = 0; i < sequences.Count; i++)
@@ -225,7 +225,7 @@ namespace TensorFlowNET.Keras.UnitTest
if (tokenizer.word_index[OOV] == sequences[i][j])
oov_count += 1;
- Assert.AreEqual(9, oov_count);
+ Assert.AreEqual(5, oov_count);
}
[TestMethod]
@@ -238,15 +238,15 @@ namespace TensorFlowNET.Keras.UnitTest
var padded = keras.preprocessing.sequence.pad_sequences(sequences);
Assert.AreEqual(4, padded.shape[0]);
- Assert.AreEqual(22, padded.shape[1]);
+ Assert.AreEqual(20, padded.shape[1]);
var firstRow = padded[0];
var secondRow = padded[1];
- Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32());
for (var i = 0; i < 8; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
- Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
for (var i = 0; i < 20; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}
@@ -269,7 +269,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32());
for (var i = 0; i < 3; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
- Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 3].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32());
for (var i = 0; i < 15; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}
@@ -292,7 +292,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32());
for (var i = 12; i < 15; i++)
Assert.AreEqual(0, padded[0, i].GetInt32());
- Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
for (var i = 0; i < 15; i++)
Assert.AreNotEqual(0, padded[1, i].GetInt32());
}