Browse Source

fix tensor_slice_dataset.

tags/v0.30
Oceania2018 5 years ago
parent
commit
2763f7c433
13 changed files with 205 additions and 20 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +24
    -0
      src/TensorFlowNET.Core/Data/FlatMapDataset.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  5. +9
    -0
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  6. +9
    -0
      src/TensorFlowNET.Core/Data/TensorSliceDataset.cs
  7. +18
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  8. +6
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs
  9. +19
    -14
      src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs
  10. +51
    -1
      src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  11. +25
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  12. +32
    -1
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Operations/random_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -19,12 +19,18 @@ namespace Tensorflow
public IDatasetV2 from_tensor(NDArray tensors) public IDatasetV2 from_tensor(NDArray tensors)
=> new TensorDataset(tensors); => new TensorDataset(tensors);


public IDatasetV2 from_tensor(Tensor features, Tensor labels)
=> new TensorDataset(features, labels);

public IDatasetV2 from_tensor(Tensor tensors) public IDatasetV2 from_tensor(Tensor tensors)
=> new TensorDataset(tensors); => new TensorDataset(tensors);


public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
=> new TensorSliceDataset(features, labels); => new TensorSliceDataset(features, labels);


public IDatasetV2 from_tensor_slices(Tensor tensor)
=> new TensorSliceDataset(tensor);

public IDatasetV2 from_tensor_slices(string[] array) public IDatasetV2 from_tensor_slices(string[] array)
=> new TensorSliceDataset(array); => new TensorSliceDataset(array);




+ 3
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -60,6 +60,9 @@ namespace Tensorflow
preserve_cardinality: preserve_cardinality, preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function); use_legacy_function: use_legacy_function);


public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
=> new FlatMapDataset(this, map_func);

public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
=> new ModelDataset(this, algorithm, cpu_budget); => new ModelDataset(this, algorithm, cpu_budget);




+ 24
- 0
src/TensorFlowNET.Core/Data/FlatMapDataset.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Functions;

namespace Tensorflow
{
/// <summary>
///
/// </summary>
public class FlatMapDataset : UnaryDataset
{
public FlatMapDataset(IDatasetV2 input_dataset,
Func<Tensor, IDatasetV2> map_func) : base(input_dataset)
{
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);

variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -62,6 +62,8 @@ namespace Tensorflow
bool preserve_cardinality = false, bool preserve_cardinality = false,
bool use_legacy_function = false); bool use_legacy_function = false);


IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);


/// <summary> /// <summary>


+ 9
- 0
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -12,6 +12,15 @@ namespace Tensorflow
/// </summary> /// </summary>
public class TensorDataset : DatasetSource public class TensorDataset : DatasetSource
{ {
public TensorDataset(Tensor feature, Tensor label)
{
_tensors = new[] { feature, label };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);

}
public TensorDataset(Tensor element) public TensorDataset(Tensor element)
{ {
_tensors = new[] { element }; _tensors = new[] { element };


+ 9
- 0
src/TensorFlowNET.Core/Data/TensorSliceDataset.cs View File

@@ -31,6 +31,15 @@ namespace Tensorflow.Data
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
} }


public TensorSliceDataset(Tensor tensor)
{
_tensors = new[] { tensor };
var batched_spec = new[] { tensor.ToTensorSpec() };
structure = batched_spec.Select(x => x._unbatch()).ToArray();

variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
}

public TensorSliceDataset(Tensor features, Tensor labels) public TensorSliceDataset(Tensor features, Tensor labels)
{ {
_tensors = new[] { features, labels }; _tensors = new[] { features, labels };


+ 18
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -33,6 +33,24 @@ namespace Tensorflow.Functions
} }
} }


public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { },
null);
}
}

public Tensor Execute(Tensor arg) public Tensor Execute(Tensor arg)
{ {
var result = tf.Runner.TFE_Execute(tf.Context, var result = tf.Runner.TFE_Execute(tf.Context,


+ 6
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
@@ -8,9 +9,13 @@ namespace Tensorflow.Keras.ArgsDefinition
{ {
public Tensor X { get; set; } public Tensor X { get; set; }
public Tensor Y { get; set; } public Tensor Y { get; set; }
public int BatchSize { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; } public int Steps { get; set; }
public int Epochs { get; set; } public int Epochs { get; set; }
public bool Shuffle { get; set; } public bool Shuffle { get; set; }
public int MaxQueueSize { get; set; }
public int Worker { get; set; }
public bool UseMultiprocessing { get; set; }
public Model Model { get; set; }
} }
} }

+ 19
- 14
src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -11,25 +11,30 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public class DataHandler public class DataHandler
{ {
DataHandlerArgs args; DataHandlerArgs args;

Tensor x => args.X;
Tensor y => args.Y;
int batch_size => args.BatchSize;
int steps_per_epoch => args.StepsPerEpoch;
int initial_epoch => args.InitialEpoch;
int epochs => args.Epochs;
bool shuffle => args.Shuffle;
int max_queue_size => args.MaxQueueSize;
int workers => args.Workers;
bool use_multiprocessing => args.UseMultiprocessing;
Model model => args.Model;
IVariableV1 steps_per_execution => args.StepsPerExecution;
IDataAdapter _adapter;


public DataHandler(DataHandlerArgs args) public DataHandler(DataHandlerArgs args)
{ {
this.args = args; this.args = args;


var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { });
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
{
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}

Tensor _infer_steps(IDatasetV2 dataset)
{
throw new NotImplementedException("");
} }
} }
} }

+ 51
- 1
src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -11,14 +11,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters
/// </summary> /// </summary>
public class TensorLikeDataAdapter : IDataAdapter public class TensorLikeDataAdapter : IDataAdapter
{ {
TensorLikeDataAdapterArgs args;
int _size;
int _batch_size;
int num_samples;
int num_full_batches;

public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
{ {
tf.data.Dataset.range(5);
this.args = args;
_process_tensorlike();
num_samples = args.X.shape[0];
var batch_size = args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
num_full_batches = num_samples / batch_size;
var _partial_batch_size = num_samples % batch_size;

var indices_dataset = tf.data.Dataset.range(1);
indices_dataset = indices_dataset.repeat();
indices_dataset = indices_dataset.map(permutation).prefetch(1);
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
}

Tensor permutation(Tensor tensor)
{
var indices = math_ops.range(num_samples, dtype: dtypes.int64);
if (args.Shuffle)
indices = random_ops.random_shuffle(indices);
return indices;
}

/// <summary>
/// Convert a Tensor of indices into a dataset of batched indices.
/// </summary>
/// <param name="tensor"></param>
/// <returns></returns>
IDatasetV2 slice_batch_indices(Tensor indices)
{
var num_in_full_batch = num_full_batches * _batch_size;
var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch });
first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size });
var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices);

return flat_dataset;
}

void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y)
{
var dataset = tf.data.Dataset.from_tensor(x, y);
} }


public bool CanHandle(Tensor x, Tensor y = null) public bool CanHandle(Tensor x, Tensor y = null)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }

void _process_tensorlike()
{
}
} }
} }

+ 25
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -84,14 +84,37 @@ namespace Tensorflow.Keras.Engine
/// <param name="verbose"></param> /// <param name="verbose"></param>
/// <param name="validation_split"></param> /// <param name="validation_split"></param>
/// <param name="shuffle"></param> /// <param name="shuffle"></param>
public void fit(NDArray x, NDArray y,
public void fit(NDArray x, NDArray y,
int batch_size = -1, int batch_size = -1,
int epochs = 1, int epochs = 1,
int verbose = 1, int verbose = 1,
float validation_split = 0f, float validation_split = 0f,
bool shuffle = true)
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{ {
int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
var val_x = x[new Slice(train_count)];
var val_y = y[new Slice(train_count)];


var data_handler = new DataHandler(new DataHandlerArgs
{
X = train_x,
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});
} }


void _configure_steps_per_execution(int steps_per_execution) void _configure_steps_per_execution(int steps_per_execution)


+ 32
- 1
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -49,7 +49,11 @@ namespace Tensorflow
return results[0]; return results[0];
} }


throw new NotImplementedException("");
var _op = tf.OpDefLib._apply_op_helper("TensorSliceDataset",
name: name,
args: new { components, output_shapes });

return _op.outputs[0];
} }


public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null)
@@ -440,6 +444,33 @@ namespace Tensorflow
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


/// <summary>
/// Creates a dataset that applies `f` to the outputs of `input_dataset`.
/// </summary>
/// <param name="dataset"></param>
/// <param name="f"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor flat_map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FlatMapDataset", name,
null,
dataset, new Tensor[0],
"f", f,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary> /// <summary>
/// A container for an iterator resource. /// A container for an iterator resource.
/// </summary> /// </summary>


+ 1
- 1
src/TensorFlowNET.Core/Operations/random_ops.cs View File

@@ -116,7 +116,7 @@ namespace Tensorflow
public static Tensor random_shuffle(Tensor value, int? seed = null, string name = null) public static Tensor random_shuffle(Tensor value, int? seed = null, string name = null)
{ {
var (seed1, seed2) = random_seed.get_seed(seed); var (seed1, seed2) = random_seed.get_seed(seed);
return gen_random_ops.random_shuffle(value, seed: seed1.Value, seed2: seed2.Value, name: name);
return gen_random_ops.random_shuffle(value, seed: seed1 ?? 0, seed2: seed2 ?? 0, name: name);
} }


public static Tensor truncated_normal(int[] shape, public static Tensor truncated_normal(int[] shape,


Loading…
Cancel
Save