From 4ef675faf9a9a999bbb7a4f511bf2894a8eaac0f Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 13 Feb 2021 08:15:08 -0600 Subject: [PATCH] Consolidate MapDataset function. --- .../Tensorflow.Console.csproj | 2 +- src/TensorFlowNET.Core/APIs/c_api.cs | 2 +- src/TensorFlowNET.Core/APIs/tf.strings.cs | 24 ++++++++++ src/TensorFlowNET.Core/Data/DatasetV2.cs | 5 +- src/TensorFlowNET.Core/Data/IDatasetV2.cs | 6 ++- src/TensorFlowNET.Core/Data/MapDataset.cs | 2 +- .../Framework/Models/TensorSpec.cs | 2 +- .../Gradients/image_grad.cs | 2 +- .../Preprocessing/PreprocessingLayerArgs.cs | 10 ++++ .../Preprocessing/TextVectorizationArgs.cs | 15 ++++++ .../Operations/string_ops.cs | 48 +++++++++++++++++++ .../Engine/CombinerPreprocessingLayer.cs | 18 +++++++ .../DataAdapters/TensorLikeDataAdapter.cs | 4 +- .../Engine/MetricsContainer.cs | 4 +- src/TensorFlowNET.Keras/Engine/Sequential.cs | 3 +- src/TensorFlowNET.Keras/KerasInterface.cs | 4 +- .../Layers/Core/InputLayer.cs | 2 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 2 + .../Layers/Preprocessing/TextVectorization.cs | 20 ++++++++ .../Layers/Reshaping/Flatten.cs | 3 +- .../Layers/Reshaping/Reshape.cs | 2 +- ...tUtils.get_training_or_validation_split.cs | 12 +++-- .../DatasetUtils.index_directory.cs | 15 +++--- .../Preprocessings/Preprocessing.cs | 18 ++++++- ...processing.image_dataset_from_directory.cs | 24 ++++++++-- ...eprocessing.paths_and_labels_to_dataset.cs | 27 +++++++++++ .../Tensorflow.Keras.csproj | 6 +-- .../Tensorflow.Benchmark.csproj | 2 +- tensorflowlib/README.md | 12 ++++- 29 files changed, 258 insertions(+), 38 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs create mode 100644 src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj index 9359d975..d363b00f 100644 --- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj +++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1 + net5.0 Tensorflow Tensorflow AnyCPU;x64 diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 11c17abd..10f678e0 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -43,7 +43,7 @@ namespace Tensorflow /// public partial class c_api { - public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow"; + public const string TensorFlowLibName = "tensorflow"; public static string StringPiece(IntPtr handle) { diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index be0cf765..f580a67d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -24,6 +24,30 @@ namespace Tensorflow { string_ops ops = new string_ops(); + /// + /// Converts all uppercase characters into their respective lowercase replacements. + /// + /// + /// + /// + /// + public Tensor lower(Tensor input, string encoding = "", string name = null) + => ops.lower(input: input, encoding: encoding, name: name); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor regex_replace(Tensor input, string pattern, string rewrite, + bool replace_global = true, string name = null) + => ops.regex_replace(input, pattern, rewrite, + replace_global: replace_global, name: name); + /// /// Return substrings from `Tensor` of strings. /// diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 2abe9970..0297eb6b 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -14,6 +14,7 @@ namespace Tensorflow public class DatasetV2 : IDatasetV2 { protected dataset_ops ops = new dataset_ops(); + public string[] class_names { get; set; } public Tensor variant_tensor { get; set; } public TensorSpec[] structure { get; set; } @@ -54,7 +55,7 @@ namespace Tensorflow public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); - public IDatasetV2 map(Func map_func, + public IDatasetV2 map(Func map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = true, bool use_legacy_function = false) @@ -64,7 +65,7 @@ namespace Tensorflow preserve_cardinality: preserve_cardinality, use_legacy_function: use_legacy_function); - public IDatasetV2 map(Func map_func, int num_parallel_calls = -1) + public IDatasetV2 map(Func map_func, int num_parallel_calls) => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); public IDatasetV2 flat_map(Func map_func) diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 4d9b00d2..d0e372dc 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -6,6 +6,8 @@ namespace Tensorflow { public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> { + string[] class_names { get; set; } + Tensor variant_tensor { get; set; } TensorShape[] output_shapes { get; } @@ -62,13 +64,13 @@ namespace Tensorflow IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); - IDatasetV2 map(Func map_func, + IDatasetV2 map(Func map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = true, bool use_legacy_function = false); IDatasetV2 map(Func map_func, - int num_parallel_calls = -1); + int num_parallel_calls); IDatasetV2 flat_map(Func map_func); diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 6f753f55..5786a340 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -10,7 +10,7 @@ namespace Tensorflow public class MapDataset : UnaryDataset { public MapDataset(IDatasetV2 input_dataset, - Func map_func, + Func map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs index 0d5aa7d0..5f333547 100644 --- a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs +++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Framework.Models if (_shape.ndim == 0) throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); - return new TensorSpec(_shape.dims[1..], _dtype); + return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype); } public TensorSpec _batch(int dim = -1) diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs index ccc70fea..08636298 100644 --- a/src/TensorFlowNET.Core/Gradients/image_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Gradients var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); Tensor image_shape = null; if (shape.is_fully_defined()) - image_shape = constant_op.constant(image.shape[1..3]); + image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray()); else image_shape = array_ops.shape(image)["1:3"]; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs new file mode 100644 index 00000000..28ccf9f7 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class PreprocessingLayerArgs : LayerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs new file mode 100644 index 00000000..ab55da4e --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TextVectorizationArgs : PreprocessingLayerArgs + { + public Func Standardize { get; set; } + public string Split { get; set; } = "standardize"; + public int MaxTokens { get; set; } = -1; + public string OutputMode { get; set; } = "int"; + public int OutputSequenceLength { get; set; } = -1; + } +} diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index 49b4c3e9..deba7be4 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -20,6 +20,54 @@ namespace Tensorflow { public class string_ops { + public Tensor lower(Tensor input, string encoding = "", string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "StringLower", name, + null, + input, encoding); + + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new + { + input, + encoding + }); + + return _op.output; + } + + public Tensor regex_replace(Tensor input, string pattern, string rewrite, + bool replace_global = true, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "StaticRegexReplace", name, + null, + input, + "pattern", pattern, + "rewrite", rewrite, + "replace_global", replace_global); + + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new + { + input, + pattern, + rewrite, + replace_global + }); + + return _op.output; + } + /// /// Return substrings from `Tensor` of strings. /// diff --git a/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs new file mode 100644 index 00000000..11adfe9f --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine +{ + public class CombinerPreprocessingLayer : Layer + { + PreprocessingLayerArgs args; + + public CombinerPreprocessingLayer(PreprocessingLayerArgs args) + : base(args) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 478e0e8b..98fd4741 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters dataset = slice_inputs(indices_dataset, inputs); } - Tensor permutation(Tensor tensor) + Tensors permutation(Tensors tensor) { var indices = math_ops.range(num_samples, dtype: dtypes.int64); if (args.Shuffle) @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters .Select(x => gen_array_ops.gather_v2(x, indices, 0)) .ToArray(); return new Tensors(results); - }); + }, -1); return dataset.with_options(new DatasetOptions { }); } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index f0abc29f..3870c29b 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -62,8 +62,8 @@ namespace Tensorflow.Keras.Engine { var y_t_rank = y_t.rank; var y_p_rank = y_p.rank; - var y_t_last_dim = y_t.shape[^1]; - var y_p_last_dim = y_p.shape[^1]; + var y_t_last_dim = y_t.shape[y_t.shape.Length - 1]; + var y_p_last_dim = y_p.shape[y_p.shape.Length - 1]; bool is_binary = y_p_last_dim == 1; bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 58cc73f3..d06810f5 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System.Linq; using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; @@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Engine if (set_inputs) { // If an input layer (placeholder) is available. - outputs = layer.InboundNodes[^1].Outputs; + outputs = layer.InboundNodes.Last().Outputs; inputs = layer_utils.get_source_inputs(outputs[0]); built = true; _has_explicit_input_shape = true; diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 50f80b6d..b5209e76 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -11,6 +11,7 @@ using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Models; using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras { @@ -27,6 +28,7 @@ namespace Tensorflow.Keras public OptimizerApi optimizers { get; } = new OptimizerApi(); public MetricsApi metrics { get; } = new MetricsApi(); public ModelsApi models { get; } = new ModelsApi(); + public KerasUtils utils { get; } = new KerasUtils(); public Sequential Sequential(List layers = null, string name = null) @@ -73,7 +75,7 @@ namespace Tensorflow.Keras Tensor tensor = null) { if (batch_input_shape != null) - shape = batch_input_shape.dims[1..]; + shape = batch_input_shape.dims.Skip(1).ToArray(); var args = new InputLayerArgs { diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index b1e2844c..55e20166 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers if (BatchInputShape != null) { args.BatchSize = BatchInputShape.dims[0]; - args.InputShape = BatchInputShape.dims[1..]; + args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); } // moved to base class diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index dc9ee749..e735f81e 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -9,6 +9,8 @@ namespace Tensorflow.Keras.Layers { public partial class LayersApi { + public Preprocessing preprocessing { get; } = new Preprocessing(); + /// /// Functional interface for the batch normalization layer. /// http://arxiv.org/abs/1502.03167 diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs new file mode 100644 index 00000000..a66be94b --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class TextVectorization : CombinerPreprocessingLayer + { + TextVectorizationArgs args; + + public TextVectorization(TextVectorizationArgs args) + : base(args) + { + args.DType = TF_DataType.TF_STRING; + // string standardize = "lower_and_strip_punctuation", + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index f376c7d5..1b59ca82 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using Tensorflow.Framework; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; @@ -45,7 +46,7 @@ namespace Tensorflow.Keras.Layers return array_ops.reshape(inputs, new[] { batch_dim, -1 }); } - var non_batch_dims = ((int[])input_shape)[1..]; + var non_batch_dims = ((int[])input_shape).Skip(1).ToArray(); var num = 1; if (non_batch_dims.Length > 0) { diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index e8f7d01c..ecabc8f1 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Layers public override TensorShape ComputeOutputShape(TensorShape input_shape) { - if (input_shape.dims[1..].Contains(-1)) + if (input_shape.dims.Skip(1).Contains(-1)) { throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs index 4e089fb6..a80e960a 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; namespace Tensorflow.Keras.Preprocessings { @@ -17,18 +18,21 @@ namespace Tensorflow.Keras.Preprocessings float validation_split, string subset) { + if (string.IsNullOrEmpty(subset)) + return (samples, labels); + var num_val_samples = Convert.ToInt32(samples.Length * validation_split); if (subset == "training") { Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); - samples = samples[..^num_val_samples]; - labels = labels[..^num_val_samples]; + samples = samples.Take(samples.Length - num_val_samples).ToArray(); + labels = labels.Take(labels.Length - num_val_samples).ToArray(); } else if (subset == "validation") { Console.WriteLine($"Using {num_val_samples} files for validation."); - samples = samples[(samples.Length - num_val_samples)..]; - labels = labels[(labels.Length - num_val_samples)..]; + samples = samples.Skip(samples.Length - num_val_samples).ToArray(); + labels = labels.Skip(labels.Length - num_val_samples).ToArray(); } else throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs index cf7ef12c..6b62b9b2 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs @@ -21,40 +21,41 @@ namespace Tensorflow.Keras.Preprocessings /// file_paths, labels, class_names /// public (string[], int[], string[]) index_directory(string directory, + string labels, string[] formats = null, string[] class_names = null, bool shuffle = true, int? seed = null, bool follow_links = false) { - var labels = new List(); + var label_list = new List(); var file_paths = new List(); var class_dirs = Directory.GetDirectories(directory); - class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray(); + class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray(); for (var label = 0; label < class_dirs.Length; label++) { var files = Directory.GetFiles(class_dirs[label]); file_paths.AddRange(files); - labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); + label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); } - var return_labels = labels.Select(x => x).ToArray(); + var return_labels = label_list.Select(x => x).ToArray(); var return_file_paths = file_paths.Select(x => x).ToArray(); if (shuffle) { if (!seed.HasValue) seed = np.random.randint((long)1e6); - var random_index = np.arange(labels.Count); + var random_index = np.arange(label_list.Count); var rng = np.random.RandomState(seed.Value); rng.shuffle(random_index); var index = random_index.ToArray(); - for (int i = 0; i < labels.Count; i++) + for (int i = 0; i < label_list.Count; i++) { - return_labels[i] = labels[index[i]]; + return_labels[i] = label_list[index[i]]; return_file_paths[i] = file_paths[index[i]]; } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs index 2d418509..6c33e9f5 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs @@ -1,4 +1,7 @@ -using Tensorflow.Keras.Preprocessings; +using System; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Preprocessings; namespace Tensorflow.Keras { @@ -6,5 +9,18 @@ namespace Tensorflow.Keras { public Sequence sequence => new Sequence(); public DatasetUtils dataset_utils => new DatasetUtils(); + + public TextVectorization TextVectorization(Func standardize = null, + string split = "standardize", + int max_tokens = -1, + string output_mode = "int", + int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs + { + Standardize = standardize, + Split = split, + MaxTokens = max_tokens, + OutputMode = output_mode, + OutputSequenceLength = output_sequence_length + }); } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index c17799ac..8d7513a6 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Keras num_channels = 3; var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, + labels, formats: WHITELIST_FORMATS, class_names: class_names, shuffle: shuffle, @@ -64,13 +65,30 @@ namespace Tensorflow.Keras string[] class_names = null, int batch_size = 32, bool shuffle = true, + int max_length = -1, int? seed = null, float validation_split = 0.2f, - string subset = null) + string subset = null, + bool follow_links = false) { - + var (file_paths, label_list, class_name_list) = dataset_utils.index_directory( + directory, + labels, + formats: new[] { ".txt" }, + class_names: class_names, + shuffle: shuffle, + seed: seed, + follow_links: follow_links); - return null; + (file_paths, label_list) = dataset_utils.get_training_or_validation_split( + file_paths, label_list, validation_split, subset); + + var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length); + if (shuffle) + dataset = dataset.shuffle(batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + dataset.class_names = class_name_list; + return dataset; } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs index abf07735..dba2cded 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs @@ -1,4 +1,5 @@ using System; +using System.IO; using static Tensorflow.Binding; namespace Tensorflow.Keras @@ -34,5 +35,31 @@ namespace Tensorflow.Keras // img.set_shape((image_size[0], image_size[1], num_channels)); return img; } + + public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, + int[] labels, + string label_mode, + int num_classes, + int max_length = -1) + { + var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); + var string_ds = path_ds.map(x => path_to_string_content(x, max_length)); + + if (label_mode == "int") + { + var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); + string_ds = tf.data.Dataset.zip(string_ds, label_ds); + } + + return string_ds; + } + + Tensor path_to_string_content(Tensor path, int max_length) + { + var txt = tf.io.read_file(path); + if (max_length > -1) + txt = tf.strings.substr(txt, 0, max_length); + return txt; + } } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 5694e8f5..9eeb4634 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -6,7 +6,7 @@ 8.0 Tensorflow.Keras AnyCPU;x64 - 0.4.0 + 0.4.1 Haiping Chen Keras for .NET Apache 2.0, Haiping Chen 2020 @@ -34,8 +34,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac Git true Open.snk - 0.4.0.0 - 0.4.0.0 + 0.4.1.0 + 0.4.1.0 LICENSE diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index 1160fa4f..60955e68 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1 + net5.0 AnyCPU;x64 diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index a08959a7..ae04c398 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ 1. Build static library -`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` +`bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow` 2. Build pip package @@ -70,6 +70,16 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` +### Build from source for MacOS + +```shell +$ cd /usr/local/lib/bazel/bin +$ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64 +$ chmod +x bazel-3.7.2-darwin-x86_64 +$ cd ~/Projects/tensorflow +$ bazel build --config=opt //tensorflow:tensorflow +``` + ### Build specific version for tf.net https://github.com/SciSharp/tensorflow