diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index efb6b0fc..e1cce1b0 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -162,6 +162,13 @@ namespace Tensorflow.Contexts return c_api.TFE_ContextHasFunction(_handle, name); } + public void add_function_def(FunctionDef fdef) + { + ensure_initialized(); + var fdef_string = fdef.ToString(); + c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); + } + public void restore_mode() { context_switches.Pop(); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 92d5b2a4..fedc02cb 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -358,7 +358,7 @@ namespace Tensorflow.Eager break; case TF_AttrType.TF_ATTR_FUNC: if (value is ConcreteFunction func) - c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); + c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); else throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); break; diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 6930b0c7..e8746c1b 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -30,6 +30,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); + [DllImport(TensorFlowLibName)] public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.cs b/src/TensorFlowNET.Core/Framework/c_api_util.cs index 9cfbf0d0..e21c3b01 100644 --- a/src/TensorFlowNET.Core/Framework/c_api_util.cs +++ b/src/TensorFlowNET.Core/Framework/c_api_util.cs @@ -111,7 +111,17 @@ namespace Tensorflow public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); - public static Buffer tf_buffer(byte[] data) => new Buffer(data); + public static Buffer tf_buffer(byte[] data = null) + { + if(data is not null) + { + return new Buffer(data); ; + } + else + { + return new Buffer(); + } + } public static IEnumerable new_tf_operations(Graph graph) { diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 23c669b3..9abcc61c 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -15,11 +15,13 @@ namespace Tensorflow.Functions { protected IEnumerable _captured_inputs; internal FuncGraph func_graph; + protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; + protected Dictionary _attrs; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; - public string Name => func_graph?.FuncName; + public string Name => _delayed_rewrite_functions.forward().Name; public Tensor[] Outputs; public Type ReturnType; @@ -31,6 +33,8 @@ namespace Tensorflow.Functions { func_graph = new FuncGraph(name); _captured_inputs = func_graph.external_captures; + _attrs= new Dictionary(); + _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) @@ -38,7 +42,9 @@ namespace Tensorflow.Functions func_graph = graph; _captured_inputs = func_graph.external_captures; - ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); + //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); + _attrs = attrs; + _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -57,6 +63,8 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; + _attrs = new Dictionary(); + _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -78,6 +86,8 @@ namespace Tensorflow.Functions null); func_graph.Exit(); _captured_inputs = func_graph.external_captures; + _attrs = new Dictionary(); + _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); } /*public ConcreteFunction(Func func, @@ -176,7 +186,7 @@ namespace Tensorflow.Functions { g = ops.get_default_graph(); } - // TODO(Rinne); complete it with `_delayed_rewrite_functions`. + _delayed_rewrite_functions.forward().AddToGraph(g); } public void SetExternalCaptures(IEnumerable captures) diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index bfb8aa71..40b61511 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Contexts; using Tensorflow.Graphs; using static Tensorflow.Binding; @@ -11,9 +12,20 @@ namespace Tensorflow.Functions public class EagerDefinedFunction { public int _num_outputs; - public string Name => _func_graph.FuncName; - FuncGraph _func_graph; + FunctionDef _definition; + public string Name => _func_graph.FuncName; + public FunctionDef Definition + { + get + { + if(_definition is null) + { + _definition = _get_definition(); + } + return _definition; + } + } public EagerDefinedFunction(string name, FuncGraph graph, Tensors inputs, Tensors outputs, Dictionary attrs) @@ -46,5 +58,39 @@ namespace Tensorflow.Functions return results; } + + public void AddToGraph(Graph g = null) + { + if(g is null && tf.Context.executing_eagerly()) + { + var ctx = tf.Context; + if (!ctx.has_function(this.Name)) + { + ctx.add_function_def(Definition); + } + } + else + { + if (!g.IsFunction(Name)) + { + g.AddFunction(this); + } + foreach(var f in _func_graph.Functions.Values) + { + if (!g.IsFunction(f.Name)) + { + g.AddFunction(f); + } + } + } + } + + private FunctionDef _get_definition() + { + var buffer = c_api_util.tf_buffer(); + // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef + var proto_data = c_api.TF_GetBuffer(buffer); + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Core/Functions/monomorphic_function.cs b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs new file mode 100644 index 00000000..df8b6d4e --- /dev/null +++ b/src/TensorFlowNET.Core/Functions/monomorphic_function.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Graphs; + +namespace Tensorflow.Functions +{ + public class DelayedRewriteGradientFunctions + { + static readonly string _INFERENCE_PREFIX = "__inference_"; + static readonly string _BACKWARD_PREFIX = "__backward_"; + static readonly string _FORWARD_PREFIX = "__forward_"; + FuncGraph _func_graph; + EagerDefinedFunction _inference_function; + Dictionary _attrs; + int _num_inference_outputs; + public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary attrs) + { + _func_graph= func_graph; + _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), + _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); + _attrs = attrs; + _num_inference_outputs = _func_graph.Outputs.Length; + } + + public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) + { + if(input_tangents is not null) + { + throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + + $"a class that does not support forwardprop."); + } + return _inference_function; + } + + private static string _inference_name(string name) + { + return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 9fe49da2..ffdac931 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -22,7 +22,6 @@ namespace Tensorflow.Graphs public override void OnEntry(MethodExecutionArgs args) { - File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); // TODO: func_name can be cache in FullName + Args func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 333380c4..b086907e 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -15,6 +15,7 @@ public class FuncGraph : Graph, IDisposable public Tensors Inputs { get; set; } = new Tensors(); public Tensors Outputs { get; set; } = new Tensors(); + public string Name { get; set; } public Dictionary Attrs { get; set; } Dictionary _captures @@ -39,7 +40,7 @@ public class FuncGraph : Graph, IDisposable outer_graph = ops.get_default_graph(); while (outer_graph.building_function) outer_graph = outer_graph.OuterGraph; - _graph_key = name; + _graph_key = Name = name; building_function = true; } @@ -48,7 +49,7 @@ public class FuncGraph : Graph, IDisposable outer_graph = ops.get_default_graph(); while (outer_graph.building_function) outer_graph = outer_graph.OuterGraph; - _graph_key = name; + _graph_key = Name = name; building_function = true; Attrs = attrs; // Will to test if FuncGraph has memory leak diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index fccc763e..cf38d6b1 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -19,6 +19,8 @@ using System.Collections; using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; +using Tensorflow.Framework; +using Tensorflow.Functions; using static Tensorflow.Binding; namespace Tensorflow @@ -85,6 +87,12 @@ namespace Tensorflow private int _next_id_counter; private List _unfetchable_ops = new List(); private List _unfeedable_tensors = new List(); + private Dictionary _functions = new(); + private VersionDef _graph_def_versions = new VersionDef() + { + Producer = versions.GRAPH_DEF_VERSION, + MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER + }; public string _name_stack = ""; protected string _graph_key; @@ -120,6 +128,7 @@ namespace Tensorflow protected Graph outer_graph; public Graph OuterGraph => outer_graph; + public Dictionary Functions => _functions; public Graph() { @@ -148,8 +157,23 @@ namespace Tensorflow public bool IsFunction(string name) { - // TODO(Rinne): deal with `_functions`. - throw new NotImplementedException(); + return _functions.ContainsKey(tf.compat.as_str(name)); + } + + public void AddFunction(EagerDefinedFunction function) + { + _check_not_finalized(); + + var name = function.Name; + + // TODO(Rinne): deal with c_graph + + _functions[tf.compat.as_str(name)] = function; + + if(_graph_def_versions.MinConsumer < 12) + { + _graph_def_versions.MinConsumer = 12; + } } private Tensor _as_graph_element(object obj) diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs index 757e8b7f..25697c6e 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -77,11 +77,8 @@ namespace Tensorflow.Training.Saving.SavedModel } Dictionary loaded_gradients = new(); - int aa = 0; - var temp = _sort_function_defs(library, function_deps); - foreach (var fdef in temp) + foreach (var fdef in _sort_function_defs(library, function_deps)) { - aa++; var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) @@ -191,7 +188,7 @@ namespace Tensorflow.Training.Saving.SavedModel { if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") { - var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; + var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; // TODO(Rinne): deal with `op._gradient_function`. } string gradient_op_type = null; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 7441e4a4..3505da93 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -375,6 +375,11 @@ namespace Tensorflow // Re-create everything. foreach (var (node_id, proto) in _iter_all_nodes()) { + if(node_id == 45) + { + // TODelete + Console.WriteLine(); + } if (nodes.ContainsKey(node_id)) { continue; @@ -469,7 +474,7 @@ namespace Tensorflow } } - private void _setup_function_captures(string concrete_function_name, Dictionary, Trackable> nodes) + private void _setup_function_captures(string concrete_function_name, IDictionary, Trackable> nodes) { if (_restored_concrete_functions.Contains(concrete_function_name)) { @@ -572,7 +577,7 @@ namespace Tensorflow { SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), - SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), + SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), _ => throw new NotImplementedException() @@ -644,7 +649,7 @@ namespace Tensorflow } private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, - Dictionary, Trackable> dependencies) + IDictionary, Trackable> dependencies) { var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); _setup_function_captures(proto.ConcreteFunctionName, dependencies); diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index a24ce727..74f610c8 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -62,7 +62,7 @@ public class SequentialModelLoad [TestMethod] public void Temp() { - var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); + var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); model.summary(); } }