| @@ -3,6 +3,7 @@ using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -98,6 +99,20 @@ namespace Tensorflow | |||||
| return dataset; | return dataset; | ||||
| } | } | ||||
| public Tensor dataset_cardinality(string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "DatasetCardinality", name, | |||||
| null, | |||||
| variant_tensor); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | ||||
| @@ -117,7 +132,9 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| yield return (results[0], results.Length == 1 ? null : results[1]); | |||||
| yield return results.Length == 2 | |||||
| ? (results[0], results[1]) | |||||
| : (null, results[0]); | |||||
| } | } | ||||
| } | } | ||||
| @@ -74,5 +74,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IDatasetV2 apply_options(); | IDatasetV2 apply_options(); | ||||
| Tensor dataset_cardinality(string name = null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,11 +15,11 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
| { | { | ||||
| var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||||
| var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input"); | |||||
| using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||||
| var input = tf.placeholder(input_dataset.element_spec[0].dtype); | |||||
| var output = map_func(input); | var output = map_func(input); | ||||
| func.ToGraph(input, output); | func.ToGraph(input, output); | ||||
| structure = func.OutputStructure; | structure = func.OutputStructure; | ||||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | ||||
| @@ -130,6 +130,9 @@ namespace Tensorflow.Functions | |||||
| return new ForwardBackwardCall(functions, args, tape_watching: true); | return new ForwardBackwardCall(functions, args, tape_watching: true); | ||||
| } | } | ||||
| public override string ToString() | |||||
| => Name; | |||||
| public void Dispose() | public void Dispose() | ||||
| { | { | ||||
| c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | ||||
| @@ -2,10 +2,11 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class TensorLikeDataAdapterArgs | |||||
| public class DataAdapterArgs | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| public IDatasetV2 Dataset { get; set; } | |||||
| public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
| public int Steps { get; set; } | public int Steps { get; set; } | ||||
| public int Epochs { get; set; } | public int Epochs { get; set; } | ||||
| @@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| public IDatasetV2 Dataset { get; set; } | |||||
| public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
| public int StepsPerEpoch { get; set; } = -1; | public int StepsPerEpoch { get; set; } = -1; | ||||
| public int InitialEpoch { get; set; } = 0; | public int InitialEpoch { get; set; } = 0; | ||||