| @@ -19,12 +19,18 @@ namespace Tensorflow | |||
| public IDatasetV2 from_tensor(NDArray tensors) | |||
| => new TensorDataset(tensors); | |||
| public IDatasetV2 from_tensor(Tensor features, Tensor labels) | |||
| => new TensorDataset(features, labels); | |||
| public IDatasetV2 from_tensor(Tensor tensors) | |||
| => new TensorDataset(tensors); | |||
| public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | |||
| => new TensorSliceDataset(features, labels); | |||
| public IDatasetV2 from_tensor_slices(Tensor tensor) | |||
| => new TensorSliceDataset(tensor); | |||
| public IDatasetV2 from_tensor_slices(string[] array) | |||
| => new TensorSliceDataset(array); | |||
| @@ -60,6 +60,9 @@ namespace Tensorflow | |||
| preserve_cardinality: preserve_cardinality, | |||
| use_legacy_function: use_legacy_function); | |||
| public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | |||
| => new FlatMapDataset(this, map_func); | |||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | |||
| => new ModelDataset(this, algorithm, cpu_budget); | |||
| @@ -0,0 +1,24 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| public class FlatMapDataset : UnaryDataset | |||
| { | |||
| public FlatMapDataset(IDatasetV2 input_dataset, | |||
| Func<Tensor, IDatasetV2> map_func) : base(input_dataset) | |||
| { | |||
| var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); | |||
| variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor, | |||
| func, | |||
| output_types, | |||
| output_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -62,6 +62,8 @@ namespace Tensorflow | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false); | |||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | |||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | |||
| /// <summary> | |||
| @@ -12,6 +12,15 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class TensorDataset : DatasetSource | |||
| { | |||
| public TensorDataset(Tensor feature, Tensor label) | |||
| { | |||
| _tensors = new[] { feature, label }; | |||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||
| variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||
| } | |||
| public TensorDataset(Tensor element) | |||
| { | |||
| _tensors = new[] { element }; | |||
| @@ -31,6 +31,15 @@ namespace Tensorflow.Data | |||
| variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | |||
| } | |||
| public TensorSliceDataset(Tensor tensor) | |||
| { | |||
| _tensors = new[] { tensor }; | |||
| var batched_spec = new[] { tensor.ToTensorSpec() }; | |||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||
| variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | |||
| } | |||
| public TensorSliceDataset(Tensor features, Tensor labels) | |||
| { | |||
| _tensors = new[] { features, labels }; | |||
| @@ -33,6 +33,24 @@ namespace Tensorflow.Functions | |||
| } | |||
| } | |||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
| { | |||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
| // IntPtr func_handle; | |||
| using (var graph = new FuncGraph(func_name)) | |||
| { | |||
| var input = tf.placeholder(dtype); | |||
| var output = func(input); | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| _handle = graph.ToGraph(opers, | |||
| new Operation[] { input }, | |||
| new Operation[] { }, | |||
| null); | |||
| } | |||
| } | |||
| public Tensor Execute(Tensor arg) | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| @@ -8,9 +9,13 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| public Tensor X { get; set; } | |||
| public Tensor Y { get; set; } | |||
| public int BatchSize { get; set; } | |||
| public int BatchSize { get; set; } = 32; | |||
| public int Steps { get; set; } | |||
| public int Epochs { get; set; } | |||
| public bool Shuffle { get; set; } | |||
| public int MaxQueueSize { get; set; } | |||
| public int Worker { get; set; } | |||
| public bool UseMultiprocessing { get; set; } | |||
| public Model Model { get; set; } | |||
| } | |||
| } | |||
| @@ -11,25 +11,30 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public class DataHandler | |||
| { | |||
| DataHandlerArgs args; | |||
| Tensor x => args.X; | |||
| Tensor y => args.Y; | |||
| int batch_size => args.BatchSize; | |||
| int steps_per_epoch => args.StepsPerEpoch; | |||
| int initial_epoch => args.InitialEpoch; | |||
| int epochs => args.Epochs; | |||
| bool shuffle => args.Shuffle; | |||
| int max_queue_size => args.MaxQueueSize; | |||
| int workers => args.Workers; | |||
| bool use_multiprocessing => args.UseMultiprocessing; | |||
| Model model => args.Model; | |||
| IVariableV1 steps_per_execution => args.StepsPerExecution; | |||
| IDataAdapter _adapter; | |||
| public DataHandler(DataHandlerArgs args) | |||
| { | |||
| this.args = args; | |||
| var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { }); | |||
| _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | |||
| { | |||
| X = args.X, | |||
| Y = args.Y, | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| UseMultiprocessing = args.UseMultiprocessing, | |||
| Model = args.Model | |||
| }); | |||
| } | |||
| Tensor _infer_steps(IDatasetV2 dataset) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| } | |||
| } | |||
| @@ -11,14 +11,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| /// </summary> | |||
| public class TensorLikeDataAdapter : IDataAdapter | |||
| { | |||
| TensorLikeDataAdapterArgs args; | |||
| int _size; | |||
| int _batch_size; | |||
| int num_samples; | |||
| int num_full_batches; | |||
| public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||
| { | |||
| tf.data.Dataset.range(5); | |||
| this.args = args; | |||
| _process_tensorlike(); | |||
| num_samples = args.X.shape[0]; | |||
| var batch_size = args.BatchSize; | |||
| _batch_size = batch_size; | |||
| _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f))); | |||
| num_full_batches = num_samples / batch_size; | |||
| var _partial_batch_size = num_samples % batch_size; | |||
| var indices_dataset = tf.data.Dataset.range(1); | |||
| indices_dataset = indices_dataset.repeat(); | |||
| indices_dataset = indices_dataset.map(permutation).prefetch(1); | |||
| indices_dataset = indices_dataset.flat_map(slice_batch_indices); | |||
| } | |||
| Tensor permutation(Tensor tensor) | |||
| { | |||
| var indices = math_ops.range(num_samples, dtype: dtypes.int64); | |||
| if (args.Shuffle) | |||
| indices = random_ops.random_shuffle(indices); | |||
| return indices; | |||
| } | |||
| /// <summary> | |||
| /// Convert a Tensor of indices into a dataset of batched indices. | |||
| /// </summary> | |||
| /// <param name="tensor"></param> | |||
| /// <returns></returns> | |||
| IDatasetV2 slice_batch_indices(Tensor indices) | |||
| { | |||
| var num_in_full_batch = num_full_batches * _batch_size; | |||
| var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | |||
| first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | |||
| var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | |||
| return flat_dataset; | |||
| } | |||
| void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||
| { | |||
| var dataset = tf.data.Dataset.from_tensor(x, y); | |||
| } | |||
| public bool CanHandle(Tensor x, Tensor y = null) | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| void _process_tensorlike() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -84,14 +84,37 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="verbose"></param> | |||
| /// <param name="validation_split"></param> | |||
| /// <param name="shuffle"></param> | |||
| public void fit(NDArray x, NDArray y, | |||
| public void fit(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| float validation_split = 0f, | |||
| bool shuffle = true) | |||
| bool shuffle = true, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false) | |||
| { | |||
| int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); | |||
| var train_x = x[new Slice(0, train_count)]; | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x[new Slice(train_count)]; | |||
| var val_y = y[new Slice(train_count)]; | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = train_x, | |||
| Y = train_y, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| Model = this, | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| } | |||
| void _configure_steps_per_execution(int steps_per_execution) | |||
| @@ -49,7 +49,11 @@ namespace Tensorflow | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| var _op = tf.OpDefLib._apply_op_helper("TensorSliceDataset", | |||
| name: name, | |||
| args: new { components, output_shapes }); | |||
| return _op.outputs[0]; | |||
| } | |||
| public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) | |||
| @@ -440,6 +444,33 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// Creates a dataset that applies `f` to the outputs of `input_dataset`. | |||
| /// </summary> | |||
| /// <param name="dataset"></param> | |||
| /// <param name="f"></param> | |||
| /// <param name="output_types"></param> | |||
| /// <param name="output_shapes"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor flat_map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes, | |||
| string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "FlatMapDataset", name, | |||
| null, | |||
| dataset, new Tensor[0], | |||
| "f", f, | |||
| "output_types", output_types, | |||
| "output_shapes", output_shapes); | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// A container for an iterator resource. | |||
| /// </summary> | |||
| @@ -116,7 +116,7 @@ namespace Tensorflow | |||
| public static Tensor random_shuffle(Tensor value, int? seed = null, string name = null) | |||
| { | |||
| var (seed1, seed2) = random_seed.get_seed(seed); | |||
| return gen_random_ops.random_shuffle(value, seed: seed1.Value, seed2: seed2.Value, name: name); | |||
| return gen_random_ops.random_shuffle(value, seed: seed1 ?? 0, seed2: seed2 ?? 0, name: name); | |||
| } | |||
| public static Tensor truncated_normal(int[] shape, | |||