diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs index 7792beae..10340063 100644 --- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs +++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs @@ -269,6 +269,24 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreNotEqual(0, padded[1, i].GetInt32()); } + [TestMethod] + public void PadSequencesPrePaddingTrunc_Larger() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45); + + Assert.AreEqual(4, padded.shape[0]); + Assert.AreEqual(45, padded.shape[1]); + + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 42].GetInt32()); + for (var i = 0; i < 33; i++) + Assert.AreEqual(0, padded[0, i].GetInt32()); + Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 33].GetInt32()); + } + [TestMethod] public void PadSequencesPostPaddingTrunc() { @@ -289,6 +307,24 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreNotEqual(0, padded[1, i].GetInt32()); } + [TestMethod] + public void PadSequencesPostPaddingTrunc_Larger() + { + var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV); + tokenizer.fit_on_texts(texts); + + var sequences = tokenizer.texts_to_sequences(texts); + var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45, padding: "post", truncating: "post"); + + Assert.AreEqual(4, padded.shape[0]); + Assert.AreEqual(45, padded.shape[1]); + + Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32()); + for (var i = 32; i < 45; i++) + Assert.AreEqual(0, padded[0, i].GetInt32()); + Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32()); + } + [TestMethod] public void TextToMatrixBinary() {