| @@ -60,7 +60,7 @@ namespace Tensorflow | |||||
| preserve_cardinality: preserve_cardinality, | preserve_cardinality: preserve_cardinality, | ||||
| use_legacy_function: use_legacy_function); | use_legacy_function: use_legacy_function); | ||||
| public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1) | |||||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1) | |||||
| => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | ||||
| public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | ||||
| @@ -60,7 +60,7 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = true, | bool preserve_cardinality = true, | ||||
| bool use_legacy_function = false); | bool use_legacy_function = false); | ||||
| IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, | |||||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||||
| int num_parallel_calls = -1); | int num_parallel_calls = -1); | ||||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | ||||
| @@ -9,7 +9,7 @@ namespace Tensorflow | |||||
| public class ParallelMapDataset : UnaryDataset | public class ParallelMapDataset : UnaryDataset | ||||
| { | { | ||||
| public ParallelMapDataset(IDatasetV2 input_dataset, | public ParallelMapDataset(IDatasetV2 input_dataset, | ||||
| Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, | |||||
| Func<Tensors, Tensors> map_func, | |||||
| int num_parallel_calls = -1, | int num_parallel_calls = -1, | ||||
| bool use_inter_op_parallelism = true, | bool use_inter_op_parallelism = true, | ||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| @@ -12,8 +12,7 @@ namespace Tensorflow | |||||
| public TensorDataset(Tensors elements) | public TensorDataset(Tensors elements) | ||||
| { | { | ||||
| _tensors = elements; | _tensors = elements; | ||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | ||||
| } | } | ||||
| @@ -83,7 +83,7 @@ namespace Tensorflow.Functions | |||||
| graph.Exit(); | graph.Exit(); | ||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func, | |||||
| public ConcreteFunction(Func<Tensors, Tensors> func, | |||||
| TF_DataType[] dtypes, TensorShape[] shapes) | TF_DataType[] dtypes, TensorShape[] shapes) | ||||
| { | { | ||||
| string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; | ||||
| @@ -92,19 +92,14 @@ namespace Tensorflow.Functions | |||||
| using var graph = new FuncGraph(func_name); | using var graph = new FuncGraph(func_name); | ||||
| graph.as_default(); | graph.as_default(); | ||||
| var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); | |||||
| var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); | |||||
| var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); | |||||
| var outputs = func(input1, (input2, input3)); | |||||
| Outputs = new[] { outputs.Item1, outputs.Item2 }; | |||||
| OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; | |||||
| var inputs = new Tensors(); | |||||
| foreach(var (i, dtype) in enumerate(dtypes)) | |||||
| inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args")); | |||||
| Outputs = func(inputs); | |||||
| OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| _handle = graph.ToGraph(opers, | |||||
| new[] { input1, input2, input3 }, | |||||
| new[] { outputs.Item1, outputs.Item2 }, | |||||
| null); | |||||
| _handle = graph.ToGraph(opers, inputs, Outputs, null); | |||||
| graph.Exit(); | graph.Exit(); | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| num_samples = args.X.shape[0]; | num_samples = args.X.shape[0]; | ||||
| var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | ||||
| _batch_size = batch_size; | _batch_size = batch_size; | ||||
| _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f))); | |||||
| _size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f))); | |||||
| num_full_batches = num_samples / batch_size; | num_full_batches = num_samples / batch_size; | ||||
| var _partial_batch_size = num_samples % batch_size; | var _partial_batch_size = num_samples % batch_size; | ||||
| @@ -29,12 +30,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| indices_dataset = indices_dataset.repeat(args.Epochs); | indices_dataset = indices_dataset.repeat(args.Epochs); | ||||
| indices_dataset = indices_dataset.map(permutation).prefetch(1); | indices_dataset = indices_dataset.map(permutation).prefetch(1); | ||||
| indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
| var elements = new Tensors(); | |||||
| var inputs = new Tensors(); | |||||
| if (args.X != null) | if (args.X != null) | ||||
| elements.Add(args.X); | |||||
| inputs.Add(args.X); | |||||
| if (args.Y != null) | if (args.Y != null) | ||||
| elements.Add(args.Y); | |||||
| dataset = slice_inputs(indices_dataset, elements); | |||||
| inputs.Add(args.Y); | |||||
| dataset = slice_inputs(indices_dataset, inputs); | |||||
| } | } | ||||
| Tensor permutation(Tensor tensor) | Tensor permutation(Tensor tensor) | ||||
| @@ -64,11 +65,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); | var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); | ||||
| var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | ||||
| dataset = dataset.map((batch, data) => | |||||
| dataset = dataset.map(inputs => | |||||
| { | { | ||||
| var x = gen_array_ops.gather_v2(data.Item1, batch, 0); | |||||
| var y = gen_array_ops.gather_v2(data.Item2, batch, 0); | |||||
| return (x, y); | |||||
| var indices = inputs[0]; | |||||
| var results = inputs.Skip(1) | |||||
| .Select(x => gen_array_ops.gather_v2(x, indices, 0)) | |||||
| .ToArray(); | |||||
| return new Tensors(results); | |||||
| }); | }); | ||||
| dataset = dataset.with_options(new DatasetOptions { }); | dataset = dataset.with_options(new DatasetOptions { }); | ||||
| @@ -97,7 +97,7 @@ namespace Tensorflow.Keras.Engine | |||||
| // callbacks.on_train_batch_begin(step) | // callbacks.on_train_batch_begin(step) | ||||
| var results = step_function(iterator); | var results = step_function(iterator); | ||||
| var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | ||||
| Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); | |||||
| Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||