Browse Source

Consolidate dataset map args.

tags/yolov3
Oceania2018 4 years ago
parent
commit
34a5b32eff
7 changed files with 24 additions and 27 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Data/ParallelMapDataset.cs
  4. +1
    -2
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  5. +7
    -12
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  6. +12
    -9
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 1
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality, preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function); use_legacy_function: use_legacy_function);


public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1)
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);


public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)


+ 1
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow
bool preserve_cardinality = true, bool preserve_cardinality = true,
bool use_legacy_function = false); bool use_legacy_function = false);


IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls = -1); int num_parallel_calls = -1);


IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);


+ 1
- 1
src/TensorFlowNET.Core/Data/ParallelMapDataset.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
public class ParallelMapDataset : UnaryDataset public class ParallelMapDataset : UnaryDataset
{ {
public ParallelMapDataset(IDatasetV2 input_dataset, public ParallelMapDataset(IDatasetV2 input_dataset,
Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func,
Func<Tensors, Tensors> map_func,
int num_parallel_calls = -1, int num_parallel_calls = -1,
bool use_inter_op_parallelism = true, bool use_inter_op_parallelism = true,
bool preserve_cardinality = false, bool preserve_cardinality = false,


+ 1
- 2
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -12,8 +12,7 @@ namespace Tensorflow
public TensorDataset(Tensors elements) public TensorDataset(Tensors elements)
{ {
_tensors = elements; _tensors = elements;
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray();


variant_tensor = ops.tensor_dataset(_tensors, output_shapes); variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
} }


+ 7
- 12
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -83,7 +83,7 @@ namespace Tensorflow.Functions
graph.Exit(); graph.Exit();
} }


public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
public ConcreteFunction(Func<Tensors, Tensors> func,
TF_DataType[] dtypes, TensorShape[] shapes) TF_DataType[] dtypes, TensorShape[] shapes)
{ {
string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";
@@ -92,19 +92,14 @@ namespace Tensorflow.Functions
using var graph = new FuncGraph(func_name); using var graph = new FuncGraph(func_name);
graph.as_default(); graph.as_default();


var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args");
var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args");
var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args");
var outputs = func(input1, (input2, input3));

Outputs = new[] { outputs.Item1, outputs.Item2 };
OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() };
var inputs = new Tensors();
foreach(var (i, dtype) in enumerate(dtypes))
inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args"));
Outputs = func(inputs);
OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray();


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null);
_handle = graph.ToGraph(opers, inputs, Outputs, null);
graph.Exit(); graph.Exit();
} }




+ 12
- 9
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
num_samples = args.X.shape[0]; num_samples = args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size; _batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
_size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f)));
num_full_batches = num_samples / batch_size; num_full_batches = num_samples / batch_size;
var _partial_batch_size = num_samples % batch_size; var _partial_batch_size = num_samples % batch_size;


@@ -29,12 +30,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters
indices_dataset = indices_dataset.repeat(args.Epochs); indices_dataset = indices_dataset.repeat(args.Epochs);
indices_dataset = indices_dataset.map(permutation).prefetch(1); indices_dataset = indices_dataset.map(permutation).prefetch(1);
indices_dataset = indices_dataset.flat_map(slice_batch_indices); indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var elements = new Tensors();
var inputs = new Tensors();
if (args.X != null) if (args.X != null)
elements.Add(args.X);
inputs.Add(args.X);
if (args.Y != null) if (args.Y != null)
elements.Add(args.Y);
dataset = slice_inputs(indices_dataset, elements);
inputs.Add(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
} }


Tensor permutation(Tensor tensor) Tensor permutation(Tensor tensor)
@@ -64,11 +65,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); var dataset2 = tf.data.Dataset.from_tensor(elements).repeat();
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); var dataset = tf.data.Dataset.zip(indices_dataset, dataset2);


dataset = dataset.map((batch, data) =>
dataset = dataset.map(inputs =>
{ {
var x = gen_array_ops.gather_v2(data.Item1, batch, 0);
var y = gen_array_ops.gather_v2(data.Item2, batch, 0);
return (x, y);
var indices = inputs[0];
var results = inputs.Skip(1)
.Select(x => gen_array_ops.gather_v2(x, indices, 0))
.ToArray();
return new Tensors(results);
}); });


dataset = dataset.with_options(new DatasetOptions { }); dataset = dataset.with_options(new DatasetOptions { });


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -97,7 +97,7 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_train_batch_begin(step) // callbacks.on_train_batch_begin(step)
var results = step_function(iterator); var results = step_function(iterator);
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
} }
} }
} }


Loading…
Cancel
Save