|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
using System; |
|
|
|
using System.Linq; |
|
|
|
using Tensorflow.Keras.ArgsDefinition; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
num_samples = args.X.shape[0]; |
|
|
|
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; |
|
|
|
_batch_size = batch_size; |
|
|
|
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f))); |
|
|
|
_size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f))); |
|
|
|
num_full_batches = num_samples / batch_size; |
|
|
|
var _partial_batch_size = num_samples % batch_size; |
|
|
|
|
|
|
|
@@ -29,12 +30,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
indices_dataset = indices_dataset.repeat(args.Epochs); |
|
|
|
indices_dataset = indices_dataset.map(permutation).prefetch(1); |
|
|
|
indices_dataset = indices_dataset.flat_map(slice_batch_indices); |
|
|
|
var elements = new Tensors(); |
|
|
|
var inputs = new Tensors(); |
|
|
|
if (args.X != null) |
|
|
|
elements.Add(args.X); |
|
|
|
inputs.Add(args.X); |
|
|
|
if (args.Y != null) |
|
|
|
elements.Add(args.Y); |
|
|
|
dataset = slice_inputs(indices_dataset, elements); |
|
|
|
inputs.Add(args.Y); |
|
|
|
dataset = slice_inputs(indices_dataset, inputs); |
|
|
|
} |
|
|
|
|
|
|
|
Tensor permutation(Tensor tensor) |
|
|
|
@@ -64,11 +65,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); |
|
|
|
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); |
|
|
|
|
|
|
|
dataset = dataset.map((batch, data) => |
|
|
|
dataset = dataset.map(inputs => |
|
|
|
{ |
|
|
|
var x = gen_array_ops.gather_v2(data.Item1, batch, 0); |
|
|
|
var y = gen_array_ops.gather_v2(data.Item2, batch, 0); |
|
|
|
return (x, y); |
|
|
|
var indices = inputs[0]; |
|
|
|
var results = inputs.Skip(1) |
|
|
|
.Select(x => gen_array_ops.gather_v2(x, indices, 0)) |
|
|
|
.ToArray(); |
|
|
|
return new Tensors(results); |
|
|
|
}); |
|
|
|
|
|
|
|
dataset = dataset.with_options(new DatasetOptions { }); |
|
|
|
|