|
|
|
@@ -6,6 +6,7 @@ using Tensorflow.Eager; |
|
|
|
using Tensorflow.Framework.Models; |
|
|
|
using Tensorflow.Graphs; |
|
|
|
using Tensorflow.Train; |
|
|
|
using Tensorflow.Util; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
namespace Tensorflow.Functions |
|
|
|
@@ -21,6 +22,7 @@ namespace Tensorflow.Functions |
|
|
|
protected Dictionary<string, string> _attrs; |
|
|
|
protected FunctionSpec _function_spec; |
|
|
|
protected FunctionSpec _pre_initialized_function_spec = null; |
|
|
|
protected EagerDefinedFunction _inference_function; |
|
|
|
internal ForwardBackwardCall forward_backward; |
|
|
|
public Tensor[] Inputs => func_graph.Inputs; |
|
|
|
public Tensor[] CapturedInputs => func_graph.external_captures; |
|
|
|
@@ -39,6 +41,7 @@ namespace Tensorflow.Functions |
|
|
|
_captured_inputs = func_graph.external_captures; |
|
|
|
_attrs= new Dictionary<string, string>(); |
|
|
|
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); |
|
|
|
_inference_function = _delayed_rewrite_functions.Forward(); |
|
|
|
} |
|
|
|
|
|
|
|
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) |
|
|
|
@@ -49,6 +52,7 @@ namespace Tensorflow.Functions |
|
|
|
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); |
|
|
|
_attrs = attrs; |
|
|
|
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); |
|
|
|
_inference_function = _delayed_rewrite_functions.Forward(); |
|
|
|
} |
|
|
|
|
|
|
|
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) |
|
|
|
@@ -69,6 +73,7 @@ namespace Tensorflow.Functions |
|
|
|
_captured_inputs = func_graph.external_captures; |
|
|
|
_attrs = new Dictionary<string, string>(); |
|
|
|
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); |
|
|
|
_inference_function = _delayed_rewrite_functions.Forward(); |
|
|
|
} |
|
|
|
|
|
|
|
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) |
|
|
|
@@ -92,6 +97,7 @@ namespace Tensorflow.Functions |
|
|
|
_captured_inputs = func_graph.external_captures; |
|
|
|
_attrs = new Dictionary<string, string>(); |
|
|
|
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); |
|
|
|
_inference_function = _delayed_rewrite_functions.Forward(); |
|
|
|
} |
|
|
|
|
|
|
|
/*public ConcreteFunction(Func<Tensors, Tensors> func, |
|
|
|
@@ -154,9 +160,10 @@ namespace Tensorflow.Functions |
|
|
|
{ |
|
|
|
tensor_inputs.Add(arg); |
|
|
|
// If we're graph building, shape inference is on. |
|
|
|
if (!executing_eagerly) |
|
|
|
{ |
|
|
|
} |
|
|
|
} |
|
|
|
if (!executing_eagerly) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
tensor_inputs.AddRange(captured_inputs); |
|
|
|
|
|
|
|
@@ -166,12 +173,13 @@ namespace Tensorflow.Functions |
|
|
|
// No tape is watching; skip to running the function. |
|
|
|
if (possible_gradient_type == 0 && executing_eagerly) |
|
|
|
{ |
|
|
|
var attrs = new object[] |
|
|
|
{ |
|
|
|
"executor_type", "", |
|
|
|
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() |
|
|
|
}; |
|
|
|
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); |
|
|
|
return _build_call_outputs(_inference_function.Call(args)); |
|
|
|
//var attrs = new object[] |
|
|
|
//{ |
|
|
|
// "executor_type", "", |
|
|
|
// "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() |
|
|
|
//}; |
|
|
|
//return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); |
|
|
|
} |
|
|
|
|
|
|
|
if (forward_backward == null) |
|
|
|
@@ -184,10 +192,11 @@ namespace Tensorflow.Functions |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
// TODO(Rinne): add `default_graph._override_gradient_function`. |
|
|
|
flat_outputs = forward_function.Call(args_with_tangents); |
|
|
|
} |
|
|
|
forward_backward.Record(flat_outputs); |
|
|
|
return flat_outputs; |
|
|
|
return _build_call_outputs(flat_outputs); |
|
|
|
} |
|
|
|
|
|
|
|
public void AddTograph(Graph? g = null) |
|
|
|
@@ -262,6 +271,13 @@ namespace Tensorflow.Functions |
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
private Tensors _build_call_outputs(Tensors result) |
|
|
|
{ |
|
|
|
// TODO(Rinne): dwal with `func_graph.structured_outputs` |
|
|
|
|
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
public override string ToString() |
|
|
|
=> Name; |
|
|
|
} |
|
|
|
|