diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
index 9359d975..d363b00f 100644
--- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj
+++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp3.1
+ net5.0
Tensorflow
Tensorflow
AnyCPU;x64
diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs
index 11c17abd..10f678e0 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.cs
@@ -43,7 +43,7 @@ namespace Tensorflow
///
public partial class c_api
{
- public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow";
+ public const string TensorFlowLibName = "tensorflow";
public static string StringPiece(IntPtr handle)
{
diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs
index be0cf765..f580a67d 100644
--- a/src/TensorFlowNET.Core/APIs/tf.strings.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs
@@ -24,6 +24,30 @@ namespace Tensorflow
{
string_ops ops = new string_ops();
+ ///
+ /// Converts all uppercase characters into their respective lowercase replacements.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor lower(Tensor input, string encoding = "", string name = null)
+ => ops.lower(input: input, encoding: encoding, name: name);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor regex_replace(Tensor input, string pattern, string rewrite,
+ bool replace_global = true, string name = null)
+ => ops.regex_replace(input, pattern, rewrite,
+ replace_global: replace_global, name: name);
+
///
/// Return substrings from `Tensor` of strings.
///
diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 2abe9970..0297eb6b 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -14,6 +14,7 @@ namespace Tensorflow
public class DatasetV2 : IDatasetV2
{
protected dataset_ops ops = new dataset_ops();
+ public string[] class_names { get; set; }
public Tensor variant_tensor { get; set; }
public TensorSpec[] structure { get; set; }
@@ -54,7 +55,7 @@ namespace Tensorflow
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
- public IDatasetV2 map(Func map_func,
+ public IDatasetV2 map(Func map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = true,
bool use_legacy_function = false)
@@ -64,7 +65,7 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function);
- public IDatasetV2 map(Func map_func, int num_parallel_calls = -1)
+ public IDatasetV2 map(Func map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
public IDatasetV2 flat_map(Func map_func)
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index 4d9b00d2..d0e372dc 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -6,6 +6,8 @@ namespace Tensorflow
{
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
{
+ string[] class_names { get; set; }
+
Tensor variant_tensor { get; set; }
TensorShape[] output_shapes { get; }
@@ -62,13 +64,13 @@ namespace Tensorflow
IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);
- IDatasetV2 map(Func map_func,
+ IDatasetV2 map(Func map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = true,
bool use_legacy_function = false);
IDatasetV2 map(Func map_func,
- int num_parallel_calls = -1);
+ int num_parallel_calls);
IDatasetV2 flat_map(Func map_func);
diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs
index 6f753f55..5786a340 100644
--- a/src/TensorFlowNET.Core/Data/MapDataset.cs
+++ b/src/TensorFlowNET.Core/Data/MapDataset.cs
@@ -10,7 +10,7 @@ namespace Tensorflow
public class MapDataset : UnaryDataset
{
public MapDataset(IDatasetV2 input_dataset,
- Func map_func,
+ Func map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
index 0d5aa7d0..5f333547 100644
--- a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
+++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
@@ -15,7 +15,7 @@ namespace Tensorflow.Framework.Models
if (_shape.ndim == 0)
throw new ValueError("Unbatching a tensor is only supported for rank >= 1");
- return new TensorSpec(_shape.dims[1..], _dtype);
+ return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype);
}
public TensorSpec _batch(int dim = -1)
diff --git a/src/TensorFlowNET.Core/Gradients/image_grad.cs b/src/TensorFlowNET.Core/Gradients/image_grad.cs
index ccc70fea..08636298 100644
--- a/src/TensorFlowNET.Core/Gradients/image_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/image_grad.cs
@@ -30,7 +30,7 @@ namespace Tensorflow.Gradients
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
Tensor image_shape = null;
if (shape.is_fully_defined())
- image_shape = constant_op.constant(image.shape[1..3]);
+ image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray());
else
image_shape = array_ops.shape(image)["1:3"];
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
new file mode 100644
index 00000000..28ccf9f7
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class PreprocessingLayerArgs : LayerArgs
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
new file mode 100644
index 00000000..ab55da4e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class TextVectorizationArgs : PreprocessingLayerArgs
+ {
+ public Func Standardize { get; set; }
+ public string Split { get; set; } = "standardize";
+ public int MaxTokens { get; set; } = -1;
+ public string OutputMode { get; set; } = "int";
+ public int OutputSequenceLength { get; set; } = -1;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs
index 49b4c3e9..deba7be4 100644
--- a/src/TensorFlowNET.Core/Operations/string_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/string_ops.cs
@@ -20,6 +20,54 @@ namespace Tensorflow
{
public class string_ops
{
+ public Tensor lower(Tensor input, string encoding = "", string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "StringLower", name,
+ null,
+ input, encoding);
+
+ return results[0];
+ }
+
+ var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new
+ {
+ input,
+ encoding
+ });
+
+ return _op.output;
+ }
+
+ public Tensor regex_replace(Tensor input, string pattern, string rewrite,
+ bool replace_global = true, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "StaticRegexReplace", name,
+ null,
+ input,
+ "pattern", pattern,
+ "rewrite", rewrite,
+ "replace_global", replace_global);
+
+ return results[0];
+ }
+
+ var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new
+ {
+ input,
+ pattern,
+ rewrite,
+ replace_global
+ });
+
+ return _op.output;
+ }
+
///
/// Return substrings from `Tensor` of strings.
///
diff --git a/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs
new file mode 100644
index 00000000..11adfe9f
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Engine
+{
+ public class CombinerPreprocessingLayer : Layer
+ {
+ PreprocessingLayerArgs args;
+
+ public CombinerPreprocessingLayer(PreprocessingLayerArgs args)
+ : base(args)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index 478e0e8b..98fd4741 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
dataset = slice_inputs(indices_dataset, inputs);
}
- Tensor permutation(Tensor tensor)
+ Tensors permutation(Tensors tensor)
{
var indices = math_ops.range(num_samples, dtype: dtypes.int64);
if (args.Shuffle)
@@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
.Select(x => gen_array_ops.gather_v2(x, indices, 0))
.ToArray();
return new Tensors(results);
- });
+ }, -1);
return dataset.with_options(new DatasetOptions { });
}
diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
index f0abc29f..3870c29b 100644
--- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
@@ -62,8 +62,8 @@ namespace Tensorflow.Keras.Engine
{
var y_t_rank = y_t.rank;
var y_p_rank = y_p.rank;
- var y_t_last_dim = y_t.shape[^1];
- var y_p_last_dim = y_p.shape[^1];
+ var y_t_last_dim = y_t.shape[y_t.shape.Length - 1];
+ var y_p_last_dim = y_p.shape[y_p.shape.Length - 1];
bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;
diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs
index 58cc73f3..d06810f5 100644
--- a/src/TensorFlowNET.Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using System.Linq;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
@@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Engine
if (set_inputs)
{
// If an input layer (placeholder) is available.
- outputs = layer.InboundNodes[^1].Outputs;
+ outputs = layer.InboundNodes.Last().Outputs;
inputs = layer_utils.get_source_inputs(outputs[0]);
built = true;
_has_explicit_input_shape = true;
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 50f80b6d..b5209e76 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -11,6 +11,7 @@ using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving;
+using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras
{
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi();
+ public KerasUtils utils { get; } = new KerasUtils();
public Sequential Sequential(List layers = null,
string name = null)
@@ -73,7 +75,7 @@ namespace Tensorflow.Keras
Tensor tensor = null)
{
if (batch_input_shape != null)
- shape = batch_input_shape.dims[1..];
+ shape = batch_input_shape.dims.Skip(1).ToArray();
var args = new InputLayerArgs
{
diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
index b1e2844c..55e20166 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers
if (BatchInputShape != null)
{
args.BatchSize = BatchInputShape.dims[0];
- args.InputShape = BatchInputShape.dims[1..];
+ args.InputShape = BatchInputShape.dims.Skip(1).ToArray();
}
// moved to base class
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index dc9ee749..e735f81e 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -9,6 +9,8 @@ namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
{
+ public Preprocessing preprocessing { get; } = new Preprocessing();
+
///
/// Functional interface for the batch normalization layer.
/// http://arxiv.org/abs/1502.03167
diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
new file mode 100644
index 00000000..a66be94b
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class TextVectorization : CombinerPreprocessingLayer
+ {
+ TextVectorizationArgs args;
+
+ public TextVectorization(TextVectorizationArgs args)
+ : base(args)
+ {
+ args.DType = TF_DataType.TF_STRING;
+ // string standardize = "lower_and_strip_punctuation",
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
index f376c7d5..1b59ca82 100644
--- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
@@ -1,4 +1,5 @@
using System;
+using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
@@ -45,7 +46,7 @@ namespace Tensorflow.Keras.Layers
return array_ops.reshape(inputs, new[] { batch_dim, -1 });
}
- var non_batch_dims = ((int[])input_shape)[1..];
+ var non_batch_dims = ((int[])input_shape).Skip(1).ToArray();
var num = 1;
if (non_batch_dims.Length > 0)
{
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
index e8f7d01c..ecabc8f1 100644
--- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Layers
public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
- if (input_shape.dims[1..].Contains(-1))
+ if (input_shape.dims.Skip(1).Contains(-1))
{
throw new NotImplementedException("");
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
index 4e089fb6..a80e960a 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
@@ -1,4 +1,5 @@
using System;
+using System.Linq;
namespace Tensorflow.Keras.Preprocessings
{
@@ -17,18 +18,21 @@ namespace Tensorflow.Keras.Preprocessings
float validation_split,
string subset)
{
+ if (string.IsNullOrEmpty(subset))
+ return (samples, labels);
+
var num_val_samples = Convert.ToInt32(samples.Length * validation_split);
if (subset == "training")
{
Console.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
- samples = samples[..^num_val_samples];
- labels = labels[..^num_val_samples];
+ samples = samples.Take(samples.Length - num_val_samples).ToArray();
+ labels = labels.Take(labels.Length - num_val_samples).ToArray();
}
else if (subset == "validation")
{
Console.WriteLine($"Using {num_val_samples} files for validation.");
- samples = samples[(samples.Length - num_val_samples)..];
- labels = labels[(labels.Length - num_val_samples)..];
+ samples = samples.Skip(samples.Length - num_val_samples).ToArray();
+ labels = labels.Skip(labels.Length - num_val_samples).ToArray();
}
else
throw new NotImplementedException("");
diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
index cf7ef12c..6b62b9b2 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
@@ -21,40 +21,41 @@ namespace Tensorflow.Keras.Preprocessings
/// file_paths, labels, class_names
///
public (string[], int[], string[]) index_directory(string directory,
+ string labels,
string[] formats = null,
string[] class_names = null,
bool shuffle = true,
int? seed = null,
bool follow_links = false)
{
- var labels = new List();
+ var label_list = new List();
var file_paths = new List();
var class_dirs = Directory.GetDirectories(directory);
- class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray();
+ class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray();
for (var label = 0; label < class_dirs.Length; label++)
{
var files = Directory.GetFiles(class_dirs[label]);
file_paths.AddRange(files);
- labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
+ label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
}
- var return_labels = labels.Select(x => x).ToArray();
+ var return_labels = label_list.Select(x => x).ToArray();
var return_file_paths = file_paths.Select(x => x).ToArray();
if (shuffle)
{
if (!seed.HasValue)
seed = np.random.randint((long)1e6);
- var random_index = np.arange(labels.Count);
+ var random_index = np.arange(label_list.Count);
var rng = np.random.RandomState(seed.Value);
rng.shuffle(random_index);
var index = random_index.ToArray();
- for (int i = 0; i < labels.Count; i++)
+ for (int i = 0; i < label_list.Count; i++)
{
- return_labels[i] = labels[index[i]];
+ return_labels[i] = label_list[index[i]];
return_file_paths[i] = file_paths[index[i]];
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
index 2d418509..6c33e9f5 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
@@ -1,4 +1,7 @@
-using Tensorflow.Keras.Preprocessings;
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Layers;
+using Tensorflow.Keras.Preprocessings;
namespace Tensorflow.Keras
{
@@ -6,5 +9,18 @@ namespace Tensorflow.Keras
{
public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils();
+
+ public TextVectorization TextVectorization(Func standardize = null,
+ string split = "standardize",
+ int max_tokens = -1,
+ string output_mode = "int",
+ int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs
+ {
+ Standardize = standardize,
+ Split = split,
+ MaxTokens = max_tokens,
+ OutputMode = output_mode,
+ OutputSequenceLength = output_sequence_length
+ });
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
index c17799ac..8d7513a6 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
@@ -43,6 +43,7 @@ namespace Tensorflow.Keras
num_channels = 3;
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory,
+ labels,
formats: WHITELIST_FORMATS,
class_names: class_names,
shuffle: shuffle,
@@ -64,13 +65,30 @@ namespace Tensorflow.Keras
string[] class_names = null,
int batch_size = 32,
bool shuffle = true,
+ int max_length = -1,
int? seed = null,
float validation_split = 0.2f,
- string subset = null)
+ string subset = null,
+ bool follow_links = false)
{
-
+ var (file_paths, label_list, class_name_list) = dataset_utils.index_directory(
+ directory,
+ labels,
+ formats: new[] { ".txt" },
+ class_names: class_names,
+ shuffle: shuffle,
+ seed: seed,
+ follow_links: follow_links);
- return null;
+ (file_paths, label_list) = dataset_utils.get_training_or_validation_split(
+ file_paths, label_list, validation_split, subset);
+
+ var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length);
+ if (shuffle)
+ dataset = dataset.shuffle(batch_size * 8, seed: seed);
+ dataset = dataset.batch(batch_size);
+ dataset.class_names = class_name_list;
+ return dataset;
}
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
index abf07735..dba2cded 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
@@ -1,4 +1,5 @@
using System;
+using System.IO;
using static Tensorflow.Binding;
namespace Tensorflow.Keras
@@ -34,5 +35,31 @@ namespace Tensorflow.Keras
// img.set_shape((image_size[0], image_size[1], num_channels));
return img;
}
+
+ public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
+ int[] labels,
+ string label_mode,
+ int num_classes,
+ int max_length = -1)
+ {
+ var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
+ var string_ds = path_ds.map(x => path_to_string_content(x, max_length));
+
+ if (label_mode == "int")
+ {
+ var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
+ string_ds = tf.data.Dataset.zip(string_ds, label_ds);
+ }
+
+ return string_ds;
+ }
+
+ Tensor path_to_string_content(Tensor path, int max_length)
+ {
+ var txt = tf.io.read_file(path);
+ if (max_length > -1)
+ txt = tf.strings.substr(txt, 0, max_length);
+ return txt;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 5694e8f5..9eeb4634 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -6,7 +6,7 @@
8.0
Tensorflow.Keras
AnyCPU;x64
- 0.4.0
+ 0.4.1
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen 2020
@@ -34,8 +34,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Git
true
Open.snk
- 0.4.0.0
- 0.4.0.0
+ 0.4.1.0
+ 0.4.1.0
LICENSE
diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
index 1160fa4f..60955e68 100644
--- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
+++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp3.1
+ net5.0
AnyCPU;x64
diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md
index a08959a7..ae04c398 100644
--- a/tensorflowlib/README.md
+++ b/tensorflowlib/README.md
@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\
1. Build static library
-`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow`
+`bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow`
2. Build pip package
@@ -70,6 +70,16 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\
`pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl`
+### Build from source for MacOS
+
+```shell
+$ cd /usr/local/lib/bazel/bin
+$ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64
+$ chmod +x bazel-3.7.2-darwin-x86_64
+$ cd ~/Projects/tensorflow
+$ bazel build --config=opt //tensorflow:tensorflow
+```
+
### Build specific version for tf.net
https://github.com/SciSharp/tensorflow