From f2e41a17916b25ff6fd3baf20ed6fc0d651fb4c2 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Thu, 2 Feb 2023 17:34:50 +0800 Subject: [PATCH] Support autograph.to_graph under graph mode. --- src/TensorFlowNET.Core/Graphs/AutoGraph.cs | 46 +++++++++++++++------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index 2af1a372..ceeca8ab 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using static Tensorflow.Binding; @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs { public class AutoGraph { - public Func to_graph(Func func) + public Func to_graph(Func func, TF_DataType dtype = TF_DataType.TF_INT32) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input = tf.placeholder(tf.int32); + var input = tf.placeholder(dtype); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -26,25 +27,33 @@ namespace Tensorflow.Graphs return (Tensor input) => { - var result = tf.Runner.TFE_Execute(tf.Context, - tf.Context.DeviceName, - func_name, - new[] { input }, - null, - 1); - return result[0]; + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, + tf.Context.DeviceName, + func_name, + new[] { input }, + null, + 1); + return result[0]; + } + using (var s = tf.Session(input.graph)) + { + var output = func(input); + return output; + } }; } - public Func to_graph(Func func) + public Func to_graph(Func func, params TF_DataType[] dtypes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; var graph = new FuncGraph(func_name); graph.as_default(); - var input1 = tf.placeholder(tf.int32); - var input2 = tf.placeholder(tf.int32); + var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); + var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); @@ -56,13 +65,22 @@ namespace Tensorflow.Graphs return (Tensor a, Tensor b) => { - var result = tf.Runner.TFE_Execute(tf.Context, + if (tf.executing_eagerly()) + { + var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); - return result[0]; + return result[0]; + } + using (var s = tf.Session(a.graph)) + { + Debug.Assert(a.graph == b.graph); + var output = func(a, b); + return output; + } }; } }