Browse Source

Add ConcreteFunction to support dataset map.

v0.20-tensorflow2.3
Oceania2018 Haiping 5 years ago
parent
commit
6c72af1503
9 changed files with 84 additions and 12 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +6
    -4
      src/TensorFlowNET.Core/Data/MapDataset.cs
  3. +5
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +59
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  7. +3
    -2
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  8. +4
    -4
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
  9. +1
    -0
      test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs

+ 1
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -52,7 +52,7 @@ namespace Tensorflow

public IDatasetV2 map(Func<Tensor, Tensor> map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool preserve_cardinality = true,
bool use_legacy_function = false)
=> new MapDataset(this,
map_func,


+ 6
- 4
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -1,6 +1,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -15,12 +19,10 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
foreach(var input in input_dataset)
{
var data = map_func(input.Item1);
}
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}


+ 5
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -6,6 +6,7 @@ using static Tensorflow.Binding;
using Tensorflow.Util;
using System.Runtime.InteropServices;
using Tensorflow.Contexts;
using Tensorflow.Functions;

namespace Tensorflow.Eager
{
@@ -385,7 +386,10 @@ namespace Tensorflow.Eager
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length);
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;
default:
throw new NotImplementedException($"SetOpAttrScalar for {type}");


+ 59
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -0,0 +1,59 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
{
/// <summary>
///
/// </summary>
public class ConcreteFunction : IDisposable
{
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
IntPtr _handle;

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

tf.compat.v1.disable_eager_execution();

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
graph.as_default();
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output },
null);

c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle);
}

tf.enable_eager_execution();
}

public Tensor Execute(Tensor arg)
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
Name,
new[] { arg },
null,
1);
return result[0];
}

public void Dispose()
{
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -21,6 +21,9 @@ namespace Tensorflow
{
public partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TF_DeleteFunction(IntPtr handle);

/// <summary>
/// Write out a serialized representation of `func` (as a FunctionDef protocol
/// message) to `output_func_def` (allocated by TF_NewBuffer()).


+ 2
- 0
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs


+ 3
- 2
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -419,7 +420,7 @@ namespace Tensorflow
/// <param name="iterator"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes,
public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes,
bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null)
{
if (tf.Context.executing_eagerly())
@@ -428,7 +429,7 @@ namespace Tensorflow
"MapDataset", name,
null,
dataset, new Tensor[0],
"f", "MapDataset",
"f", f,
"output_types", output_types,
"output_shapes", output_shapes,
"use_inter_op_parallelism", use_inter_op_parallelism,


+ 4
- 4
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -118,17 +118,17 @@ namespace TensorFlowNET.UnitTest.Dataset
}
}

[TestMethod, Ignore]
[TestMethod]
public void Map()
{
long value = 0;

var dataset = tf.data.Dataset.range(3);
var dataset1 = dataset.map(x => x);
var dataset = tf.data.Dataset.range(0, 2);
dataset = dataset.map(x => x + 10);

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value + 10, (long)item.Item1);
value++;
}
}


+ 1
- 0
test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs View File

@@ -2,6 +2,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Tensorflow;
using Tensorflow.Functions;
using Tensorflow.Util;
using Buffer = Tensorflow.Buffer;



Loading…
Cancel
Save