| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||||
| <TargetFramework>net5.0</TargetFramework> | |||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <AssemblyName>Tensorflow</AssemblyName> | <AssemblyName>Tensorflow</AssemblyName> | ||||
| <Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
| @@ -43,7 +43,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public partial class c_api | public partial class c_api | ||||
| { | { | ||||
| public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow"; | |||||
| public const string TensorFlowLibName = "tensorflow"; | |||||
| public static string StringPiece(IntPtr handle) | public static string StringPiece(IntPtr handle) | ||||
| { | { | ||||
| @@ -24,6 +24,30 @@ namespace Tensorflow | |||||
| { | { | ||||
| string_ops ops = new string_ops(); | string_ops ops = new string_ops(); | ||||
| /// <summary> | |||||
| /// Converts all uppercase characters into their respective lowercase replacements. | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="encoding"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor lower(Tensor input, string encoding = "", string name = null) | |||||
| => ops.lower(input: input, encoding: encoding, name: name); | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="pattern"></param> | |||||
| /// <param name="rewrite"></param> | |||||
| /// <param name="replace_global"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor regex_replace(Tensor input, string pattern, string rewrite, | |||||
| bool replace_global = true, string name = null) | |||||
| => ops.regex_replace(input, pattern, rewrite, | |||||
| replace_global: replace_global, name: name); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return substrings from `Tensor` of strings. | /// Return substrings from `Tensor` of strings. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -14,6 +14,7 @@ namespace Tensorflow | |||||
| public class DatasetV2 : IDatasetV2 | public class DatasetV2 : IDatasetV2 | ||||
| { | { | ||||
| protected dataset_ops ops = new dataset_ops(); | protected dataset_ops ops = new dataset_ops(); | ||||
| public string[] class_names { get; set; } | |||||
| public Tensor variant_tensor { get; set; } | public Tensor variant_tensor { get; set; } | ||||
| public TensorSpec[] structure { get; set; } | public TensorSpec[] structure { get; set; } | ||||
| @@ -54,7 +55,7 @@ namespace Tensorflow | |||||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | ||||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | ||||
| public IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||||
| bool use_inter_op_parallelism = true, | bool use_inter_op_parallelism = true, | ||||
| bool preserve_cardinality = true, | bool preserve_cardinality = true, | ||||
| bool use_legacy_function = false) | bool use_legacy_function = false) | ||||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||||
| preserve_cardinality: preserve_cardinality, | preserve_cardinality: preserve_cardinality, | ||||
| use_legacy_function: use_legacy_function); | use_legacy_function: use_legacy_function); | ||||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1) | |||||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls) | |||||
| => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | ||||
| public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | ||||
| @@ -6,6 +6,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> | public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> | ||||
| { | { | ||||
| string[] class_names { get; set; } | |||||
| Tensor variant_tensor { get; set; } | Tensor variant_tensor { get; set; } | ||||
| TensorShape[] output_shapes { get; } | TensorShape[] output_shapes { get; } | ||||
| @@ -62,13 +64,13 @@ namespace Tensorflow | |||||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | ||||
| IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||||
| bool use_inter_op_parallelism = true, | bool use_inter_op_parallelism = true, | ||||
| bool preserve_cardinality = true, | bool preserve_cardinality = true, | ||||
| bool use_legacy_function = false); | bool use_legacy_function = false); | ||||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | IDatasetV2 map(Func<Tensors, Tensors> map_func, | ||||
| int num_parallel_calls = -1); | |||||
| int num_parallel_calls); | |||||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||||
| public class MapDataset : UnaryDataset | public class MapDataset : UnaryDataset | ||||
| { | { | ||||
| public MapDataset(IDatasetV2 input_dataset, | public MapDataset(IDatasetV2 input_dataset, | ||||
| Func<Tensor, Tensor> map_func, | |||||
| Func<Tensors, Tensors> map_func, | |||||
| bool use_inter_op_parallelism = true, | bool use_inter_op_parallelism = true, | ||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow.Framework.Models | |||||
| if (_shape.ndim == 0) | if (_shape.ndim == 0) | ||||
| throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); | throw new ValueError("Unbatching a tensor is only supported for rank >= 1"); | ||||
| return new TensorSpec(_shape.dims[1..], _dtype); | |||||
| return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype); | |||||
| } | } | ||||
| public TensorSpec _batch(int dim = -1) | public TensorSpec _batch(int dim = -1) | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Gradients | |||||
| var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray()); | ||||
| Tensor image_shape = null; | Tensor image_shape = null; | ||||
| if (shape.is_fully_defined()) | if (shape.is_fully_defined()) | ||||
| image_shape = constant_op.constant(image.shape[1..3]); | |||||
| image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray()); | |||||
| else | else | ||||
| image_shape = array_ops.shape(image)["1:3"]; | image_shape = array_ops.shape(image)["1:3"]; | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class PreprocessingLayerArgs : LayerArgs | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class TextVectorizationArgs : PreprocessingLayerArgs | |||||
| { | |||||
| public Func<Tensor, Tensor> Standardize { get; set; } | |||||
| public string Split { get; set; } = "standardize"; | |||||
| public int MaxTokens { get; set; } = -1; | |||||
| public string OutputMode { get; set; } = "int"; | |||||
| public int OutputSequenceLength { get; set; } = -1; | |||||
| } | |||||
| } | |||||
| @@ -20,6 +20,54 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class string_ops | public class string_ops | ||||
| { | { | ||||
| public Tensor lower(Tensor input, string encoding = "", string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "StringLower", name, | |||||
| null, | |||||
| input, encoding); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new | |||||
| { | |||||
| input, | |||||
| encoding | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| public Tensor regex_replace(Tensor input, string pattern, string rewrite, | |||||
| bool replace_global = true, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "StaticRegexReplace", name, | |||||
| null, | |||||
| input, | |||||
| "pattern", pattern, | |||||
| "rewrite", rewrite, | |||||
| "replace_global", replace_global); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new | |||||
| { | |||||
| input, | |||||
| pattern, | |||||
| rewrite, | |||||
| replace_global | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Return substrings from `Tensor` of strings. | /// Return substrings from `Tensor` of strings. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Engine | |||||
| { | |||||
| public class CombinerPreprocessingLayer : Layer | |||||
| { | |||||
| PreprocessingLayerArgs args; | |||||
| public CombinerPreprocessingLayer(PreprocessingLayerArgs args) | |||||
| : base(args) | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -39,7 +39,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| dataset = slice_inputs(indices_dataset, inputs); | dataset = slice_inputs(indices_dataset, inputs); | ||||
| } | } | ||||
| Tensor permutation(Tensor tensor) | |||||
| Tensors permutation(Tensors tensor) | |||||
| { | { | ||||
| var indices = math_ops.range(num_samples, dtype: dtypes.int64); | var indices = math_ops.range(num_samples, dtype: dtypes.int64); | ||||
| if (args.Shuffle) | if (args.Shuffle) | ||||
| @@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| .Select(x => gen_array_ops.gather_v2(x, indices, 0)) | .Select(x => gen_array_ops.gather_v2(x, indices, 0)) | ||||
| .ToArray(); | .ToArray(); | ||||
| return new Tensors(results); | return new Tensors(results); | ||||
| }); | |||||
| }, -1); | |||||
| return dataset.with_options(new DatasetOptions { }); | return dataset.with_options(new DatasetOptions { }); | ||||
| } | } | ||||
| @@ -62,8 +62,8 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| var y_t_rank = y_t.rank; | var y_t_rank = y_t.rank; | ||||
| var y_p_rank = y_p.rank; | var y_p_rank = y_p.rank; | ||||
| var y_t_last_dim = y_t.shape[^1]; | |||||
| var y_p_last_dim = y_p.shape[^1]; | |||||
| var y_t_last_dim = y_t.shape[y_t.shape.Length - 1]; | |||||
| var y_p_last_dim = y_p.shape[y_p.shape.Length - 1]; | |||||
| bool is_binary = y_p_last_dim == 1; | bool is_binary = y_p_last_dim == 1; | ||||
| bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; | bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; | ||||
| @@ -14,6 +14,7 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Linq; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| @@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (set_inputs) | if (set_inputs) | ||||
| { | { | ||||
| // If an input layer (placeholder) is available. | // If an input layer (placeholder) is available. | ||||
| outputs = layer.InboundNodes[^1].Outputs; | |||||
| outputs = layer.InboundNodes.Last().Outputs; | |||||
| inputs = layer_utils.get_source_inputs(outputs[0]); | inputs = layer_utils.get_source_inputs(outputs[0]); | ||||
| built = true; | built = true; | ||||
| _has_explicit_input_shape = true; | _has_explicit_input_shape = true; | ||||
| @@ -11,6 +11,7 @@ using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Models; | using Tensorflow.Keras.Models; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Utils; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| @@ -27,6 +28,7 @@ namespace Tensorflow.Keras | |||||
| public OptimizerApi optimizers { get; } = new OptimizerApi(); | public OptimizerApi optimizers { get; } = new OptimizerApi(); | ||||
| public MetricsApi metrics { get; } = new MetricsApi(); | public MetricsApi metrics { get; } = new MetricsApi(); | ||||
| public ModelsApi models { get; } = new ModelsApi(); | public ModelsApi models { get; } = new ModelsApi(); | ||||
| public KerasUtils utils { get; } = new KerasUtils(); | |||||
| public Sequential Sequential(List<ILayer> layers = null, | public Sequential Sequential(List<ILayer> layers = null, | ||||
| string name = null) | string name = null) | ||||
| @@ -73,7 +75,7 @@ namespace Tensorflow.Keras | |||||
| Tensor tensor = null) | Tensor tensor = null) | ||||
| { | { | ||||
| if (batch_input_shape != null) | if (batch_input_shape != null) | ||||
| shape = batch_input_shape.dims[1..]; | |||||
| shape = batch_input_shape.dims.Skip(1).ToArray(); | |||||
| var args = new InputLayerArgs | var args = new InputLayerArgs | ||||
| { | { | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Layers | |||||
| if (BatchInputShape != null) | if (BatchInputShape != null) | ||||
| { | { | ||||
| args.BatchSize = BatchInputShape.dims[0]; | args.BatchSize = BatchInputShape.dims[0]; | ||||
| args.InputShape = BatchInputShape.dims[1..]; | |||||
| args.InputShape = BatchInputShape.dims.Skip(1).ToArray(); | |||||
| } | } | ||||
| // moved to base class | // moved to base class | ||||
| @@ -9,6 +9,8 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| public partial class LayersApi | public partial class LayersApi | ||||
| { | { | ||||
| public Preprocessing preprocessing { get; } = new Preprocessing(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Functional interface for the batch normalization layer. | /// Functional interface for the batch normalization layer. | ||||
| /// http://arxiv.org/abs/1502.03167 | /// http://arxiv.org/abs/1502.03167 | ||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Layers | |||||
| { | |||||
| public class TextVectorization : CombinerPreprocessingLayer | |||||
| { | |||||
| TextVectorizationArgs args; | |||||
| public TextVectorization(TextVectorizationArgs args) | |||||
| : base(args) | |||||
| { | |||||
| args.DType = TF_DataType.TF_STRING; | |||||
| // string standardize = "lower_and_strip_punctuation", | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| @@ -45,7 +46,7 @@ namespace Tensorflow.Keras.Layers | |||||
| return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | ||||
| } | } | ||||
| var non_batch_dims = ((int[])input_shape)[1..]; | |||||
| var non_batch_dims = ((int[])input_shape).Skip(1).ToArray(); | |||||
| var num = 1; | var num = 1; | ||||
| if (non_batch_dims.Length > 0) | if (non_batch_dims.Length > 0) | ||||
| { | { | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Layers | |||||
| public override TensorShape ComputeOutputShape(TensorShape input_shape) | public override TensorShape ComputeOutputShape(TensorShape input_shape) | ||||
| { | { | ||||
| if (input_shape.dims[1..].Contains(-1)) | |||||
| if (input_shape.dims.Skip(1).Contains(-1)) | |||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Preprocessings | namespace Tensorflow.Keras.Preprocessings | ||||
| { | { | ||||
| @@ -17,18 +18,21 @@ namespace Tensorflow.Keras.Preprocessings | |||||
| float validation_split, | float validation_split, | ||||
| string subset) | string subset) | ||||
| { | { | ||||
| if (string.IsNullOrEmpty(subset)) | |||||
| return (samples, labels); | |||||
| var num_val_samples = Convert.ToInt32(samples.Length * validation_split); | var num_val_samples = Convert.ToInt32(samples.Length * validation_split); | ||||
| if (subset == "training") | if (subset == "training") | ||||
| { | { | ||||
| Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | ||||
| samples = samples[..^num_val_samples]; | |||||
| labels = labels[..^num_val_samples]; | |||||
| samples = samples.Take(samples.Length - num_val_samples).ToArray(); | |||||
| labels = labels.Take(labels.Length - num_val_samples).ToArray(); | |||||
| } | } | ||||
| else if (subset == "validation") | else if (subset == "validation") | ||||
| { | { | ||||
| Console.WriteLine($"Using {num_val_samples} files for validation."); | Console.WriteLine($"Using {num_val_samples} files for validation."); | ||||
| samples = samples[(samples.Length - num_val_samples)..]; | |||||
| labels = labels[(labels.Length - num_val_samples)..]; | |||||
| samples = samples.Skip(samples.Length - num_val_samples).ToArray(); | |||||
| labels = labels.Skip(labels.Length - num_val_samples).ToArray(); | |||||
| } | } | ||||
| else | else | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| @@ -21,40 +21,41 @@ namespace Tensorflow.Keras.Preprocessings | |||||
| /// file_paths, labels, class_names | /// file_paths, labels, class_names | ||||
| /// </returns> | /// </returns> | ||||
| public (string[], int[], string[]) index_directory(string directory, | public (string[], int[], string[]) index_directory(string directory, | ||||
| string labels, | |||||
| string[] formats = null, | string[] formats = null, | ||||
| string[] class_names = null, | string[] class_names = null, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int? seed = null, | int? seed = null, | ||||
| bool follow_links = false) | bool follow_links = false) | ||||
| { | { | ||||
| var labels = new List<int>(); | |||||
| var label_list = new List<int>(); | |||||
| var file_paths = new List<string>(); | var file_paths = new List<string>(); | ||||
| var class_dirs = Directory.GetDirectories(directory); | var class_dirs = Directory.GetDirectories(directory); | ||||
| class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar)[^1]).ToArray(); | |||||
| class_names = class_dirs.Select(x => x.Split(Path.DirectorySeparatorChar).Last()).ToArray(); | |||||
| for (var label = 0; label < class_dirs.Length; label++) | for (var label = 0; label < class_dirs.Length; label++) | ||||
| { | { | ||||
| var files = Directory.GetFiles(class_dirs[label]); | var files = Directory.GetFiles(class_dirs[label]); | ||||
| file_paths.AddRange(files); | file_paths.AddRange(files); | ||||
| labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||||
| label_list.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||||
| } | } | ||||
| var return_labels = labels.Select(x => x).ToArray(); | |||||
| var return_labels = label_list.Select(x => x).ToArray(); | |||||
| var return_file_paths = file_paths.Select(x => x).ToArray(); | var return_file_paths = file_paths.Select(x => x).ToArray(); | ||||
| if (shuffle) | if (shuffle) | ||||
| { | { | ||||
| if (!seed.HasValue) | if (!seed.HasValue) | ||||
| seed = np.random.randint((long)1e6); | seed = np.random.randint((long)1e6); | ||||
| var random_index = np.arange(labels.Count); | |||||
| var random_index = np.arange(label_list.Count); | |||||
| var rng = np.random.RandomState(seed.Value); | var rng = np.random.RandomState(seed.Value); | ||||
| rng.shuffle(random_index); | rng.shuffle(random_index); | ||||
| var index = random_index.ToArray<int>(); | var index = random_index.ToArray<int>(); | ||||
| for (int i = 0; i < labels.Count; i++) | |||||
| for (int i = 0; i < label_list.Count; i++) | |||||
| { | { | ||||
| return_labels[i] = labels[index[i]]; | |||||
| return_labels[i] = label_list[index[i]]; | |||||
| return_file_paths[i] = file_paths[index[i]]; | return_file_paths[i] = file_paths[index[i]]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,7 @@ | |||||
| using Tensorflow.Keras.Preprocessings; | |||||
| using System; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Preprocessings; | |||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| @@ -6,5 +9,18 @@ namespace Tensorflow.Keras | |||||
| { | { | ||||
| public Sequence sequence => new Sequence(); | public Sequence sequence => new Sequence(); | ||||
| public DatasetUtils dataset_utils => new DatasetUtils(); | public DatasetUtils dataset_utils => new DatasetUtils(); | ||||
| public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | |||||
| string split = "standardize", | |||||
| int max_tokens = -1, | |||||
| string output_mode = "int", | |||||
| int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs | |||||
| { | |||||
| Standardize = standardize, | |||||
| Split = split, | |||||
| MaxTokens = max_tokens, | |||||
| OutputMode = output_mode, | |||||
| OutputSequenceLength = output_sequence_length | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -43,6 +43,7 @@ namespace Tensorflow.Keras | |||||
| num_channels = 3; | num_channels = 3; | ||||
| var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, | var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, | ||||
| labels, | |||||
| formats: WHITELIST_FORMATS, | formats: WHITELIST_FORMATS, | ||||
| class_names: class_names, | class_names: class_names, | ||||
| shuffle: shuffle, | shuffle: shuffle, | ||||
| @@ -64,13 +65,30 @@ namespace Tensorflow.Keras | |||||
| string[] class_names = null, | string[] class_names = null, | ||||
| int batch_size = 32, | int batch_size = 32, | ||||
| bool shuffle = true, | bool shuffle = true, | ||||
| int max_length = -1, | |||||
| int? seed = null, | int? seed = null, | ||||
| float validation_split = 0.2f, | float validation_split = 0.2f, | ||||
| string subset = null) | |||||
| string subset = null, | |||||
| bool follow_links = false) | |||||
| { | { | ||||
| var (file_paths, label_list, class_name_list) = dataset_utils.index_directory( | |||||
| directory, | |||||
| labels, | |||||
| formats: new[] { ".txt" }, | |||||
| class_names: class_names, | |||||
| shuffle: shuffle, | |||||
| seed: seed, | |||||
| follow_links: follow_links); | |||||
| return null; | |||||
| (file_paths, label_list) = dataset_utils.get_training_or_validation_split( | |||||
| file_paths, label_list, validation_split, subset); | |||||
| var dataset = paths_and_labels_to_dataset(file_paths, label_list, label_mode, class_name_list.Length); | |||||
| if (shuffle) | |||||
| dataset = dataset.shuffle(batch_size * 8, seed: seed); | |||||
| dataset = dataset.batch(batch_size); | |||||
| dataset.class_names = class_name_list; | |||||
| return dataset; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.IO; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| @@ -34,5 +35,31 @@ namespace Tensorflow.Keras | |||||
| // img.set_shape((image_size[0], image_size[1], num_channels)); | // img.set_shape((image_size[0], image_size[1], num_channels)); | ||||
| return img; | return img; | ||||
| } | } | ||||
| public IDatasetV2 paths_and_labels_to_dataset(string[] image_paths, | |||||
| int[] labels, | |||||
| string label_mode, | |||||
| int num_classes, | |||||
| int max_length = -1) | |||||
| { | |||||
| var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); | |||||
| var string_ds = path_ds.map(x => path_to_string_content(x, max_length)); | |||||
| if (label_mode == "int") | |||||
| { | |||||
| var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); | |||||
| string_ds = tf.data.Dataset.zip(string_ds, label_ds); | |||||
| } | |||||
| return string_ds; | |||||
| } | |||||
| Tensor path_to_string_content(Tensor path, int max_length) | |||||
| { | |||||
| var txt = tf.io.read_file(path); | |||||
| if (max_length > -1) | |||||
| txt = tf.strings.substr(txt, 0, max_length); | |||||
| return txt; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -6,7 +6,7 @@ | |||||
| <LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
| <RootNamespace>Tensorflow.Keras</RootNamespace> | <RootNamespace>Tensorflow.Keras</RootNamespace> | ||||
| <Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
| <Version>0.4.0</Version> | |||||
| <Version>0.4.1</Version> | |||||
| <Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
| <Product>Keras for .NET</Product> | <Product>Keras for .NET</Product> | ||||
| <Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | <Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | ||||
| @@ -34,8 +34,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
| <RepositoryType>Git</RepositoryType> | <RepositoryType>Git</RepositoryType> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
| <AssemblyVersion>0.4.0.0</AssemblyVersion> | |||||
| <FileVersion>0.4.0.0</FileVersion> | |||||
| <AssemblyVersion>0.4.1.0</AssemblyVersion> | |||||
| <FileVersion>0.4.1.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -2,7 +2,7 @@ | |||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <OutputType>Exe</OutputType> | <OutputType>Exe</OutputType> | ||||
| <TargetFramework>netcoreapp3.1</TargetFramework> | |||||
| <TargetFramework>net5.0</TargetFramework> | |||||
| <Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| @@ -56,7 +56,7 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ | |||||
| 1. Build static library | 1. Build static library | ||||
| `bazel build --output_base=C:/tmp/tfcompilation build --config=opt //tensorflow:tensorflow` | |||||
| `bazel build --output_base=C:/tmp/tfcompilation --config=opt //tensorflow:tensorflow` | |||||
| 2. Build pip package | 2. Build pip package | ||||
| @@ -70,6 +70,16 @@ Set ENV `BAZEL_VC=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\ | |||||
| `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` | `pip install C:/tmp/tensorflow_pkg/tensorflow-1.15.0-cp36-cp36m-win_amd64.whl` | ||||
| ### Build from source for MacOS | |||||
| ```shell | |||||
| $ cd /usr/local/lib/bazel/bin | |||||
| $ curl -LO https://release.bazel.build/3.7.2/release/bazel-3.7.2-darwin-x86_64 | |||||
| $ chmod +x bazel-3.7.2-darwin-x86_64 | |||||
| $ cd ~/Projects/tensorflow | |||||
| $ bazel build --config=opt //tensorflow:tensorflow | |||||
| ``` | |||||
| ### Build specific version for tf.net | ### Build specific version for tf.net | ||||
| https://github.com/SciSharp/tensorflow | https://github.com/SciSharp/tensorflow | ||||