Browse Source

Fix preserve_cardinality for ParallelMapDataset.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
500f0c0cca
10 changed files with 81 additions and 23 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +12
    -4
      src/TensorFlowNET.Core/Data/ParallelMapDataset.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  5. +4
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  6. +1
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  7. +12
    -12
      src/TensorFlowNET.Keras/Engine/Model.Compile.cs
  8. +26
    -1
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  9. +8
    -3
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  10. +11
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 3
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -66,7 +66,9 @@ namespace Tensorflow
use_legacy_function: use_legacy_function);

public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
=> new ParallelMapDataset(this, map_func,
num_parallel_calls: num_parallel_calls,
preserve_cardinality: true);

public OwnedIterator make_one_shot_iterator()
{


+ 12
- 4
src/TensorFlowNET.Core/Data/ParallelMapDataset.cs View File

@@ -15,18 +15,26 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
var func = new ConcreteFunction(map_func,
input_dataset.element_spec.Select(x => x.dtype).ToArray(),
input_dataset.element_spec.Select(x => x.shape).ToArray());
var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = map_func(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;

var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64,
name: "num_parallel_calls");
variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor,
_num_parallel_calls,
func,
output_types,
output_shapes);
output_shapes,
use_inter_op_parallelism: use_inter_op_parallelism,
preserve_cardinality: preserve_cardinality);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Functions
func_graph.Exit();
}

public ConcreteFunction(Func<Tensors, Tensors> func,
/*public ConcreteFunction(Func<Tensors, Tensors> func,
TF_DataType[] dtypes, TensorShape[] shapes)
{
string func_name = $"{func.Method.Name}_{ops.uid_function()}";
@@ -89,7 +89,7 @@ namespace Tensorflow.Functions
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
func_graph.ToGraph(opers, inputs, Outputs, null);
func_graph.Exit();
}
}*/

public void ToGraph(Tensors inputs, Tensors outputs)
{


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -38,6 +38,8 @@ namespace Tensorflow
}
}

public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors)
{
items.AddRange(tensors);


+ 4
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -585,6 +585,10 @@ would not be rank 1.", tensor.op.get_attr("axis")));
else
return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
}
else if(dtype == TF_DataType.TF_VARIANT)
{
return "<unprintable>";
}

var nd = tensor.numpy();



+ 1
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -100,6 +100,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
using var data_iterator = new OwnedIterator(_dataset);
yield return (epoch, data_iterator);
}
// _adapter.on_epoch_end()
}

public IEnumerable<int> steps()


+ 12
- 12
src/TensorFlowNET.Keras/Engine/Model.Compile.cs View File

@@ -33,22 +33,22 @@ namespace Tensorflow.Keras.Engine

public void compile(string optimizer, string loss, string[] metrics)
{
switch (optimizer)
var _optimizer = optimizer switch
{
case "rmsprop":
this.optimizer = new RMSprop(new RMSpropArgs
{
"rmsprop" => new RMSprop(new RMSpropArgs
{

});
break;
}
}),
_ => throw new NotImplementedException("")
};

int experimental_steps_per_execution = 1;
_configure_steps_per_execution(experimental_steps_per_execution);

_reset_compile_cache();
var _loss = loss switch
{
"mse" => new MeanSquaredError(),
_ => throw new NotImplementedException("")
};

_is_compiled = true;
compile(optimizer: _optimizer, loss: _loss, metrics: metrics);
}
}
}

+ 26
- 1
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -49,7 +49,32 @@ namespace Tensorflow.Keras.Engine
Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = test_function(iterator);
}
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}

public void evaluate(IDatasetV2 x)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = x,
Model = this,
StepsPerExecution = _steps_per_execution
});

Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;


+ 8
- 3
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -124,10 +124,11 @@ namespace Tensorflow.Keras

var start_positions_tensor = tf.constant(start_positions);
var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat();
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds);
var r = tf.data.Dataset.range(len(start_positions));
var z = tf.data.Dataset.zip(r, positions_ds);
var indices = z.map(m =>
{
var (i, positions) = (m[0], m[1]);
var (i, positions) = m;
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
}, num_parallel_calls: -1);
var dataset = sequences_from_indices(data, indices, start_index, end_index);
@@ -142,7 +143,11 @@ namespace Tensorflow.Keras
{
var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]);
dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds)
.map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1);
.map(x =>
{
var (steps, indx) = x;
return array_ops.gather(steps, indx);
}, num_parallel_calls: -1);
return dataset;
}
}


+ 11
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -147,7 +147,18 @@ namespace TensorFlowNET.UnitTest.Dataset
public void Cardinality()
{
var dataset = tf.data.Dataset.range(10);
var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
dataset = dataset.map(x => x[0] + 1);
cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

[TestMethod]
public void CardinalityWithAutoTune()
{
var dataset = tf.data.Dataset.range(10);
dataset = dataset.map(x => x, num_parallel_calls: -1);
var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}


Loading…
Cancel
Save