| @@ -52,7 +52,7 @@ namespace Tensorflow | |||||
| public IDatasetV2 map(Func<Tensor, Tensor> map_func, | public IDatasetV2 map(Func<Tensor, Tensor> map_func, | ||||
| bool use_inter_op_parallelism = true, | bool use_inter_op_parallelism = true, | ||||
| bool preserve_cardinality = false, | |||||
| bool preserve_cardinality = true, | |||||
| bool use_legacy_function = false) | bool use_legacy_function = false) | ||||
| => new MapDataset(this, | => new MapDataset(this, | ||||
| map_func, | map_func, | ||||
| @@ -1,6 +1,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -15,12 +19,10 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
| { | { | ||||
| foreach(var input in input_dataset) | |||||
| { | |||||
| var data = map_func(input.Item1); | |||||
| } | |||||
| var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); | |||||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | ||||
| func, | |||||
| output_types, | output_types, | ||||
| output_shapes); | output_shapes); | ||||
| } | } | ||||
| @@ -6,6 +6,7 @@ using static Tensorflow.Binding; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
| using Tensorflow.Functions; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| @@ -385,7 +386,10 @@ namespace Tensorflow.Eager | |||||
| status.Check(true); | status.Check(true); | ||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
| c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length); | |||||
| if (value is ConcreteFunction func) | |||||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||||
| else | |||||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||||
| break; | break; | ||||
| default: | default: | ||||
| throw new NotImplementedException($"SetOpAttrScalar for {type}"); | throw new NotImplementedException($"SetOpAttrScalar for {type}"); | ||||
| @@ -0,0 +1,59 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| public class ConcreteFunction : IDisposable | |||||
| { | |||||
| public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
| IntPtr _handle; | |||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||||
| { | |||||
| string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||||
| tf.compat.v1.disable_eager_execution(); | |||||
| // IntPtr func_handle; | |||||
| using (var graph = new FuncGraph(func_name)) | |||||
| { | |||||
| graph.as_default(); | |||||
| var input = tf.placeholder(dtype); | |||||
| var output = func(input); | |||||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
| _handle = graph.ToGraph(opers, | |||||
| new Operation[] { input }, | |||||
| new Operation[] { output }, | |||||
| null); | |||||
| c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle); | |||||
| } | |||||
| tf.enable_eager_execution(); | |||||
| } | |||||
| public Tensor Execute(Tensor arg) | |||||
| { | |||||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||||
| tf.Context.DeviceName, | |||||
| Name, | |||||
| new[] { arg }, | |||||
| null, | |||||
| 1); | |||||
| return result[0]; | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -21,6 +21,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class c_api | public partial class c_api | ||||
| { | { | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_DeleteFunction(IntPtr handle); | |||||
| /// <summary> | /// <summary> | ||||
| /// Write out a serialized representation of `func` (as a FunctionDef protocol | /// Write out a serialized representation of `func` (as a FunctionDef protocol | ||||
| /// message) to `output_func_def` (allocated by TF_NewBuffer()). | /// message) to `output_func_def` (allocated by TF_NewBuffer()). | ||||
| @@ -1,7 +1,9 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Graphs | namespace Tensorflow.Graphs | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -419,7 +420,7 @@ namespace Tensorflow | |||||
| /// <param name="iterator"></param> | /// <param name="iterator"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) | bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| @@ -428,7 +429,7 @@ namespace Tensorflow | |||||
| "MapDataset", name, | "MapDataset", name, | ||||
| null, | null, | ||||
| dataset, new Tensor[0], | dataset, new Tensor[0], | ||||
| "f", "MapDataset", | |||||
| "f", f, | |||||
| "output_types", output_types, | "output_types", output_types, | ||||
| "output_shapes", output_shapes, | "output_shapes", output_shapes, | ||||
| "use_inter_op_parallelism", use_inter_op_parallelism, | "use_inter_op_parallelism", use_inter_op_parallelism, | ||||
| @@ -118,17 +118,17 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| } | } | ||||
| } | } | ||||
| [TestMethod, Ignore] | |||||
| [TestMethod] | |||||
| public void Map() | public void Map() | ||||
| { | { | ||||
| long value = 0; | long value = 0; | ||||
| var dataset = tf.data.Dataset.range(3); | |||||
| var dataset1 = dataset.map(x => x); | |||||
| var dataset = tf.data.Dataset.range(0, 2); | |||||
| dataset = dataset.map(x => x + 10); | |||||
| foreach (var item in dataset) | foreach (var item in dataset) | ||||
| { | { | ||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| Assert.AreEqual(value + 10, (long)item.Item1); | |||||
| value++; | value++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Diagnostics.CodeAnalysis; | using System.Diagnostics.CodeAnalysis; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||