diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 35e6d3f5..3a4fb7d8 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -52,7 +52,7 @@ namespace Tensorflow public IDatasetV2 map(Func map_func, bool use_inter_op_parallelism = true, - bool preserve_cardinality = false, + bool preserve_cardinality = true, bool use_legacy_function = false) => new MapDataset(this, map_func, diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 03c5f026..daa11202 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -1,6 +1,10 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Functions; +using Tensorflow.Graphs; +using static Tensorflow.Binding; namespace Tensorflow { @@ -15,12 +19,10 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { - foreach(var input in input_dataset) - { - var data = map_func(input.Item1); - } + var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); variant_tensor = ops.map_dataset(input_dataset.variant_tensor, + func, output_types, output_shapes); } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index f5c0a8ee..84e27cc6 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -6,6 +6,7 @@ using static Tensorflow.Binding; using Tensorflow.Util; using System.Runtime.InteropServices; using Tensorflow.Contexts; +using Tensorflow.Functions; namespace Tensorflow.Eager { @@ -385,7 +386,10 @@ namespace Tensorflow.Eager status.Check(true); break; case TF_AttrType.TF_ATTR_FUNC: - c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length); + if (value is ConcreteFunction func) + c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); + else + throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); break; default: throw new NotImplementedException($"SetOpAttrScalar for {type}"); diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs new file mode 100644 index 00000000..f05bdbc4 --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -0,0 +1,59 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Graphs; +using static Tensorflow.Binding; + +namespace Tensorflow.Functions +{ + /// + /// + /// + public class ConcreteFunction : IDisposable + { + public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); + IntPtr _handle; + + public ConcreteFunction(Func func, TF_DataType dtype) + { + string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; + + tf.compat.v1.disable_eager_execution(); + + // IntPtr func_handle; + using (var graph = new FuncGraph(func_name)) + { + graph.as_default(); + var input = tf.placeholder(dtype); + var output = func(input); + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = graph.ToGraph(opers, + new Operation[] { input }, + new Operation[] { output }, + null); + + c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle); + } + + tf.enable_eager_execution(); + } + + public Tensor Execute(Tensor arg) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + Name, + new[] { arg }, + null, + 1); + return result[0]; + } + + public void Dispose() + { + c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); + } + } +} diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index 9e800c56..bf93ae74 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -21,6 +21,9 @@ namespace Tensorflow { public partial class c_api { + [DllImport(TensorFlowLibName)] + public static extern void TF_DeleteFunction(IntPtr handle); + /// /// Write out a serialized representation of `func` (as a FunctionDef protocol /// message) to `output_func_def` (allocated by TF_NewBuffer()). diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index dc74bc0a..3da3b267 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -1,7 +1,9 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using System.Text; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow.Graphs diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 74b203df..b20ca7f2 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Framework.Models; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow @@ -419,7 +420,7 @@ namespace Tensorflow /// /// /// - public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes, + public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) { if (tf.Context.executing_eagerly()) @@ -428,7 +429,7 @@ namespace Tensorflow "MapDataset", name, null, dataset, new Tensor[0], - "f", "MapDataset", + "f", f, "output_types", output_types, "output_shapes", output_shapes, "use_inter_op_parallelism", use_inter_op_parallelism, diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index af1a91fa..f8bd5475 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -118,17 +118,17 @@ namespace TensorFlowNET.UnitTest.Dataset } } - [TestMethod, Ignore] + [TestMethod] public void Map() { long value = 0; - var dataset = tf.data.Dataset.range(3); - var dataset1 = dataset.map(x => x); + var dataset = tf.data.Dataset.range(0, 2); + dataset = dataset.map(x => x + 10); foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value + 10, (long)item.Item1); value++; } } diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs b/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs index 24766453..d5adfdb0 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs @@ -2,6 +2,7 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Tensorflow; +using Tensorflow.Functions; using Tensorflow.Util; using Buffer = Tensorflow.Buffer;