Browse Source

Support autograph.to_graph under graph mode.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
f2e41a1791
1 changed files with 32 additions and 14 deletions
  1. +32
    -14
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs

+ 32
- 14
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;

@@ -6,14 +7,14 @@ namespace Tensorflow.Graphs
{
public class AutoGraph
{
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32)
{
string func_name = $"{func.Method.Name}_{ops.uid_function()}";

var graph = new FuncGraph(func_name);
graph.as_default();

var input = tf.placeholder(tf.int32);
var input = tf.placeholder(dtype);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -26,25 +27,33 @@ namespace Tensorflow.Graphs

return (Tensor input) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
}
using (var s = tf.Session(input.graph))
{
var output = func(input);
return output;
}
};
}

public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes)
{
string func_name = $"{func.Method.Name}_{ops.uid_function()}";

var graph = new FuncGraph(func_name);
graph.as_default();

var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32);
var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32);
var output = func(input1, input2);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
@@ -56,13 +65,22 @@ namespace Tensorflow.Graphs
return (Tensor a, Tensor b) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
if (tf.executing_eagerly())
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { a, b },
null,
1);
return result[0];
return result[0];
}
using (var s = tf.Session(a.graph))
{
Debug.Assert(a.graph == b.graph);
var output = func(a, b);
return output;
}
};
}
}


Loading…
Cancel
Save