diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
new file mode 100644
index 00000000..9742203d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
@@ -0,0 +1,34 @@
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class Pooling1DArgs : LayerArgs
+ {
+ ///
+ /// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
+ ///
+ public IPoolFunction PoolFunction { get; set; }
+
+ ///
+ /// specifying the size of the pooling window.
+ ///
+ public int PoolSize { get; set; }
+
+ ///
+ /// specifying the strides of the pooling operation.
+ ///
+ public int Strides {
+ get { return _strides.HasValue ? _strides.Value : PoolSize; }
+ set { _strides = value; }
+ }
+ private int? _strides = null;
+
+ ///
+ /// The padding method, either 'valid' or 'same'.
+ ///
+ public string Padding { get; set; } = "valid";
+
+ ///
+ /// one of `channels_last` (default) or `channels_first`.
+ ///
+ public string DataFormat { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Datasets/MNIST.cs b/src/TensorFlowNET.Keras/Datasets/MNIST.cs
index 9cdc56b5..8fa61b41 100644
--- a/src/TensorFlowNET.Keras/Datasets/MNIST.cs
+++ b/src/TensorFlowNET.Keras/Datasets/MNIST.cs
@@ -45,8 +45,8 @@ namespace Tensorflow.Keras.Datasets
(NDArray, NDArray) LoadX(byte[] bytes)
{
- var y = np.Load_Npz(bytes);
- return (y["x_train.npy"], y["x_test.npy"]);
+ var x = np.Load_Npz(bytes);
+ return (x["x_train.npy"], x["x_test.npy"]);
}
(NDArray, NDArray) LoadY(byte[] bytes)
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index e735f81e..9b889635 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -325,6 +325,16 @@ namespace Tensorflow.Keras.Layers
return input_layer.InboundNodes[0].Outputs;
}
+ public MaxPooling1D MaxPooling1D(int? pool_size = null,
+ int? strides = null,
+ string padding = "valid")
+ => new MaxPooling1D(new Pooling1DArgs
+ {
+ PoolSize = pool_size ?? 2,
+ Strides = strides ?? (pool_size ?? 2),
+ Padding = padding
+ });
+
public MaxPooling2D MaxPooling2D(TensorShape pool_size = null,
TensorShape strides = null,
string padding = "valid")
@@ -448,6 +458,20 @@ namespace Tensorflow.Keras.Layers
public GlobalAveragePooling2D GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
+ public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last")
+ => new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format });
+
+ public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last")
+ => new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format });
+
+ public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last")
+ => new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format });
+
+ public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last")
+ => new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format });
+
+
+
Activation GetActivationByName(string name)
=> name switch
{
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
new file mode 100644
index 00000000..d2442bec
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class GlobalAveragePooling1D : GlobalPooling1D
+ {
+ public GlobalAveragePooling1D(Pooling1DArgs args)
+ : base(args)
+ {
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ if (data_format == "channels_last")
+ return math_ops.reduce_mean(inputs, new int[] { 1 }, false);
+ else
+ return math_ops.reduce_mean(inputs, new int[] { 2 }, false);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
new file mode 100644
index 00000000..c0d0d831
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class GlobalMaxPooling1D : GlobalPooling1D
+ {
+ public GlobalMaxPooling1D(Pooling1DArgs args)
+ : base(args)
+ {
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ if (data_format == "channels_last")
+ return math_ops.reduce_max(inputs, new int[] { 1 }, false);
+ else
+ return math_ops.reduce_max(inputs, new int[] { 2 }, false);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
new file mode 100644
index 00000000..6ab6b501
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class GlobalMaxPooling2D : GlobalPooling2D
+ {
+ public GlobalMaxPooling2D(Pooling2DArgs args)
+ : base(args)
+ {
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ if (data_format == "channels_last")
+ return math_ops.reduce_max(inputs, new int[] { 1, 2 }, false);
+ else
+ return math_ops.reduce_max(inputs, new int[] { 2, 3 }, false);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs
new file mode 100644
index 00000000..04fadeeb
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/GlobalPooling1D.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Utils;
+
+namespace Tensorflow.Keras.Layers
+{
+ public abstract class GlobalPooling1D : Layer
+ {
+ Pooling1DArgs args;
+ protected string data_format => args.DataFormat;
+ protected InputSpec input_spec;
+
+ public GlobalPooling1D(Pooling1DArgs args) : base(args)
+ {
+ this.args = args;
+ args.DataFormat = conv_utils.normalize_data_format(data_format);
+ input_spec = new InputSpec(ndim: 3);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs
new file mode 100644
index 00000000..c1deb9bf
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/MaxPooling1D.cs
@@ -0,0 +1,14 @@
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Operations;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class MaxPooling1D : Pooling1D
+ {
+ public MaxPooling1D(Pooling1DArgs args)
+ : base(args)
+ {
+ args.PoolFunction = new MaxPoolFunction();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
new file mode 100644
index 00000000..80b36c86
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
@@ -0,0 +1,62 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Utils;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class Pooling1D : Layer
+ {
+ Pooling1DArgs args;
+ InputSpec input_spec;
+
+ public Pooling1D(Pooling1DArgs args)
+ : base(args)
+ {
+ this.args = args;
+ args.Padding = conv_utils.normalize_padding(args.Padding);
+ args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
+ input_spec = new InputSpec(ndim: 3);
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ int[] pool_shape;
+ int[] strides;
+ if (args.DataFormat == "channels_last")
+ {
+ pool_shape = new int[] { 1, args.PoolSize, 1 };
+ strides = new int[] { 1, args.Strides, 1 };
+ }
+ else
+ {
+ pool_shape = new int[] { 1, 1, args.PoolSize };
+ strides = new int[] { 1, 1, args.Strides };
+ }
+
+ var outputs = args.PoolFunction.Apply(
+ inputs,
+ ksize: pool_shape,
+ strides: strides,
+ padding: args.Padding.ToUpper(),
+ data_format: conv_utils.convert_data_format(args.DataFormat, 3));
+
+ return outputs;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
index 34aeb211..994a36d6 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
@@ -10,6 +10,10 @@ namespace Tensorflow.Keras
public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils();
+ public TextApi text => _text;
+
+ private static TextApi _text = new TextApi();
+
public TextVectorization TextVectorization(Func standardize = null,
string split = "whitespace",
int max_tokens = -1,
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
new file mode 100644
index 00000000..29cbec8e
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Preprocessings/Tokenizer.cs
@@ -0,0 +1,444 @@
+using NumSharp;
+using Serilog.Debugging;
+using System;
+using System.Collections.Generic;
+using System.Collections.Specialized;
+using System.Data.SqlTypes;
+using System.Linq;
+using System.Net.Sockets;
+using System.Text;
+
+namespace Tensorflow.Keras.Text
+{
+ ///
+ /// Text tokenization API.
+ /// This class allows to vectorize a text corpus, by turning each text into either a sequence of integers
+ /// (each integer being the index of a token in a dictionary) or into a vector where the coefficient for
+ /// each token could be binary, based on word count, based on tf-idf...
+ ///
+ ///
+ /// This code is a fairly straight port of the Python code for Keras text preprocessing found at:
+ /// https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/text.py
+ ///
+ public class Tokenizer
+ {
+ private readonly int num_words;
+ private readonly string filters;
+ private readonly bool lower;
+ private readonly char split;
+ private readonly bool char_level;
+ private readonly string oov_token;
+ private readonly Func> analyzer;
+
+ private int document_count = 0;
+
+ private Dictionary word_docs = new Dictionary();
+ private Dictionary word_counts = new Dictionary();
+
+ public Dictionary word_index = null;
+ public Dictionary index_word = null;
+
+ private Dictionary index_docs = null;
+
+ public Tokenizer(
+ int num_words = -1,
+ string filters = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
+ bool lower = true,
+ char split = ' ',
+ bool char_level = false,
+ string oov_token = null,
+ Func> analyzer = null)
+ {
+ this.num_words = num_words;
+ this.filters = filters;
+ this.lower = lower;
+ this.split = split;
+ this.char_level = char_level;
+ this.oov_token = oov_token;
+ this.analyzer = analyzer != null ? analyzer : (text) => TextApi.text_to_word_sequence(text, filters, lower, split);
+ }
+
+ ///
+ /// Updates internal vocabulary based on a list of texts.
+ ///
+ /// A list of strings, each containing one or more tokens.
+ /// Required before using texts_to_sequences or texts_to_matrix.
+ public void fit_on_texts(IEnumerable texts)
+ {
+ foreach (var text in texts)
+ {
+ IEnumerable seq = null;
+
+ document_count += 1;
+ if (char_level)
+ {
+ throw new NotImplementedException("char_level == true");
+ }
+ else
+ {
+ seq = analyzer(lower ? text.ToLower() : text);
+ }
+
+ foreach (var w in seq)
+ {
+ var count = 0;
+ word_counts.TryGetValue(w, out count);
+ word_counts[w] = count + 1;
+ }
+
+ foreach (var w in new HashSet(seq))
+ {
+ var count = 0;
+ word_docs.TryGetValue(w, out count);
+ word_docs[w] = count + 1;
+ }
+ }
+
+ var wcounts = word_counts.AsEnumerable().ToList();
+ wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value)); // Note: '-' gives us descending order.
+
+ var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token };
+ sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
+
+ if (num_words > 0 - 1)
+ {
+ sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList();
+ }
+
+ word_index = new Dictionary(sorted_voc.Count);
+ index_word = new Dictionary(sorted_voc.Count);
+ index_docs = new Dictionary(word_docs.Count);
+
+ for (int i = 0; i < sorted_voc.Count; i++)
+ {
+ word_index.Add(sorted_voc[i], i + 1);
+ index_word.Add(i + 1, sorted_voc[i]);
+ }
+
+ foreach (var kv in word_docs)
+ {
+ var idx = -1;
+ if (word_index.TryGetValue(kv.Key, out idx))
+ {
+ index_docs.Add(idx, kv.Value);
+ }
+ }
+ }
+
+ ///
+ /// Updates internal vocabulary based on a list of texts.
+ ///
+ /// A list of list of strings, each containing one token.
+ /// Required before using texts_to_sequences or texts_to_matrix.
+ public void fit_on_texts(IEnumerable> texts)
+ {
+ foreach (var seq in texts)
+ {
+ foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
+ {
+ var count = 0;
+ word_counts.TryGetValue(w, out count);
+ word_counts[w] = count + 1;
+ }
+
+ foreach (var w in new HashSet(word_counts.Keys))
+ {
+ var count = 0;
+ word_docs.TryGetValue(w, out count);
+ word_docs[w] = count + 1;
+ }
+ }
+
+ var wcounts = word_counts.AsEnumerable().ToList();
+ wcounts.Sort((kv1, kv2) => -kv1.Value.CompareTo(kv2.Value));
+
+ var sorted_voc = (oov_token == null) ? new List() : new List() { oov_token };
+ sorted_voc.AddRange(word_counts.Select(kv => kv.Key));
+
+ if (num_words > 0 - 1)
+ {
+ sorted_voc = sorted_voc.Take((oov_token == null) ? num_words : num_words + 1).ToList();
+ }
+
+ word_index = new Dictionary(sorted_voc.Count);
+ index_word = new Dictionary(sorted_voc.Count);
+ index_docs = new Dictionary(word_docs.Count);
+
+ for (int i = 0; i < sorted_voc.Count; i++)
+ {
+ word_index.Add(sorted_voc[i], i + 1);
+ index_word.Add(i + 1, sorted_voc[i]);
+ }
+
+ foreach (var kv in word_docs)
+ {
+ var idx = -1;
+ if (word_index.TryGetValue(kv.Key, out idx))
+ {
+ index_docs.Add(idx, kv.Value);
+ }
+ }
+ }
+
+ ///
+ /// Updates internal vocabulary based on a list of sequences.
+ ///
+ ///
+ /// Required before using sequences_to_matrix (if fit_on_texts was never called).
+ public void fit_on_sequences(IEnumerable sequences)
+ {
+ throw new NotImplementedException("fit_on_sequences");
+ }
+
+ ///
+ /// Transforms each string in texts to a sequence of integers.
+ ///
+ ///
+ ///
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
+ public IList texts_to_sequences(IEnumerable texts)
+ {
+ return texts_to_sequences_generator(texts).ToArray();
+ }
+
+ ///
+ /// Transforms each token in texts to a sequence of integers.
+ ///
+ ///
+ ///
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
+ public IList texts_to_sequences(IEnumerable> texts)
+ {
+ return texts_to_sequences_generator(texts).ToArray();
+ }
+
+ public IEnumerable texts_to_sequences_generator(IEnumerable texts)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+
+ return texts.Select(text =>
+ {
+ IEnumerable seq = null;
+
+ if (char_level)
+ {
+ throw new NotImplementedException("char_level == true");
+ }
+ else
+ {
+ seq = analyzer(lower ? text.ToLower() : text);
+ }
+
+ return ConvertToSequence(oov_index, seq).ToArray();
+ });
+ }
+
+ public IEnumerable texts_to_sequences_generator(IEnumerable> texts)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+ return texts.Select(seq => ConvertToSequence(oov_index, seq).ToArray());
+ }
+
+ private List ConvertToSequence(int oov_index, IEnumerable seq)
+ {
+ var vect = new List();
+ foreach (var w in seq.Select(s => lower ? s.ToLower() : s))
+ {
+ var i = -1;
+ if (word_index.TryGetValue(w, out i))
+ {
+ if (num_words != -1 && i >= num_words)
+ {
+ if (oov_index != -1)
+ {
+ vect.Add(oov_index);
+ }
+ }
+ else
+ {
+ vect.Add(i);
+ }
+ }
+ else if (oov_index != -1)
+ {
+ vect.Add(oov_index);
+ }
+ }
+
+ return vect;
+ }
+
+ ///
+ /// Transforms each sequence into a list of text.
+ ///
+ ///
+ /// A list of texts(strings)
+ /// Only top num_words-1 most frequent words will be taken into account.Only words known by the tokenizer will be taken into account.
+ public IList sequences_to_texts(IEnumerable sequences)
+ {
+ return sequences_to_texts_generator(sequences).ToArray();
+ }
+
+ public IEnumerable sequences_to_texts_generator(IEnumerable> sequences)
+ {
+ int oov_index = -1;
+ var _ = (oov_token != null) && word_index.TryGetValue(oov_token, out oov_index);
+
+ return sequences.Select(seq =>
+ {
+
+ var bldr = new StringBuilder();
+ for (var i = 0; i < seq.Count; i++)
+ {
+ if (i > 0) bldr.Append(' ');
+
+ string word = null;
+ if (index_word.TryGetValue(seq[i], out word))
+ {
+ if (num_words != -1 && i >= num_words)
+ {
+ if (oov_index != -1)
+ {
+ bldr.Append(oov_token);
+ }
+ }
+ else
+ {
+ bldr.Append(word);
+ }
+ }
+ else if (oov_index != -1)
+ {
+ bldr.Append(oov_token);
+ }
+ }
+
+ return bldr.ToString();
+ });
+ }
+
+ ///
+ /// Convert a list of texts to a Numpy matrix.
+ ///
+ /// A sequence of strings containing one or more tokens.
+ /// One of "binary", "count", "tfidf", "freq".
+ ///
+ public NDArray texts_to_matrix(IEnumerable texts, string mode = "binary")
+ {
+ return sequences_to_matrix(texts_to_sequences(texts), mode);
+ }
+
+ ///
+ /// Convert a list of texts to a Numpy matrix.
+ ///
+ /// A sequence of lists of strings, each containing one token.
+ /// One of "binary", "count", "tfidf", "freq".
+ ///
+ public NDArray texts_to_matrix(IEnumerable> texts, string mode = "binary")
+ {
+ return sequences_to_matrix(texts_to_sequences(texts), mode);
+ }
+
+ ///
+ /// Converts a list of sequences into a Numpy matrix.
+ ///
+ /// A sequence of lists of integers, encoding tokens.
+ /// One of "binary", "count", "tfidf", "freq".
+ ///
+ public NDArray sequences_to_matrix(IEnumerable> sequences, string mode = "binary")
+ {
+ if (!modes.Contains(mode)) throw new InvalidArgumentError($"Unknown vectorization mode: {mode}");
+ var word_count = 0;
+
+ if (num_words == -1)
+ {
+ if (word_index != null)
+ {
+ word_count = word_index.Count + 1;
+ }
+ else
+ {
+ throw new InvalidOperationException("Specifya dimension ('num_words' arugment), or fit on some text data first.");
+ }
+ }
+ else
+ {
+ word_count = num_words;
+ }
+
+ if (mode == "tfidf" && this.document_count == 0)
+ {
+ throw new InvalidOperationException("Fit the Tokenizer on some text data before using the 'tfidf' mode.");
+ }
+
+ var x = np.zeros(sequences.Count(), word_count);
+
+ for (int i = 0; i < sequences.Count(); i++)
+ {
+ var seq = sequences.ElementAt(i);
+ if (seq == null || seq.Count == 0)
+ continue;
+
+ var counts = new Dictionary();
+
+ var seq_length = seq.Count;
+
+ foreach (var j in seq)
+ {
+ if (j >= word_count)
+ continue;
+ var count = 0;
+ counts.TryGetValue(j, out count);
+ counts[j] = count + 1;
+ }
+
+ if (mode == "count")
+ {
+ foreach (var kv in counts)
+ {
+ var j = kv.Key;
+ var c = kv.Value;
+ x[i, j] = c;
+ }
+ }
+ else if (mode == "freq")
+ {
+ foreach (var kv in counts)
+ {
+ var j = kv.Key;
+ var c = kv.Value;
+ x[i, j] = ((double)c) / seq_length;
+ }
+ }
+ else if (mode == "binary")
+ {
+ foreach (var kv in counts)
+ {
+ var j = kv.Key;
+ var c = kv.Value;
+ x[i, j] = 1;
+ }
+ }
+ else if (mode == "tfidf")
+ {
+ foreach (var kv in counts)
+ {
+ var j = kv.Key;
+ var c = kv.Value;
+ var id = 0;
+ var _ = index_docs.TryGetValue(j, out id);
+ var tf = 1 + np.log(c);
+ var idf = np.log(1 + document_count / (1 + id));
+ x[i, j] = tf * idf;
+ }
+ }
+ }
+
+ return x;
+ }
+
+ private string[] modes = new string[] { "binary", "count", "tfidf", "freq" };
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Sequence.cs b/src/TensorFlowNET.Keras/Sequence.cs
index a428a568..9f503aee 100644
--- a/src/TensorFlowNET.Keras/Sequence.cs
+++ b/src/TensorFlowNET.Keras/Sequence.cs
@@ -15,7 +15,9 @@
******************************************************************************/
using NumSharp;
+using NumSharp.Utilities;
using System;
+using System.Collections.Generic;
using System.Linq;
namespace Tensorflow.Keras
@@ -34,14 +36,18 @@ namespace Tensorflow.Keras
/// String, 'pre' or 'post'
/// Float or String, padding value.
///
- public NDArray pad_sequences(NDArray sequences,
+ public NDArray pad_sequences(IEnumerable sequences,
int? maxlen = null,
string dtype = "int32",
string padding = "pre",
string truncating = "pre",
object value = null)
{
- int[] length = new int[sequences.size];
+ if (value != null) throw new NotImplementedException("padding with a specific value.");
+ if (padding != "pre" && padding != "post") throw new InvalidArgumentError("padding must be 'pre' or 'post'.");
+ if (truncating != "pre" && truncating != "post") throw new InvalidArgumentError("truncating must be 'pre' or 'post'.");
+
+ var length = sequences.Select(s => s.Length);
if (maxlen == null)
maxlen = length.Max();
@@ -49,19 +55,26 @@ namespace Tensorflow.Keras
if (value == null)
value = 0f;
- var nd = new NDArray(np.int32, new Shape(sequences.size, maxlen.Value));
-#pragma warning disable CS0162 // Unreachable code detected
+ var type = getNPType(dtype);
+ var nd = new NDArray(type, new Shape(length.Count(), maxlen.Value), true);
+
for (int i = 0; i < nd.shape[0]; i++)
-#pragma warning restore CS0162 // Unreachable code detected
{
- switch (sequences[i])
+ var s = sequences.ElementAt(i);
+ if (s.Length > maxlen.Value)
{
- default:
- throw new NotImplementedException("pad_sequences");
+ s = (truncating == "pre") ? s.Slice(s.Length - maxlen.Value, s.Length) : s.Slice(0, maxlen.Value);
}
+ var sliceString = (padding == "pre") ? $"{i},{maxlen - s.Length}:" : $"{i},:{s.Length}";
+ nd[sliceString] = np.array(s);
}
return nd;
}
+
+ private Type getNPType(string typeName)
+ {
+ return System.Type.GetType("NumSharp.np,NumSharp").GetField(typeName).GetValue(null) as Type;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/TextApi.cs b/src/TensorFlowNET.Keras/TextApi.cs
new file mode 100644
index 00000000..8ce8d685
--- /dev/null
+++ b/src/TensorFlowNET.Keras/TextApi.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Keras.Text;
+
+namespace Tensorflow.Keras
+{
+ public class TextApi
+ {
+ public Tensorflow.Keras.Text.Tokenizer Tokenizer(
+ int num_words = -1,
+ string filters = DefaultFilter,
+ bool lower = true,
+ char split = ' ',
+ bool char_level = false,
+ string oov_token = null,
+ Func> analyzer = null)
+ {
+ return new Keras.Text.Tokenizer(num_words, filters, lower, split, char_level, oov_token, analyzer);
+ }
+
+ public static IEnumerable text_to_word_sequence(string text, string filters = DefaultFilter, bool lower = true, char split = ' ')
+ {
+ if (lower)
+ {
+ text = text.ToLower();
+ }
+ var newText = new String(text.Where(c => !filters.Contains(c)).ToArray());
+ return newText.Split(split);
+ }
+
+ private const string DefaultFilter = "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n";
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
new file mode 100644
index 00000000..8bd0055f
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/PoolingTest.cs
@@ -0,0 +1,305 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using System.Linq;
+using Tensorflow;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.Keras.UnitTest
+{
+ ///
+ /// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers
+ ///
+ [TestClass]
+ public class PoolingTest : EagerModeTestBase
+ {
+ private NDArray input_array_1D = np.array(new float[,,]
+ {
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ });
+
+ private NDArray input_array_2D = np.array(new float[,,,]
+ {{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ },{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ }});
+
+ [TestMethod]
+ public void GlobalAverage1DPoolingChannelsLast()
+ {
+ var pool = keras.layers.GlobalAveragePooling1D();
+ var y = pool.Apply(input_array_1D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(5, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {1,2,3,3,3},
+ {4,5,6,3,3},
+ {7,8,9,3,3},
+ {7,8,9,3,3}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalAverage1DPoolingChannelsFirst()
+ {
+ var pool = keras.layers.GlobalAveragePooling1D(data_format: "channels_first");
+ var y = pool.Apply(input_array_1D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(3, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {2.4f, 2.4f, 2.4f},
+ {4.2f, 4.2f, 4.2f},
+ {6.0f, 6.0f, 6.0f},
+ {6.0f, 6.0f, 6.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalAverage2DPoolingChannelsLast()
+ {
+ var pool = keras.layers.GlobalAveragePooling2D();
+ var y = pool.Apply(input_array_2D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(5, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {2.5f, 3.5f, 4.5f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalAverage2DPoolingChannelsFirst()
+ {
+ var pool = keras.layers.GlobalAveragePooling2D(data_format: "channels_first");
+ var y = pool.Apply(input_array_2D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(2, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {2.4f, 4.2f},
+ {6.0f, 6.0f},
+ {2.4f, 4.2f},
+ {6.0f, 6.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalMax1DPoolingChannelsLast()
+ {
+ var pool = keras.layers.GlobalMaxPooling1D();
+ var y = pool.Apply(input_array_1D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(5, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {1,2,3,3,3},
+ {4,5,6,3,3},
+ {7,8,9,3,3},
+ {7,8,9,3,3}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalMax1DPoolingChannelsFirst()
+ {
+ var pool = keras.layers.GlobalMaxPooling1D(data_format: "channels_first");
+ var y = pool.Apply(input_array_1D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(3, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {3.0f, 3.0f, 3.0f},
+ {6.0f, 6.0f, 6.0f},
+ {9.0f, 9.0f, 9.0f},
+ {9.0f, 9.0f, 9.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalMax2DPoolingChannelsLast()
+ {
+ var input_array_2D = np.array(new float[,,,]
+ {{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ },{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ }});
+
+ var pool = keras.layers.GlobalMaxPooling2D();
+ var y = pool.Apply(input_array_2D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(5, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {4.0f, 5.0f, 6.0f, 9.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f, 3.0f, 9.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void GlobalMax2DPoolingChannelsFirst()
+ {
+ var input_array_2D = np.array(new float[,,,]
+ {{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ },{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ }});
+
+ var pool = keras.layers.GlobalMaxPooling2D(data_format: "channels_first");
+ var y = pool.Apply(input_array_2D);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(2, y.shape[1]);
+
+ var expected = np.array(new float[,]
+ {
+ {9.0f, 6.0f},
+ {9.0f, 9.0f},
+ {9.0f, 6.0f},
+ {9.0f, 9.0f}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod, Ignore("There's an error generated from TF complaining about the shape of the pool. Needs further investigation.")]
+ public void Max1DPoolingChannelsLast()
+ {
+ var x = input_array_1D;
+ var pool = keras.layers.MaxPooling1D(pool_size:2, strides:1);
+ var y = pool.Apply(x);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(2, y.shape[1]);
+ Assert.AreEqual(5, y.shape[2]);
+
+ var expected = np.array(new float[,,]
+ {
+ {{2.0f, 2.0f, 3.0f, 3.0f, 3.0f},
+ { 1.0f, 2.0f, 3.0f, 3.0f, 3.0f}},
+
+ {{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f, 3.0f, 3.0f}},
+
+ {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}},
+
+ {{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+
+ [TestMethod]
+ public void Max2DPoolingChannelsLast()
+ {
+ var x = np.array(new float[,,,]
+ {{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,9,3}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ },{
+ {{1,2,3,3,3},{1,2,3,3,3},{1,2,3,3,9}},
+ {{4,5,6,3,3},{4,5,6,3,3},{4,5,6,3,3}},
+ },{
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}},
+ {{7,8,9,3,3},{7,8,9,3,3},{7,8,9,3,3}}
+ }});
+
+ var pool = keras.layers.MaxPooling2D(pool_size: 2, strides: 1);
+ var y = pool.Apply(x);
+
+ Assert.AreEqual(4, y.shape[0]);
+ Assert.AreEqual(1, y.shape[1]);
+ Assert.AreEqual(2, y.shape[2]);
+ Assert.AreEqual(5, y.shape[3]);
+
+ var expected = np.array(new float[,,,]
+ {
+ {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f, 9.0f, 3.0f}}},
+
+
+ {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}},
+
+
+ {{{4.0f, 5.0f, 6.0f, 3.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f, 3.0f, 9.0f}}},
+
+
+ {{{7.0f, 8.0f, 9.0f, 3.0f, 3.0f},
+ {7.0f, 8.0f, 9.0f, 3.0f, 3.0f}}}
+ });
+
+ Assert.AreEqual(expected, y[0].numpy());
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
new file mode 100644
index 00000000..10340063
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
@@ -0,0 +1,413 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Linq;
+using System.Collections.Generic;
+using System.Text;
+using NumSharp;
+using static Tensorflow.KerasApi;
+using Tensorflow;
+using Tensorflow.Keras.Datasets;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace TensorFlowNET.Keras.UnitTest
+{
+ [TestClass]
+ public class PreprocessingTests : EagerModeTestBase
+ {
+ private readonly string[] texts = new string[] {
+ "It was the best of times, it was the worst of times.",
+ "Mr and Mrs Dursley of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.",
+ "It was the best of times, it was the worst of times.",
+ "Mr and Mrs Dursley of number four, Privet Drive.",
+ };
+
+ private readonly string[][] tokenized_texts = new string[][] {
+ new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
+ new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive","were","proud","to","say","that","they","were","perfectly","normal","thank","you","very","much"},
+ new string[] {"It","was","the","best","of","times","it","was","the","worst","of","times"},
+ new string[] {"mr","and","mrs","dursley","of","number","four","privet","drive"},
+ };
+
+ private readonly string[] processed_texts = new string[] {
+ "it was the best of times it was the worst of times",
+ "mr and mrs dursley of number four privet drive were proud to say that they were perfectly normal thank you very much",
+ "it was the best of times it was the worst of times",
+ "mr and mrs dursley of number four privet drive",
+ };
+
+ private const string OOV = "";
+
+ [TestMethod]
+ public void TokenizeWithNoOOV()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ Assert.AreEqual(7, tokenizer.word_index["worst"]);
+ Assert.AreEqual(12, tokenizer.word_index["number"]);
+ Assert.AreEqual(16, tokenizer.word_index["were"]);
+ }
+
+ [TestMethod]
+ public void TokenizeWithNoOOV_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ Assert.AreEqual(7, tokenizer.word_index["worst"]);
+ Assert.AreEqual(12, tokenizer.word_index["number"]);
+ Assert.AreEqual(16, tokenizer.word_index["were"]);
+ }
+
+ [TestMethod]
+ public void TokenizeWithOOV()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(28, tokenizer.word_index.Count);
+
+ Assert.AreEqual(1, tokenizer.word_index[OOV]);
+ Assert.AreEqual(8, tokenizer.word_index["worst"]);
+ Assert.AreEqual(13, tokenizer.word_index["number"]);
+ Assert.AreEqual(17, tokenizer.word_index["were"]);
+ }
+
+ [TestMethod]
+ public void TokenizeWithOOV_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ Assert.AreEqual(28, tokenizer.word_index.Count);
+
+ Assert.AreEqual(1, tokenizer.word_index[OOV]);
+ Assert.AreEqual(8, tokenizer.word_index["worst"]);
+ Assert.AreEqual(13, tokenizer.word_index["number"]);
+ Assert.AreEqual(17, tokenizer.word_index["were"]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequences()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequences_Tkn()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn1()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ // Use the list version, where the tokenization has already been done.
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn2()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ // Use the list version, where the tokenization has already been done.
+ tokenizer.fit_on_texts(tokenized_texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesAndBack_Tkn3()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ // Use the list version, where the tokenization has already been done.
+ var sequences = tokenizer.texts_to_sequences(tokenized_texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ var processed = tokenizer.sequences_to_texts(sequences);
+
+ Assert.AreEqual(4, processed.Count);
+
+ for (var i = 0; i < processed.Count; i++)
+ Assert.AreEqual(processed_texts[i], processed[i]);
+ }
+ [TestMethod]
+ public void TokenizeTextsToSequencesWithOOV()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+
+ for (var i = 0; i < sequences.Count; i++)
+ for (var j = 0; j < sequences[i].Length; j++)
+ Assert.AreNotEqual(tokenizer.word_index[OOV], sequences[i][j]);
+ }
+
+ [TestMethod]
+ public void TokenizeTextsToSequencesWithOOVPresent()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV, num_words:20);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ Assert.AreEqual(4, sequences.Count);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], sequences[0][9]);
+ Assert.AreEqual(tokenizer.word_index["proud"], sequences[1][10]);
+
+ var oov_count = 0;
+ for (var i = 0; i < sequences.Count; i++)
+ for (var j = 0; j < sequences[i].Length; j++)
+ if (tokenizer.word_index[OOV] == sequences[i][j])
+ oov_count += 1;
+
+ Assert.AreEqual(9, oov_count);
+ }
+
+ [TestMethod]
+ public void PadSequencesWithDefaults()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ var padded = keras.preprocessing.sequence.pad_sequences(sequences);
+
+ Assert.AreEqual(4, padded.shape[0]);
+ Assert.AreEqual(22, padded.shape[1]);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19].GetInt32());
+ for (var i = 0; i < 8; i++)
+ Assert.AreEqual(0, padded[0, i].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32());
+ for (var i = 0; i < 20; i++)
+ Assert.AreNotEqual(0, padded[1, i].GetInt32());
+ }
+
+ [TestMethod]
+ public void PadSequencesPrePaddingTrunc()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15);
+
+ Assert.AreEqual(4, padded.shape[0]);
+ Assert.AreEqual(15, padded.shape[1]);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32());
+ for (var i = 0; i < 3; i++)
+ Assert.AreEqual(0, padded[0, i].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 3].GetInt32());
+ for (var i = 0; i < 15; i++)
+ Assert.AreNotEqual(0, padded[1, i].GetInt32());
+ }
+
+ [TestMethod]
+ public void PadSequencesPrePaddingTrunc_Larger()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45);
+
+ Assert.AreEqual(4, padded.shape[0]);
+ Assert.AreEqual(45, padded.shape[1]);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 42].GetInt32());
+ for (var i = 0; i < 33; i++)
+ Assert.AreEqual(0, padded[0, i].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 33].GetInt32());
+ }
+
+ [TestMethod]
+ public void PadSequencesPostPaddingTrunc()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post");
+
+ Assert.AreEqual(4, padded.shape[0]);
+ Assert.AreEqual(15, padded.shape[1]);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32());
+ for (var i = 12; i < 15; i++)
+ Assert.AreEqual(0, padded[0, i].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32());
+ for (var i = 0; i < 15; i++)
+ Assert.AreNotEqual(0, padded[1, i].GetInt32());
+ }
+
+ [TestMethod]
+ public void PadSequencesPostPaddingTrunc_Larger()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
+ tokenizer.fit_on_texts(texts);
+
+ var sequences = tokenizer.texts_to_sequences(texts);
+ var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 45, padding: "post", truncating: "post");
+
+ Assert.AreEqual(4, padded.shape[0]);
+ Assert.AreEqual(45, padded.shape[1]);
+
+ Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32());
+ for (var i = 32; i < 45; i++)
+ Assert.AreEqual(0, padded[0, i].GetInt32());
+ Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10].GetInt32());
+ }
+
+ [TestMethod]
+ public void TextToMatrixBinary()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ var matrix = tokenizer.texts_to_matrix(texts);
+
+ Assert.AreEqual(texts.Length, matrix.shape[0]);
+
+ CompareLists(new double[] { 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray());
+ CompareLists(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray());
+ }
+
+ [TestMethod]
+ public void TextToMatrixCount()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ var matrix = tokenizer.texts_to_matrix(texts, mode:"count");
+
+ Assert.AreEqual(texts.Length, matrix.shape[0]);
+
+ CompareLists(new double[] { 0, 2, 2, 2, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray());
+ CompareLists(new double[] { 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }, matrix[1].ToArray());
+ }
+
+ [TestMethod]
+ public void TextToMatrixFrequency()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ var matrix = tokenizer.texts_to_matrix(texts, mode: "freq");
+
+ Assert.AreEqual(texts.Length, matrix.shape[0]);
+
+ double t12 = 2.0 / 12.0;
+ double o12 = 1.0 / 12.0;
+ double t22 = 2.0 / 22.0;
+ double o22 = 1.0 / 22.0;
+
+ CompareLists(new double[] { 0, t12, t12, t12, o12, t12, t12, o12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray());
+ CompareLists(new double[] { 0, 0, 0, 0, 0, o22, 0, 0, o22, o22, o22, o22, o22, o22, o22, o22, t22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22, o22 }, matrix[1].ToArray());
+ }
+
+ [TestMethod]
+ public void TextToMatrixTDIDF()
+ {
+ var tokenizer = keras.preprocessing.text.Tokenizer();
+ tokenizer.fit_on_texts(texts);
+
+ Assert.AreEqual(27, tokenizer.word_index.Count);
+
+ var matrix = tokenizer.texts_to_matrix(texts, mode: "tfidf");
+
+ Assert.AreEqual(texts.Length, matrix.shape[0]);
+
+ double t1 = 1.1736001944781467;
+ double t2 = 0.69314718055994529;
+ double t3 = 1.860112299086919;
+ double t4 = 1.0986122886681098;
+ double t5 = 0.69314718055994529;
+
+ CompareLists(new double[] { 0, t1, t1, t1, t2, 0, t1, t2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, matrix[0].ToArray());
+ CompareLists(new double[] { 0, 0, 0, 0, 0, 0, 0, 0, t5, t5, t5, t5, t5, t5, t5, t5, t3, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4, t4 }, matrix[1].ToArray());
+ }
+
+ private void CompareLists(IList expected, IList actual)
+ {
+ Assert.AreEqual(expected.Count, actual.Count);
+ for (var i = 0; i < expected.Count; i++)
+ {
+ Assert.AreEqual(expected[i], actual[i]);
+ }
+ }
+
+ }
+}