|
|
|
@@ -0,0 +1,59 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using System.Text; |
|
|
|
using Tensorflow.Graphs; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Functions |
|
|
|
{ |
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// </summary> |
|
|
|
public class ConcreteFunction : IDisposable |
|
|
|
{ |
|
|
|
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); |
|
|
|
IntPtr _handle; |
|
|
|
|
|
|
|
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) |
|
|
|
{ |
|
|
|
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; |
|
|
|
|
|
|
|
tf.compat.v1.disable_eager_execution(); |
|
|
|
|
|
|
|
// IntPtr func_handle; |
|
|
|
using (var graph = new FuncGraph(func_name)) |
|
|
|
{ |
|
|
|
graph.as_default(); |
|
|
|
var input = tf.placeholder(dtype); |
|
|
|
var output = func(input); |
|
|
|
|
|
|
|
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); |
|
|
|
_handle = graph.ToGraph(opers, |
|
|
|
new Operation[] { input }, |
|
|
|
new Operation[] { output }, |
|
|
|
null); |
|
|
|
|
|
|
|
c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle); |
|
|
|
} |
|
|
|
|
|
|
|
tf.enable_eager_execution(); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor Execute(Tensor arg) |
|
|
|
{ |
|
|
|
var result = tf.Runner.TFE_Execute(tf.Context, |
|
|
|
tf.Context.DeviceName, |
|
|
|
Name, |
|
|
|
new[] { arg }, |
|
|
|
null, |
|
|
|
1); |
|
|
|
return result[0]; |
|
|
|
} |
|
|
|
|
|
|
|
public void Dispose() |
|
|
|
{ |
|
|
|
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); |
|
|
|
} |
|
|
|
} |
|
|
|
} |