| @@ -66,7 +66,9 @@ namespace Tensorflow | |||||
| use_legacy_function: use_legacy_function); | use_legacy_function: use_legacy_function); | ||||
| public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls) | public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls) | ||||
| => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | |||||
| => new ParallelMapDataset(this, map_func, | |||||
| num_parallel_calls: num_parallel_calls, | |||||
| preserve_cardinality: true); | |||||
| public OwnedIterator make_one_shot_iterator() | public OwnedIterator make_one_shot_iterator() | ||||
| { | { | ||||
| @@ -15,18 +15,26 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
| { | { | ||||
| var func = new ConcreteFunction(map_func, | |||||
| input_dataset.element_spec.Select(x => x.dtype).ToArray(), | |||||
| input_dataset.element_spec.Select(x => x.shape).ToArray()); | |||||
| var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||||
| func.Enter(); | |||||
| var inputs = new Tensors(); | |||||
| foreach (var input in input_dataset.element_spec) | |||||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||||
| var outputs = map_func(inputs); | |||||
| func.ToGraph(inputs, outputs); | |||||
| func.Exit(); | |||||
| structure = func.OutputStructure; | structure = func.OutputStructure; | ||||
| var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, | var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, | ||||
| name: "num_parallel_calls"); | name: "num_parallel_calls"); | ||||
| variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, | variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, | ||||
| _num_parallel_calls, | _num_parallel_calls, | ||||
| func, | func, | ||||
| output_types, | output_types, | ||||
| output_shapes); | |||||
| output_shapes, | |||||
| use_inter_op_parallelism: use_inter_op_parallelism, | |||||
| preserve_cardinality: preserve_cardinality); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Functions | |||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| } | } | ||||
| public ConcreteFunction(Func<Tensors, Tensors> func, | |||||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | |||||
| TF_DataType[] dtypes, TensorShape[] shapes) | TF_DataType[] dtypes, TensorShape[] shapes) | ||||
| { | { | ||||
| string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | ||||
| @@ -89,7 +89,7 @@ namespace Tensorflow.Functions | |||||
| var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
| func_graph.ToGraph(opers, inputs, Outputs, null); | func_graph.ToGraph(opers, inputs, Outputs, null); | ||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| } | |||||
| }*/ | |||||
| public void ToGraph(Tensors inputs, Tensors outputs) | public void ToGraph(Tensors inputs, Tensors outputs) | ||||
| { | { | ||||
| @@ -38,6 +38,8 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[params string[] slices] | |||||
| => items.First()[slices]; | |||||
| public Tensors(params Tensor[] tensors) | public Tensors(params Tensor[] tensors) | ||||
| { | { | ||||
| items.AddRange(tensors); | items.AddRange(tensors); | ||||
| @@ -585,6 +585,10 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| else | else | ||||
| return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; | return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; | ||||
| } | } | ||||
| else if(dtype == TF_DataType.TF_VARIANT) | |||||
| { | |||||
| return "<unprintable>"; | |||||
| } | |||||
| var nd = tensor.numpy(); | var nd = tensor.numpy(); | ||||
| @@ -100,6 +100,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| using var data_iterator = new OwnedIterator(_dataset); | using var data_iterator = new OwnedIterator(_dataset); | ||||
| yield return (epoch, data_iterator); | yield return (epoch, data_iterator); | ||||
| } | } | ||||
| // _adapter.on_epoch_end() | |||||
| } | } | ||||
| public IEnumerable<int> steps() | public IEnumerable<int> steps() | ||||
| @@ -33,22 +33,22 @@ namespace Tensorflow.Keras.Engine | |||||
| public void compile(string optimizer, string loss, string[] metrics) | public void compile(string optimizer, string loss, string[] metrics) | ||||
| { | { | ||||
| switch (optimizer) | |||||
| var _optimizer = optimizer switch | |||||
| { | { | ||||
| case "rmsprop": | |||||
| this.optimizer = new RMSprop(new RMSpropArgs | |||||
| { | |||||
| "rmsprop" => new RMSprop(new RMSpropArgs | |||||
| { | |||||
| }); | |||||
| break; | |||||
| } | |||||
| }), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| int experimental_steps_per_execution = 1; | |||||
| _configure_steps_per_execution(experimental_steps_per_execution); | |||||
| _reset_compile_cache(); | |||||
| var _loss = loss switch | |||||
| { | |||||
| "mse" => new MeanSquaredError(), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| _is_compiled = true; | |||||
| compile(optimizer: _optimizer, loss: _loss, metrics: metrics); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -49,7 +49,32 @@ namespace Tensorflow.Keras.Engine | |||||
| Binding.tf_output_redirect.WriteLine($"Testing..."); | Binding.tf_output_redirect.WriteLine($"Testing..."); | ||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| { | { | ||||
| // reset_metrics(); | |||||
| reset_metrics(); | |||||
| // callbacks.on_epoch_begin(epoch) | |||||
| // data_handler.catch_stop_iteration(); | |||||
| IEnumerable<(string, Tensor)> results = null; | |||||
| foreach (var step in data_handler.steps()) | |||||
| { | |||||
| // callbacks.on_train_batch_begin(step) | |||||
| results = test_function(iterator); | |||||
| } | |||||
| Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||||
| } | |||||
| } | |||||
| public void evaluate(IDatasetV2 x) | |||||
| { | |||||
| data_handler = new DataHandler(new DataHandlerArgs | |||||
| { | |||||
| Dataset = x, | |||||
| Model = this, | |||||
| StepsPerExecution = _steps_per_execution | |||||
| }); | |||||
| Binding.tf_output_redirect.WriteLine($"Testing..."); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
| { | |||||
| reset_metrics(); | |||||
| // callbacks.on_epoch_begin(epoch) | // callbacks.on_epoch_begin(epoch) | ||||
| // data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
| IEnumerable<(string, Tensor)> results = null; | IEnumerable<(string, Tensor)> results = null; | ||||
| @@ -124,10 +124,11 @@ namespace Tensorflow.Keras | |||||
| var start_positions_tensor = tf.constant(start_positions); | var start_positions_tensor = tf.constant(start_positions); | ||||
| var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); | var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); | ||||
| var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds); | |||||
| var r = tf.data.Dataset.range(len(start_positions)); | |||||
| var z = tf.data.Dataset.zip(r, positions_ds); | |||||
| var indices = z.map(m => | var indices = z.map(m => | ||||
| { | { | ||||
| var (i, positions) = (m[0], m[1]); | |||||
| var (i, positions) = m; | |||||
| return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); | return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); | ||||
| }, num_parallel_calls: -1); | }, num_parallel_calls: -1); | ||||
| var dataset = sequences_from_indices(data, indices, start_index, end_index); | var dataset = sequences_from_indices(data, indices, start_index, end_index); | ||||
| @@ -142,7 +143,11 @@ namespace Tensorflow.Keras | |||||
| { | { | ||||
| var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); | var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); | ||||
| dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) | dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) | ||||
| .map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1); | |||||
| .map(x => | |||||
| { | |||||
| var (steps, indx) = x; | |||||
| return array_ops.gather(steps, indx); | |||||
| }, num_parallel_calls: -1); | |||||
| return dataset; | return dataset; | ||||
| } | } | ||||
| } | } | ||||
| @@ -147,7 +147,18 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| public void Cardinality() | public void Cardinality() | ||||
| { | { | ||||
| var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
| var cardinality = dataset.dataset_cardinality(); | |||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
| dataset = dataset.map(x => x[0] + 1); | dataset = dataset.map(x => x[0] + 1); | ||||
| cardinality = dataset.dataset_cardinality(); | |||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void CardinalityWithAutoTune() | |||||
| { | |||||
| var dataset = tf.data.Dataset.range(10); | |||||
| dataset = dataset.map(x => x, num_parallel_calls: -1); | |||||
| var cardinality = dataset.dataset_cardinality(); | var cardinality = dataset.dataset_cardinality(); | ||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | ||||
| } | } | ||||