Browse Source

IDataset cardinality.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
3d943a10c1
10 changed files with 155 additions and 33 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/APIs/tf.data.cs
  2. +26
    -26
      src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
  3. +10
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  4. +58
    -0
      src/TensorFlowNET.Core/Data/FilterDataset.cs
  5. +4
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  6. +19
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  7. +13
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs
  8. +6
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  10. +16
    -3
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 2
- 0
src/TensorFlowNET.Core/APIs/tf.data.cs View File

@@ -23,6 +23,8 @@ namespace Tensorflow
public class DataOps
{
public int AUTOTUNE = -1;
public int INFINITE_CARDINALITY = -1;
public int UNKNOWN_CARDINALITY = -2;
public DatasetManager Dataset { get; } = new DatasetManager();
}
}


+ 26
- 26
src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs View File

@@ -29,48 +29,48 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
// [DebuggerStepThrough]
public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args)
{
Func<Tensors> graphAction = () =>
var keywords = new Dictionary<string, object>();
if (args.OpInputArgs != null)
{
var keywords = new Dictionary<string, object>();
if(args.OpInputArgs != null)
{
foreach (var (i, input) in enumerate(args.OpInputArgs))
keywords[$"input_{i}"] = input;
}
foreach (var (i, input) in enumerate(args.OpInputArgs))
keywords[$"input_{i}"] = input;
}

if(args.OpAttrs != null)
{
foreach (var attr in args.OpAttrs)
keywords[attr.Key] = attr.Value;
}
if (args.OpAttrs != null)
{
foreach (var attr in args.OpAttrs)
keywords[attr.Key] = attr.Value;
}

return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
};
return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
}

Func<Tensors> eagerAction = () =>
Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args)
{
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
{
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
{
attrs = args.OpAttrs
};
return tf.Runner.TFE_FastPathExecute(opExecInfo);
attrs = args.OpAttrs
};
return tf.Runner.TFE_FastPathExecute(opExecInfo);
}

// [DebuggerStepThrough]
public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args)
{
if (tf.Context.has_graph_arg(args.OpInputArgs))
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
var result = ExecGraphAction(opType, name, args);
restore_mode();
return result;
}
else
{
var result = graphAction();
var result = ExecGraphAction(opType, name, args);
if (tf.Runner.MustRecordGradient())
{
var op = result[0].op;
@@ -92,14 +92,14 @@ namespace Tensorflow.Contexts
args1[i + 1] = arg.Value;
i += 2;
}
tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs);
tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs);
}
return result;
}
}
else
{
return eagerAction();
return ExecEagerAction(opType, name, args);
}
}
}


+ 10
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -70,6 +70,12 @@ namespace Tensorflow
num_parallel_calls: num_parallel_calls,
preserve_cardinality: true);

public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func)
=> new FilterDataset(this, predicate_func);

public IDatasetV2 filter(Func<Tensor, bool> predicate_func)
=> new FilterDataset(this, predicate_func);

public OwnedIterator make_one_shot_iterator()
{
if (tf.Context.executing_eagerly())
@@ -105,13 +111,15 @@ namespace Tensorflow
// (3) Apply graph rewrite options
var graph_rewrites = new[]
{
"noop_elimination",
"map_and_batch_fusion",
"map_parallelization",
"noop_elimination",
"shuffle_and_repeat_fusion"
};
var graph_rewrite_configs = new string[]
{
"autotune_buffer_sizes:autotune:true",
"batch_parallelization:autotune:true",
"disable_prefetch_legacy_autotune:autotune:true",
"enable_gradient_descent:autotune:true",
"map_parallelization:autotune:true"
@@ -124,7 +132,7 @@ namespace Tensorflow
return dataset;
}

public Tensor dataset_cardinality(string name = null)
public Tensor cardinality(string name = null)
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));

public override string ToString()


+ 58
- 0
src/TensorFlowNET.Core/Data/FilterDataset.cs View File

@@ -0,0 +1,58 @@
using System;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that filters its input according to a predicate function.
/// </summary>
public class FilterDataset : UnaryDataset
{
public FilterDataset(IDatasetV2 input_dataset,
Func<Tensor, bool> predicate_func) : base(input_dataset)
{
Func<Tensors, Tensors> predicate_func_update = x =>
{
var result = predicate_func(x);
return constant_op.constant(result);
};

var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = predicate_func_update(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;

variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}

public FilterDataset(IDatasetV2 input_dataset,
Func<Tensors, Tensors> predicate_func) : base(input_dataset)
{
var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = predicate_func(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;

variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}
}
}

+ 4
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -72,6 +72,9 @@ namespace Tensorflow
IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls);

IDatasetV2 filter(Func<Tensors, Tensors> map_func);
IDatasetV2 filter(Func<Tensor, bool> map_func);

OwnedIterator make_one_shot_iterator();

IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
@@ -91,6 +94,6 @@ namespace Tensorflow
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
Tensor dataset_cardinality(string name = null);
Tensor cardinality(string name = null);
}
}

+ 19
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -249,6 +249,25 @@ namespace Tensorflow
preserve_cardinality
}));

/// <summary>
/// Creates a dataset containing elements of `input_dataset` matching `predicate`.
/// </summary>
/// <param name="dataset"></param>
/// <param name="predicate"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
=> tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0])
.SetAttributes(new
{
predicate,
output_types,
output_shapes
}));

/// <summary>
/// Creates a dataset that applies `f` to the outputs of `input_dataset`.
/// </summary>


+ 13
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Runtime.CompilerServices;

namespace Tensorflow
{
public partial class Tensor
{
public static Tensor operator !=(Tensor x, int y)
=> gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype));
public static Tensor operator ==(Tensor x, int y)
=> gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype));
}
}

+ 6
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -144,6 +144,12 @@ namespace Tensorflow
break;
}
}
else if (dtype != TF_DataType.DtInvalid &&
value is NDArray nd &&
dtypes.as_dtype(nd.dtype) != dtype)
{
value = nd.astype(dtype.as_numpy_dtype());
}

if (dtype == TF_DataType.TF_STRING && value is byte[] bytes)
return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -87,7 +87,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
if (adapter_steps > -1)
return adapter_steps;

var size = dataset.dataset_cardinality();
var size = dataset.cardinality();
return size.numpy();
}



+ 16
- 3
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -147,10 +147,10 @@ namespace TensorFlowNET.UnitTest.Dataset
public void Cardinality()
{
var dataset = tf.data.Dataset.range(10);
var cardinality = dataset.dataset_cardinality();
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
dataset = dataset.map(x => x[0] + 1);
cardinality = dataset.dataset_cardinality();
cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

@@ -159,10 +159,23 @@ namespace TensorFlowNET.UnitTest.Dataset
{
var dataset = tf.data.Dataset.range(10);
dataset = dataset.map(x => x, num_parallel_calls: -1);
var cardinality = dataset.dataset_cardinality();
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
}

[TestMethod]
public void CardinalityWithRepeat()
{
var dataset = tf.data.Dataset.range(10);
dataset = dataset.repeat();
var cardinality = dataset.cardinality();
Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());

dataset = dataset.filter(x => true);
cardinality = dataset.cardinality();
Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
}

[TestMethod]
public void Shuffle()
{


Loading…
Cancel
Save