| @@ -162,6 +162,13 @@ namespace Tensorflow.Contexts | |||
| return c_api.TFE_ContextHasFunction(_handle, name); | |||
| } | |||
| public void add_function_def(FunctionDef fdef) | |||
| { | |||
| ensure_initialized(); | |||
| var fdef_string = fdef.ToString(); | |||
| c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); | |||
| } | |||
| public void restore_mode() | |||
| { | |||
| context_switches.Pop(); | |||
| @@ -358,7 +358,7 @@ namespace Tensorflow.Eager | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| if (value is ConcreteFunction func) | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
| else | |||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
| break; | |||
| @@ -30,6 +30,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | |||
| @@ -111,7 +111,17 @@ namespace Tensorflow | |||
| public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | |||
| public static Buffer tf_buffer(byte[] data) => new Buffer(data); | |||
| public static Buffer tf_buffer(byte[] data = null) | |||
| { | |||
| if(data is not null) | |||
| { | |||
| return new Buffer(data); ; | |||
| } | |||
| else | |||
| { | |||
| return new Buffer(); | |||
| } | |||
| } | |||
| public static IEnumerable<Operation> new_tf_operations(Graph graph) | |||
| { | |||
| @@ -15,11 +15,13 @@ namespace Tensorflow.Functions | |||
| { | |||
| protected IEnumerable<Tensor> _captured_inputs; | |||
| internal FuncGraph func_graph; | |||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
| protected Dictionary<string, string> _attrs; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| public string Name => func_graph?.FuncName; | |||
| public string Name => _delayed_rewrite_functions.forward().Name; | |||
| public Tensor[] Outputs; | |||
| public Type ReturnType; | |||
| @@ -31,6 +33,8 @@ namespace Tensorflow.Functions | |||
| { | |||
| func_graph = new FuncGraph(name); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs= new Dictionary<string, string>(); | |||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| @@ -38,7 +42,9 @@ namespace Tensorflow.Functions | |||
| func_graph = graph; | |||
| _captured_inputs = func_graph.external_captures; | |||
| ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
| //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||
| _attrs = attrs; | |||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
| @@ -57,6 +63,8 @@ namespace Tensorflow.Functions | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, string>(); | |||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
| } | |||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | |||
| @@ -78,6 +86,8 @@ namespace Tensorflow.Functions | |||
| null); | |||
| func_graph.Exit(); | |||
| _captured_inputs = func_graph.external_captures; | |||
| _attrs = new Dictionary<string, string>(); | |||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||
| } | |||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | |||
| @@ -176,7 +186,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| g = ops.get_default_graph(); | |||
| } | |||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
| _delayed_rewrite_functions.forward().AddToGraph(g); | |||
| } | |||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| @@ -11,9 +12,20 @@ namespace Tensorflow.Functions | |||
| public class EagerDefinedFunction | |||
| { | |||
| public int _num_outputs; | |||
| public string Name => _func_graph.FuncName; | |||
| FuncGraph _func_graph; | |||
| FunctionDef _definition; | |||
| public string Name => _func_graph.FuncName; | |||
| public FunctionDef Definition | |||
| { | |||
| get | |||
| { | |||
| if(_definition is null) | |||
| { | |||
| _definition = _get_definition(); | |||
| } | |||
| return _definition; | |||
| } | |||
| } | |||
| public EagerDefinedFunction(string name, FuncGraph graph, | |||
| Tensors inputs, Tensors outputs, | |||
| Dictionary<string, string> attrs) | |||
| @@ -46,5 +58,39 @@ namespace Tensorflow.Functions | |||
| return results; | |||
| } | |||
| public void AddToGraph(Graph g = null) | |||
| { | |||
| if(g is null && tf.Context.executing_eagerly()) | |||
| { | |||
| var ctx = tf.Context; | |||
| if (!ctx.has_function(this.Name)) | |||
| { | |||
| ctx.add_function_def(Definition); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| if (!g.IsFunction(Name)) | |||
| { | |||
| g.AddFunction(this); | |||
| } | |||
| foreach(var f in _func_graph.Functions.Values) | |||
| { | |||
| if (!g.IsFunction(f.Name)) | |||
| { | |||
| g.AddFunction(f); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| private FunctionDef _get_definition() | |||
| { | |||
| var buffer = c_api_util.tf_buffer(); | |||
| // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef | |||
| var proto_data = c_api.TF_GetBuffer(buffer); | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| namespace Tensorflow.Functions | |||
| { | |||
| public class DelayedRewriteGradientFunctions | |||
| { | |||
| static readonly string _INFERENCE_PREFIX = "__inference_"; | |||
| static readonly string _BACKWARD_PREFIX = "__backward_"; | |||
| static readonly string _FORWARD_PREFIX = "__forward_"; | |||
| FuncGraph _func_graph; | |||
| EagerDefinedFunction _inference_function; | |||
| Dictionary<string, string> _attrs; | |||
| int _num_inference_outputs; | |||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | |||
| { | |||
| _func_graph= func_graph; | |||
| _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | |||
| _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||
| _attrs = attrs; | |||
| _num_inference_outputs = _func_graph.Outputs.Length; | |||
| } | |||
| public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) | |||
| { | |||
| if(input_tangents is not null) | |||
| { | |||
| throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||
| $"a class that does not support forwardprop."); | |||
| } | |||
| return _inference_function; | |||
| } | |||
| private static string _inference_name(string name) | |||
| { | |||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||
| } | |||
| } | |||
| } | |||
| @@ -22,7 +22,6 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); | |||
| // TODO: func_name can be cache in FullName + Args | |||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | |||
| @@ -15,6 +15,7 @@ public class FuncGraph : Graph, IDisposable | |||
| public Tensors Inputs { get; set; } = new Tensors(); | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public string Name { get; set; } | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| @@ -39,7 +40,7 @@ public class FuncGraph : Graph, IDisposable | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| _graph_key = Name = name; | |||
| building_function = true; | |||
| } | |||
| @@ -48,7 +49,7 @@ public class FuncGraph : Graph, IDisposable | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| _graph_key = Name = name; | |||
| building_function = true; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| @@ -19,6 +19,8 @@ using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Collections.Specialized; | |||
| using System.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -85,6 +87,12 @@ namespace Tensorflow | |||
| private int _next_id_counter; | |||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | |||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
| private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||
| private VersionDef _graph_def_versions = new VersionDef() | |||
| { | |||
| Producer = versions.GRAPH_DEF_VERSION, | |||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||
| }; | |||
| public string _name_stack = ""; | |||
| protected string _graph_key; | |||
| @@ -120,6 +128,7 @@ namespace Tensorflow | |||
| protected Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
| public Graph() | |||
| { | |||
| @@ -148,8 +157,23 @@ namespace Tensorflow | |||
| public bool IsFunction(string name) | |||
| { | |||
| // TODO(Rinne): deal with `_functions`. | |||
| throw new NotImplementedException(); | |||
| return _functions.ContainsKey(tf.compat.as_str(name)); | |||
| } | |||
| public void AddFunction(EagerDefinedFunction function) | |||
| { | |||
| _check_not_finalized(); | |||
| var name = function.Name; | |||
| // TODO(Rinne): deal with c_graph | |||
| _functions[tf.compat.as_str(name)] = function; | |||
| if(_graph_def_versions.MinConsumer < 12) | |||
| { | |||
| _graph_def_versions.MinConsumer = 12; | |||
| } | |||
| } | |||
| private Tensor _as_graph_element(object obj) | |||
| @@ -77,11 +77,8 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| } | |||
| Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
| int aa = 0; | |||
| var temp = _sort_function_defs(library, function_deps); | |||
| foreach (var fdef in temp) | |||
| foreach (var fdef in _sort_function_defs(library, function_deps)) | |||
| { | |||
| aa++; | |||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
| if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
| @@ -191,7 +188,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; | |||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
| // TODO(Rinne): deal with `op._gradient_function`. | |||
| } | |||
| string gradient_op_type = null; | |||
| @@ -375,6 +375,11 @@ namespace Tensorflow | |||
| // Re-create everything. | |||
| foreach (var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| if(node_id == 45) | |||
| { | |||
| // TODelete | |||
| Console.WriteLine(); | |||
| } | |||
| if (nodes.ContainsKey(node_id)) | |||
| { | |||
| continue; | |||
| @@ -469,7 +474,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes) | |||
| private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes) | |||
| { | |||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
| { | |||
| @@ -572,7 +577,7 @@ namespace Tensorflow | |||
| { | |||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | |||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | |||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||
| SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), | |||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | |||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | |||
| _ => throw new NotImplementedException() | |||
| @@ -644,7 +649,7 @@ namespace Tensorflow | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| IDictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
| @@ -62,7 +62,7 @@ public class SequentialModelLoad | |||
| [TestMethod] | |||
| public void Temp() | |||
| { | |||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||
| var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | |||
| model.summary(); | |||
| } | |||
| } | |||