Browse Source

Fix preserve_cardinality for ParallelMapDataset.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
500f0c0cca
10 changed files with 81 additions and 23 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +12
    -4
      src/TensorFlowNET.Core/Data/ParallelMapDataset.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  5. +4
    -0
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  6. +1
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  7. +12
    -12
      src/TensorFlowNET.Keras/Engine/Model.Compile.cs
  8. +26
    -1
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  9. +8
    -3
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  10. +11
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 3
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -66,7 +66,9 @@ namespace Tensorflow
use_legacy_function: use_legacy_function); use_legacy_function: use_legacy_function);


public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls) public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
=> new ParallelMapDataset(this, map_func,
num_parallel_calls: num_parallel_calls,
preserve_cardinality: true);


public OwnedIterator make_one_shot_iterator() public OwnedIterator make_one_shot_iterator()
{ {


+ 12
- 4
src/TensorFlowNET.Core/Data/ParallelMapDataset.cs View File

@@ -15,18 +15,26 @@ namespace Tensorflow
bool preserve_cardinality = false, bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset) bool use_legacy_function = false) : base(input_dataset)
{ {
var func = new ConcreteFunction(map_func,
input_dataset.element_spec.Select(x => x.dtype).ToArray(),
input_dataset.element_spec.Select(x => x.shape).ToArray());
var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = map_func(inputs);
func.ToGraph(inputs, outputs);
func.Exit();


structure = func.OutputStructure; structure = func.OutputStructure;

var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64,
name: "num_parallel_calls"); name: "num_parallel_calls");
variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor,
_num_parallel_calls, _num_parallel_calls,
func, func,
output_types, output_types,
output_shapes);
output_shapes,
use_inter_op_parallelism: use_inter_op_parallelism,
preserve_cardinality: preserve_cardinality);
} }
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Functions
func_graph.Exit(); func_graph.Exit();
} }


public ConcreteFunction(Func<Tensors, Tensors> func,
/*public ConcreteFunction(Func<Tensors, Tensors> func,
TF_DataType[] dtypes, TensorShape[] shapes) TF_DataType[] dtypes, TensorShape[] shapes)
{ {
string func_name = $"{func.Method.Name}_{ops.uid_function()}"; string func_name = $"{func.Method.Name}_{ops.uid_function()}";
@@ -89,7 +89,7 @@ namespace Tensorflow.Functions
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
func_graph.ToGraph(opers, inputs, Outputs, null); func_graph.ToGraph(opers, inputs, Outputs, null);
func_graph.Exit(); func_graph.Exit();
}
}*/


public void ToGraph(Tensors inputs, Tensors outputs) public void ToGraph(Tensors inputs, Tensors outputs)
{ {


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -38,6 +38,8 @@ namespace Tensorflow
} }
} }


public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors) public Tensors(params Tensor[] tensors)
{ {
items.AddRange(tensors); items.AddRange(tensors);


+ 4
- 0
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -585,6 +585,10 @@ would not be rank 1.", tensor.op.get_attr("axis")));
else else
return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
} }
else if(dtype == TF_DataType.TF_VARIANT)
{
return "<unprintable>";
}


var nd = tensor.numpy(); var nd = tensor.numpy();




+ 1
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -100,6 +100,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
using var data_iterator = new OwnedIterator(_dataset); using var data_iterator = new OwnedIterator(_dataset);
yield return (epoch, data_iterator); yield return (epoch, data_iterator);
} }
// _adapter.on_epoch_end()
} }


public IEnumerable<int> steps() public IEnumerable<int> steps()


+ 12
- 12
src/TensorFlowNET.Keras/Engine/Model.Compile.cs View File

@@ -33,22 +33,22 @@ namespace Tensorflow.Keras.Engine


public void compile(string optimizer, string loss, string[] metrics) public void compile(string optimizer, string loss, string[] metrics)
{ {
switch (optimizer)
var _optimizer = optimizer switch
{ {
case "rmsprop":
this.optimizer = new RMSprop(new RMSpropArgs
{
"rmsprop" => new RMSprop(new RMSpropArgs
{


});
break;
}
}),
_ => throw new NotImplementedException("")
};


int experimental_steps_per_execution = 1;
_configure_steps_per_execution(experimental_steps_per_execution);

_reset_compile_cache();
var _loss = loss switch
{
"mse" => new MeanSquaredError(),
_ => throw new NotImplementedException("")
};


_is_compiled = true;
compile(optimizer: _optimizer, loss: _loss, metrics: metrics);
} }
} }
} }

+ 26
- 1
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -49,7 +49,32 @@ namespace Tensorflow.Keras.Engine
Binding.tf_output_redirect.WriteLine($"Testing..."); Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{ {
// reset_metrics();
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = test_function(iterator);
}
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}

public void evaluate(IDatasetV2 x)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = x,
Model = this,
StepsPerExecution = _steps_per_execution
});

Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// callbacks.on_epoch_begin(epoch) // callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration(); // data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null; IEnumerable<(string, Tensor)> results = null;


+ 8
- 3
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -124,10 +124,11 @@ namespace Tensorflow.Keras


var start_positions_tensor = tf.constant(start_positions); var start_positions_tensor = tf.constant(start_positions);
var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat();
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds);
var r = tf.data.Dataset.range(len(start_positions));
var z = tf.data.Dataset.zip(r, positions_ds);
var indices = z.map(m => var indices = z.map(m =>
{ {
var (i, positions) = (m[0], m[1]);
var (i, positions) = m;
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
}, num_parallel_calls: -1); }, num_parallel_calls: -1);
var dataset = sequences_from_indices(data, indices, start_index, end_index); var dataset = sequences_from_indices(data, indices, start_index, end_index);
@@ -142,7 +143,11 @@ namespace Tensorflow.Keras
{ {
var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]);
dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds)
.map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1);
.map(x =>
{
var (steps, indx) = x;
return array_ops.gather(steps, indx);
}, num_parallel_calls: -1);
return dataset; return dataset;
} }
} }


+ 11
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -147,7 +147,18 @@ namespace TensorFlowNET.UnitTest.Dataset
public void Cardinality() public void Cardinality()
{ {
var dataset = tf.data.Dataset.range(10); var dataset = tf.data.Dataset.range(10);
var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
dataset = dataset.map(x => x[0] + 1); dataset = dataset.map(x => x[0] + 1);
cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

[TestMethod]
public void CardinalityWithAutoTune()
{
var dataset = tf.data.Dataset.range(10);
dataset = dataset.map(x => x, num_parallel_calls: -1);
var cardinality = dataset.dataset_cardinality(); var cardinality = dataset.dataset_cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
} }


Loading…
Cancel
Save