Browse Source

Consolidate dataset map args.

tags/yolov3
Oceania2018 4 years ago
parent
commit
34a5b32eff
7 changed files with 24 additions and 27 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Data/ParallelMapDataset.cs
  4. +1
    -2
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  5. +7
    -12
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  6. +12
    -9
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 1
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function);

public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1)
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);

public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)


+ 1
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow
bool preserve_cardinality = true,
bool use_legacy_function = false);

IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls = -1);

IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);


+ 1
- 1
src/TensorFlowNET.Core/Data/ParallelMapDataset.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
public class ParallelMapDataset : UnaryDataset
{
public ParallelMapDataset(IDatasetV2 input_dataset,
Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
Func<Tensors, Tensors> map_func,
int num_parallel_calls = -1,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,


+ 1
- 2
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -12,8 +12,7 @@ namespace Tensorflow
public TensorDataset(Tensors elements)
{
_tensors = elements;
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
}


+ 7
- 12
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -83,7 +83,7 @@ namespace Tensorflow.Functions
graph.Exit();
}

public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
public ConcreteFunction(Func<Tensors, Tensors> func,
TF_DataType[] dtypes, TensorShape[] shapes)
{
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
@@ -92,19 +92,14 @@ namespace Tensorflow.Functions
using var graph = new FuncGraph(func_name);
graph.as_default();

var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
var outputs = func(input1, (input2, input3));

Outputs = new[] { outputs.Item1, outputs.Item2 };
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
var inputs = new Tensors();
foreach(var (i, dtype) in enumerate(dtypes))
inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args"));
Outputs = func(inputs);
OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray();

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
_handle = graph.ToGraph(opers, inputs, Outputs, null);
graph.Exit();
}



+ 12
- 9
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
num_samples = args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
_size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f)));
num_full_batches = num_samples / batch_size;
var _partial_batch_size = num_samples % batch_size;

@@ -29,12 +30,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters
indices_dataset = indices_dataset.repeat(args.Epochs);
indices_dataset = indices_dataset.map(permutation).prefetch(1);
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var elements = new Tensors();
var inputs = new Tensors();
if (args.X != null)
elements.Add(args.X);
inputs.Add(args.X);
if (args.Y != null)
elements.Add(args.Y);
dataset = slice_inputs(indices_dataset, elements);
inputs.Add(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
}

Tensor permutation(Tensor tensor)
@@ -64,11 +65,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat();
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2);

dataset = dataset.map((batch, data) =>
dataset = dataset.map(inputs =>
{
var x = gen_array_ops.gather_v2(data.Item1, batch, 0);
var y = gen_array_ops.gather_v2(data.Item2, batch, 0);
return (x, y);
var indices = inputs[0];
var results = inputs.Skip(1)
.Select(x => gen_array_ops.gather_v2(x, indices, 0))
.ToArray();
return new Tensors(results);
});

dataset = dataset.with_options(new DatasetOptions { });


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -97,7 +97,7 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_train_batch_begin(step)
var results = step_function(iterator);
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
}
}
}


Loading…
Cancel
Save