| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| @@ -6,14 +7,14 @@ namespace Tensorflow.Graphs | |||
| { | |||
| public class AutoGraph | |||
| { | |||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func) | |||
| public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func, TF_DataType dtype = TF_DataType.TF_INT32) | |||
| { | |||
| string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||
| var graph = new FuncGraph(func_name); | |||
| graph.as_default(); | |||
| var input = tf.placeholder(tf.int32); | |||
| var input = tf.placeholder(dtype); | |||
| var output = func(input); | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| @@ -26,25 +27,33 @@ namespace Tensorflow.Graphs | |||
| return (Tensor input) => | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| func_name, | |||
| new[] { input }, | |||
| null, | |||
| 1); | |||
| return result[0]; | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| func_name, | |||
| new[] { input }, | |||
| null, | |||
| 1); | |||
| return result[0]; | |||
| } | |||
| using (var s = tf.Session(input.graph)) | |||
| { | |||
| var output = func(input); | |||
| return output; | |||
| } | |||
| }; | |||
| } | |||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func) | |||
| public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func, params TF_DataType[] dtypes) | |||
| { | |||
| string func_name = $"{func.Method.Name}_{ops.uid_function()}"; | |||
| var graph = new FuncGraph(func_name); | |||
| graph.as_default(); | |||
| var input1 = tf.placeholder(tf.int32); | |||
| var input2 = tf.placeholder(tf.int32); | |||
| var input1 = tf.placeholder(dtypes.Length >= 1 ? dtypes[0] : tf.int32); | |||
| var input2 = tf.placeholder(dtypes.Length >= 2 ? dtypes[1] : tf.int32); | |||
| var output = func(input1, input2); | |||
| var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
| @@ -56,13 +65,22 @@ namespace Tensorflow.Graphs | |||
| return (Tensor a, Tensor b) => | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| if (tf.executing_eagerly()) | |||
| { | |||
| var result = tf.Runner.TFE_Execute(tf.Context, | |||
| tf.Context.DeviceName, | |||
| func_name, | |||
| new[] { a, b }, | |||
| null, | |||
| 1); | |||
| return result[0]; | |||
| return result[0]; | |||
| } | |||
| using (var s = tf.Session(a.graph)) | |||
| { | |||
| Debug.Assert(a.graph == b.graph); | |||
| var output = func(a, b); | |||
| return output; | |||
| } | |||
| }; | |||
| } | |||
| } | |||