Browse Source

Consolidate MapDataset function.

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
4ef675faf9
29 changed files with 258 additions and 38 deletions
  1. +1
    -1
      src/TensorFlowNET.Console/Tensorflow.Console.csproj
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.cs
  3. +24
    -0
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  4. +3
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  5. +4
    -2
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Data/MapDataset.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Gradients/image_grad.cs
  9. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
  10. +15
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
  11. +48
    -0
      src/TensorFlowNET.Core/Operations/string_ops.cs
  12. +18
    -0
      src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs
  13. +2
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  14. +2
    -2
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  15. +2
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  16. +3
    -1
      src/TensorFlowNET.Keras/KerasInterface.cs
  17. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  18. +2
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  19. +20
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  20. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  21. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  22. +8
    -4
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
  23. +8
    -7
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  24. +17
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
  25. +21
    -3
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  26. +27
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
  27. +3
    -3
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  28. +1
    -1
      src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
  29. +11
    -1
      tensorflowlib/README.md

+ 1
- 1
src/TensorFlowNET.Console/Tensorflow.Console.csproj View File

@@ -2,7 +2,7 @@


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net5.0</TargetFramework>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<AssemblyName>Tensorflow</AssemblyName> <AssemblyName>Tensorflow</AssemblyName>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>


+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public partial class c_api public partial class c_api
{ {
public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow";
public const string TensorFlowLibName = "tensorflow";


public static string StringPiece(IntPtr handle) public static string StringPiece(IntPtr handle)
{ {


+ 24
- 0
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -24,6 +24,30 @@ namespace Tensorflow
{ {
string_ops ops = new string_ops(); string_ops ops = new string_ops();


/// <summary>
/// Converts all uppercase characters into their respective lowercase replacements.
/// </summary>
/// <param name="input"></param>
/// <param name="encoding"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor lower(Tensor input, string encoding = "", string name = null)
=> ops.lower(input: input, encoding: encoding, name: name);

/// <summary>
///
/// </summary>
/// <param name="input"></param>
/// <param name="pattern"></param>
/// <param name="rewrite"></param>
/// <param name="replace_global"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor regex_replace(Tensor input, string pattern, string rewrite,
bool replace_global = true, string name = null)
=> ops.regex_replace(input, pattern, rewrite,
replace_global: replace_global, name: name);

/// <summary> /// <summary>
/// Return substrings from `Tensor` of strings. /// Return substrings from `Tensor` of strings.
/// </summary> /// </summary>


+ 3
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow
public class DatasetV2 : IDatasetV2 public class DatasetV2 : IDatasetV2
{ {
protected dataset_ops ops = new dataset_ops(); protected dataset_ops ops = new dataset_ops();
public string[] class_names { get; set; }
public Tensor variant_tensor { get; set; } public Tensor variant_tensor { get; set; }


public TensorSpec[] structure { get; set; } public TensorSpec[] structure { get; set; }
@@ -54,7 +55,7 @@ namespace Tensorflow
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);


public IDatasetV2 map(Func<Tensor, Tensor> map_func,
public IDatasetV2 map(Func<Tensors, Tensors> map_func,
bool use_inter_op_parallelism = true, bool use_inter_op_parallelism = true,
bool preserve_cardinality = true, bool preserve_cardinality = true,
bool use_legacy_function = false) bool use_legacy_function = false)
@@ -64,7 +65,7 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality, preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function); use_legacy_function: use_legacy_function);


public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1)
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);


public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)


+ 4
- 2
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -6,6 +6,8 @@ namespace Tensorflow
{ {
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
{ {
string[] class_names { get; set; }

Tensor variant_tensor { get; set; } Tensor variant_tensor { get; set; }


TensorShape[] output_shapes { get; } TensorShape[] output_shapes { get; }
@@ -62,13 +64,13 @@ namespace Tensorflow


IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);


IDatasetV2 map(Func<Tensor, Tensor> map_func,
IDatasetV2 map(Func<Tensors, Tensors> map_func,
bool use_inter_op_parallelism = true, bool use_inter_op_parallelism = true,
bool preserve_cardinality = true, bool preserve_cardinality = true,
bool use_legacy_function = false); bool use_legacy_function = false);


IDatasetV2 map(Func<Tensors, Tensors> map_func, IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls = -1);
int num_parallel_calls);


IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);




+ 1
- 1
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
public class MapDataset : UnaryDataset public class MapDataset : UnaryDataset
{ {
public MapDataset(IDatasetV2 input_dataset, public MapDataset(IDatasetV2 input_dataset,
Func<Tensor, Tensor> map_func,
Func<Tensors, Tensors> map_func,
bool use_inter_op_parallelism = true, bool use_inter_op_parallelism = true,
bool preserve_cardinality = false, bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset) bool use_legacy_function = false) : base(input_dataset)


+ 1
- 1
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Framework.Models
if (_shape.ndim == 0) if (_shape.ndim == 0)
throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); throw new ValueError("Unbatching a tensor is only supported for rank >= 1");


return new TensorSpec(_shape.dims[1..], _dtype);
return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype);
} }


public TensorSpec _batch(int dim = -1) public TensorSpec _batch(int dim = -1)


+ 1
- 1
src/TensorFlowNET.Core/Gradients/image_grad.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Gradients
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
Tensor image_shape = null; Tensor image_shape = null;
if (shape.is_fully_defined()) if (shape.is_fully_defined())
image_shape = constant_op.constant(image.shape[1..3]);
image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray());
else else
image_shape = array_ops.shape(image)["1:3"]; image_shape = array_ops.shape(image)["1:3"];




+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class PreprocessingLayerArgs : LayerArgs
{
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class TextVectorizationArgs : PreprocessingLayerArgs
{
public Func<Tensor, Tensor> Standardize { get; set; }
public string Split { get; set; } = "standardize";
public int MaxTokens { get; set; } = -1;
public string OutputMode { get; set; } = "int";
public int OutputSequenceLength { get; set; } = -1;
}
}

+ 48
- 0
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -20,6 +20,54 @@ namespace Tensorflow
{ {
public class string_ops public class string_ops
{ {
public Tensor lower(Tensor input, string encoding = "", string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StringLower", name,
null,
input, encoding);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new
{
input,
encoding
});

return _op.output;
}

public Tensor regex_replace(Tensor input, string pattern, string rewrite,
bool replace_global = true, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StaticRegexReplace", name,
null,
input,
"pattern", pattern,
"rewrite", rewrite,
"replace_global", replace_global);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new
{
input,
pattern,
rewrite,
replace_global
});

return _op.output;
}

/// <summary> /// <summary>
/// Return substrings from `Tensor` of strings. /// Return substrings from `Tensor` of strings.
/// </summary> /// </summary>


+ 18
- 0
src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine
{
public class CombinerPreprocessingLayer : Layer
{
PreprocessingLayerArgs args;

public CombinerPreprocessingLayer(PreprocessingLayerArgs args)
: base(args)
{
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
dataset = slice_inputs(indices_dataset, inputs); dataset = slice_inputs(indices_dataset, inputs);
} }


Tensor permutation(Tensor tensor)
Tensors permutation(Tensors tensor)
{ {
var indices = math_ops.range(num_samples, dtype: dtypes.int64); var indices = math_ops.range(num_samples, dtype: dtypes.int64);
if (args.Shuffle) if (args.Shuffle)
@@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
.Select(x => gen_array_ops.gather_v2(x, indices, 0)) .Select(x => gen_array_ops.gather_v2(x, indices, 0))
.ToArray(); .ToArray();
return new Tensors(results); return new Tensors(results);
});
}, -1);


return dataset.with_options(new DatasetOptions { }); return dataset.with_options(new DatasetOptions { });
} }


+ 2
- 2
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -62,8 +62,8 @@ namespace Tensorflow.Keras.Engine
{ {
var y_t_rank = y_t.rank; var y_t_rank = y_t.rank;
var y_p_rank = y_p.rank; var y_p_rank = y_p.rank;
var y_t_last_dim = y_t.shape[^1];
var y_p_last_dim = y_p.shape[^1];
var y_t_last_dim = y_t.shape[y_t.shape.Length - 1];
var y_p_last_dim = y_p.shape[y_p.shape.Length - 1];


bool is_binary = y_p_last_dim == 1; bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Linq;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
@@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Engine
if (set_inputs) if (set_inputs)
{ {
// If an input layer (placeholder) is available. // If an input layer (placeholder) is available.
outputs = layer.InboundNodes[^1].Outputs;
outputs = layer.InboundNodes.Last().Outputs;
inputs = layer_utils.get_source_inputs(outputs[0]); inputs = layer_utils.get_source_inputs(outputs[0]);
built = true; built = true;
_has_explicit_input_shape = true; _has_explicit_input_shape = true;


+ 3
- 1
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -11,6 +11,7 @@ using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models; using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;


namespace Tensorflow.Keras namespace Tensorflow.Keras
{ {
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras
public OptimizerApi optimizers { get; } = new OptimizerApi(); public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi(); public MetricsApi metrics { get; } = new MetricsApi();
public ModelsApi models { get; } = new ModelsApi(); public ModelsApi models { get; } = new ModelsApi();
public KerasUtils utils { get; } = new KerasUtils();


public Sequential Sequential(List<ILayer> layers = null, public Sequential Sequential(List<ILayer> layers = null,
string name = null) string name = null)
@@ -73,7 +75,7 @@ namespace Tensorflow.Keras
Tensor tensor = null) Tensor tensor = null)
{ {
if (batch_input_shape != null) if (batch_input_shape != null)
shape = batch_input_shape.dims[1..];
shape = batch_input_shape.dims.Skip(1).ToArray();


var args = new InputLayerArgs var args = new InputLayerArgs
{ {


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers
if (BatchInputShape != null) if (BatchInputShape != null)
{ {
args.BatchSize = BatchInputShape.dims[0]; args.BatchSize = BatchInputShape.dims[0];
args.InputShape = BatchInputShape.dims[1..];
args.InputShape = BatchInputShape.dims.Skip(1).ToArray();
} }


// moved to base class // moved to base class


+ 2
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -9,6 +9,8 @@ namespace Tensorflow.Keras.Layers
{ {
public partial class LayersApi public partial class LayersApi
{ {
public Preprocessing preprocessing { get; } = new Preprocessing();

/// <summary> /// <summary>
/// Functional interface for the batch normalization layer. /// Functional interface for the batch normalization layer.
/// http://arxiv.org/abs/1502.03167 /// http://arxiv.org/abs/1502.03167


+ 20
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class TextVectorization : CombinerPreprocessingLayer
{
TextVectorizationArgs args;

public TextVectorization(TextVectorizationArgs args)
: base(args)
{
args.DType = TF_DataType.TF_STRING;
// string standardize = "lower_and_strip_punctuation",
}
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Linq;
using Tensorflow.Framework; using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
@@ -45,7 +46,7 @@ namespace Tensorflow.Keras.Layers
return array_ops.reshape(inputs, new[] { batch_dim, -1 }); return array_ops.reshape(inputs, new[] { batch_dim, -1 });
} }


var non_batch_dims = ((int[])input_shape)[1..];
var non_batch_dims = ((int[])input_shape).Skip(1).ToArray();
var num = 1; var num = 1;
if (non_batch_dims.Length > 0) if (non_batch_dims.Length > 0)
{ {


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Layers


public override TensorShape ComputeOutputShape(TensorShape input_shape) public override TensorShape ComputeOutputShape(TensorShape input_shape)
{ {
if (input_shape.dims[1..].Contains(-1))
if (input_shape.dims.Skip(1).Contains(-1))
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


+ 8
- 4
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Linq;


namespace Tensorflow.Keras.Preprocessings namespace Tensorflow.Keras.Preprocessings
{ {
@@ -17,18 +18,21 @@ namespace Tensorflow.Keras.Preprocessings
float validation_split, float validation_split,
string subset) string subset)
{ {
if (string.IsNullOrEmpty(subset))
return (samples, labels);

var num_val_samples = Convert.ToInt32(samples.Length * validation_split); var num_val_samples = Convert.ToInt32(samples.Length * validation_split);
if (subset == "training") if (subset == "training")
{ {
Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); Console.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
samples = samples[..^num_val_samples];
labels = labels[..^num_val_samples];
samples = samples.Take(samples.Length - num_val_samples).ToArray();
labels = labels.Take(labels.Length - num_val_samples).ToArray();
} }
else if (subset == "validation") else if (subset == "validation")
{ {
Console.WriteLine($"Using {num_val_samples} files for validation."); Console.WriteLine($"Using {num_val_samples} files for validation.");
samples = samples[(samples.Length - num_val_samples)..];
labels = labels[(labels.Length - num_val_samples)..];
samples = samples.Skip(samples.Length - num_val_samples).ToArray();
labels = labels.Skip(labels.Length - num_val_samples).ToArray();
} }
else else
throw new NotImplementedException(""); throw new NotImplementedException("");


+ 8
- 7
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -21,40 +21,41 @@ namespace Tensorflow.Keras.Preprocessings
/// file_paths, labels, class_names /// file_paths, labels, class_names
/// </returns> /// </returns>
public (string[], int[], string[]) index_directory(string directory, public (string[], int[], string[]) index_directory(string directory,
string labels,
string[] formats = null, string[] formats = null,
string[] class_names = null, string[] class_names = null,
bool shuffle = true, bool shuffle = true,
int? seed = null, int? seed = null,
bool follow_links = false) bool follow_links = false)
{ {
var labels = new List<int>();
var label_list = new List<int>();
var file_paths = new List<string>(); var file_paths = new List<string>();


var class_dirs = Directory.GetDirectories(directory); var class_dirs = Directory.GetDirectories(directory);
class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray();
class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray();


for (var label = 0; label < class_dirs.Length; label++) for (var label = 0; label < class_dirs.Length; label++)
{ {
var files = Directory.GetFiles(class_dirs[label]); var files = Directory.GetFiles(class_dirs[label]);
file_paths.AddRange(files); file_paths.AddRange(files);
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
} }


var return_labels = labels.Select(x => x).ToArray();
var return_labels = label_list.Select(x => x).ToArray();
var return_file_paths = file_paths.Select(x => x).ToArray(); var return_file_paths = file_paths.Select(x => x).ToArray();


if (shuffle) if (shuffle)
{ {
if (!seed.HasValue) if (!seed.HasValue)
seed = np.random.randint((long)1e6); seed = np.random.randint((long)1e6);
var random_index = np.arange(labels.Count);
var random_index = np.arange(label_list.Count);
var rng = np.random.RandomState(seed.Value); var rng = np.random.RandomState(seed.Value);
rng.shuffle(random_index); rng.shuffle(random_index);
var index = random_index.ToArray<int>(); var index = random_index.ToArray<int>();


for (int i = 0; i < labels.Count; i++)
for (int i = 0; i < label_list.Count; i++)
{ {
return_labels[i] = labels[index[i]];
return_labels[i] = label_list[index[i]];
return_file_paths[i] = file_paths[index[i]]; return_file_paths[i] = file_paths[index[i]];
} }
} }


+ 17
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs View File

@@ -1,4 +1,7 @@
using Tensorflow.Keras.Preprocessings;
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Preprocessings;


namespace Tensorflow.Keras namespace Tensorflow.Keras
{ {
@@ -6,5 +9,18 @@ namespace Tensorflow.Keras
{ {
public Sequence sequence => new Sequence(); public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils(); public DatasetUtils dataset_utils => new DatasetUtils();

public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null,
string split = "standardize",
int max_tokens = -1,
string output_mode = "int",
int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs
{
Standardize = standardize,
Split = split,
MaxTokens = max_tokens,
OutputMode = output_mode,
OutputSequenceLength = output_sequence_length
});
} }
} }

+ 21
- 3
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -43,6 +43,7 @@ namespace Tensorflow.Keras
num_channels = 3; num_channels = 3;
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory,
labels,
formats: WHITELIST_FORMATS, formats: WHITELIST_FORMATS,
class_names: class_names, class_names: class_names,
shuffle: shuffle, shuffle: shuffle,
@@ -64,13 +65,30 @@ namespace Tensorflow.Keras
string[] class_names = null, string[] class_names = null,
int batch_size = 32, int batch_size = 32,
bool shuffle = true, bool shuffle = true,
int max_length = -1,
int? seed = null, int? seed = null,
float validation_split = 0.2f, float validation_split = 0.2f,
string subset = null)
string subset = null,
bool follow_links = false)
{ {
var (file_paths, label_list, class_name_list) = dataset_utils.index_directory(
directory,
labels,
formats: new[] { ".txt" },
class_names: class_names,
shuffle: shuffle,
seed: seed,
follow_links: follow_links);


return null;
(file_paths, label_list) = dataset_utils.get_training_or_validation_split(
file_paths, label_list, validation_split, subset);

var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length);
if (shuffle)
dataset = dataset.shuffle(batch_size * 8, seed: seed);
dataset = dataset.batch(batch_size);
dataset.class_names = class_name_list;
return dataset;
} }
} }
} }

+ 27
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.IO;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras namespace Tensorflow.Keras
@@ -34,5 +35,31 @@ namespace Tensorflow.Keras
// img.set_shape((image_size[0], image_size[1], num_channels)); // img.set_shape((image_size[0], image_size[1], num_channels));
return img; return img;
} }

public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths,
int[] labels,
string label_mode,
int num_classes,
int max_length = -1)
{
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var string_ds = path_ds.map(x => path_to_string_content(x, max_length));

if (label_mode == "int")
{
var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
string_ds = tf.data.Dataset.zip(string_ds, label_ds);
}

return string_ds;
}

Tensor path_to_string_content(Tensor path, int max_length)
{
var txt = tf.io.read_file(path);
if (max_length > -1)
txt = tf.strings.substr(txt, 0, max_length);
return txt;
}
} }
} }

+ 3
- 3
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -6,7 +6,7 @@
<LangVersion>8.0</LangVersion> <LangVersion>8.0</LangVersion>
<RootNamespace>Tensorflow.Keras</RootNamespace> <RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
<Version>0.4.0</Version>
<Version>0.4.1</Version>
<Authors>Haiping Chen</Authors> <Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product> <Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright> <Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
@@ -34,8 +34,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType> <RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.4.0.0</AssemblyVersion>
<FileVersion>0.4.0.0</FileVersion>
<AssemblyVersion>0.4.1.0</AssemblyVersion>
<FileVersion>0.4.1.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup> </PropertyGroup>




+ 1
- 1
src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj View File

@@ -2,7 +2,7 @@


<PropertyGroup> <PropertyGroup>
<OutputType>Exe</OutputType> <OutputType>Exe</OutputType>
<TargetFramework>netcoreapp3.1</TargetFramework>
<TargetFramework>net5.0</TargetFramework>
<Platforms>AnyCPU;x64</Platforms> <Platforms>AnyCPU;x64</Platforms>
</PropertyGroup> </PropertyGroup>




+ 11
- 1
tensorflowlib/README.md View File

@@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\


1. Build static library 1. Build static library


`bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow`
`bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow`


2. Build pip package 2. Build pip package


@@ -70,6 +70,16 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\


`pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl`


### Build from source for MacOS

```shell
$ cd /usr/local/lib/bazel/bin
$ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64
$ chmod +x bazel-3.7.2-darwin-x86_64
$ cd ~/Projects/tensorflow
$ bazel build --config=opt //tensorflow:tensorflow
```

### Build specific version for tf.net ### Build specific version for tf.net


https://github.com/SciSharp/tensorflow https://github.com/SciSharp/tensorflow


Loading…
Cancel
Save